1use std::any::type_name;
11use std::cmp::Ordering;
12use std::error::Error;
13use std::fmt::{self, Debug, Display};
14use std::hash::{Hash, Hasher};
15
16use bitflags::bitflags;
17use chrono::{DateTime, NaiveDateTime, Utc};
18use dec::OrderedDecimal;
19use mz_lowertest::MzReflect;
20use mz_proto::{RustType, TryFromProtoError};
21use postgres_protocol::types;
22use proptest_derive::Arbitrary;
23use serde::{Deserialize, Serialize};
24use tokio_postgres::types::{FromSql, Type as PgType};
25
26use crate::Datum;
27use crate::adt::date::Date;
28use crate::adt::numeric::Numeric;
29use crate::adt::timestamp::CheckedTimestamp;
30use crate::scalar::{DatumKind, SqlScalarType};
31
32include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.range.rs"));
33
34bitflags! {
35 pub(crate) struct InternalFlags: u8 {
36 const EMPTY = 1;
37 const LB_INCLUSIVE = 1 << 1;
38 const LB_INFINITE = 1 << 2;
39 const UB_INCLUSIVE = 1 << 3;
40 const UB_INFINITE = 1 << 4;
41 }
42}
43
44bitflags! {
45 pub(crate) struct PgFlags: u8 {
46 const EMPTY = 0b0000_0001;
47 const LB_INCLUSIVE = 0b0000_0010;
48 const UB_INCLUSIVE = 0b0000_0100;
49 const LB_INFINITE = 0b0000_1000;
50 const UB_INFINITE = 0b0001_0000;
51 }
52}
53
54pub struct Range<D> {
64 pub inner: Option<RangeInner<D>>,
66}
67
68impl crate::scalar::SqlContainerType for Range<Datum<'_>> {
69 fn unwrap_element_type(container: &SqlScalarType) -> &SqlScalarType {
70 container.unwrap_range_element_type()
71 }
72 fn wrap_element_type(element: SqlScalarType) -> SqlScalarType {
73 SqlScalarType::Range {
74 element_type: Box::new(element),
75 }
76 }
77}
78
79impl<D: Display> Display for Range<D> {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 match &self.inner {
82 None => f.write_str("empty"),
83 Some(i) => i.fmt(f),
84 }
85 }
86}
87
88impl<D: Debug> Debug for Range<D> {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.debug_struct("Range").field("inner", &self.inner).finish()
91 }
92}
93
94impl<D: Clone> Clone for Range<D> {
95 fn clone(&self) -> Self {
96 Self {
97 inner: self.inner.clone(),
98 }
99 }
100}
101
102impl<D: Copy> Copy for Range<D> {}
103
104impl<D: PartialEq> PartialEq for Range<D> {
105 fn eq(&self, other: &Self) -> bool {
106 self.inner == other.inner
107 }
108}
109
110impl<D: Eq> Eq for Range<D> {}
111
112impl<D: Ord + PartialOrd> PartialOrd for Range<D> {
113 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
114 Some(self.cmp(other))
115 }
116}
117
118impl<D: Ord> Ord for Range<D> {
119 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
120 self.inner.cmp(&other.inner)
121 }
122}
123
124impl<D: Hash> Hash for Range<D> {
125 fn hash<H: Hasher>(&self, hasher: &mut H) {
126 self.inner.hash(hasher)
127 }
128}
129
130pub trait RangeOps<'a>:
132 Debug + Ord + PartialOrd + Eq + PartialEq + TryFrom<Datum<'a>> + Into<Datum<'a>>
133where
134 <Self as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
135{
136 fn step(self) -> Option<Self> {
139 Some(self)
140 }
141
142 fn unwrap_datum(d: Datum<'a>) -> Self {
143 <Self>::try_from(d)
144 .unwrap_or_else(|_| panic!("cannot take {} to {}", d, type_name::<Self>()))
145 }
146
147 fn err_type_name() -> &'static str;
148}
149
150impl<'a> RangeOps<'a> for i32 {
151 fn step(self) -> Option<i32> {
152 self.checked_add(1)
153 }
154
155 fn err_type_name() -> &'static str {
156 "integer"
157 }
158}
159
160impl<'a> RangeOps<'a> for i64 {
161 fn step(self) -> Option<i64> {
162 self.checked_add(1)
163 }
164
165 fn err_type_name() -> &'static str {
166 "bigint"
167 }
168}
169
170impl<'a> RangeOps<'a> for Date {
171 fn step(self) -> Option<Date> {
172 self.checked_add(1).ok()
173 }
174
175 fn err_type_name() -> &'static str {
176 "date"
177 }
178}
179
180impl<'a> RangeOps<'a> for OrderedDecimal<Numeric> {
181 fn err_type_name() -> &'static str {
182 "numeric"
183 }
184}
185
186impl<'a> RangeOps<'a> for CheckedTimestamp<NaiveDateTime> {
187 fn err_type_name() -> &'static str {
188 "timestamp"
189 }
190}
191
192impl<'a> RangeOps<'a> for CheckedTimestamp<DateTime<Utc>> {
193 fn err_type_name() -> &'static str {
194 "timestamptz"
195 }
196}
197
198impl<D> Range<D> {
200 pub fn new(inner: Option<(RangeLowerBound<D>, RangeUpperBound<D>)>) -> Range<D> {
206 Range {
207 inner: inner.map(|(lower, upper)| RangeInner { lower, upper }),
208 }
209 }
210
211 pub fn internal_flag_bits(&self) -> u8 {
218 let mut flags = InternalFlags::empty();
219
220 match &self.inner {
221 None => {
222 flags.set(InternalFlags::EMPTY, true);
223 }
224 Some(RangeInner { lower, upper }) => {
225 flags.set(InternalFlags::EMPTY, false);
226 flags.set(InternalFlags::LB_INFINITE, lower.bound.is_none());
227 flags.set(InternalFlags::UB_INFINITE, upper.bound.is_none());
228 flags.set(InternalFlags::LB_INCLUSIVE, lower.inclusive);
229 flags.set(InternalFlags::UB_INCLUSIVE, upper.inclusive);
230 }
231 }
232
233 flags.bits()
234 }
235
236 pub fn pg_flag_bits(&self) -> u8 {
243 let mut flags = PgFlags::empty();
244
245 match &self.inner {
246 None => {
247 flags.set(PgFlags::EMPTY, true);
248 }
249 Some(RangeInner { lower, upper }) => {
250 flags.set(PgFlags::EMPTY, false);
251 flags.set(PgFlags::LB_INFINITE, lower.bound.is_none());
252 flags.set(PgFlags::UB_INFINITE, upper.bound.is_none());
253 flags.set(PgFlags::LB_INCLUSIVE, lower.inclusive);
254 flags.set(PgFlags::UB_INCLUSIVE, upper.inclusive);
255 }
256 }
257
258 flags.bits()
259 }
260
261 pub fn into_bounds<F, O>(self, conv: F) -> Range<O>
264 where
265 F: Fn(D) -> O,
266 {
267 Range {
268 inner: self
269 .inner
270 .map(|RangeInner::<D> { lower, upper }| RangeInner::<O> {
271 lower: RangeLowerBound {
272 inclusive: lower.inclusive,
273 bound: lower.bound.map(&conv),
274 },
275 upper: RangeUpperBound {
276 inclusive: upper.inclusive,
277 bound: upper.bound.map(&conv),
278 },
279 }),
280 }
281 }
282
283 pub fn try_into_bounds<F, O, E>(self, conv: F) -> Result<Range<O>, E>
289 where
290 F: Fn(D) -> Result<O, E>,
291 {
292 let inner = match self.inner {
293 None => None,
294 Some(RangeInner { lower, upper }) => Some(RangeInner {
295 lower: RangeLowerBound {
296 inclusive: lower.inclusive,
297 bound: lower.bound.map(&conv).transpose()?,
298 },
299 upper: RangeUpperBound {
300 inclusive: upper.inclusive,
301 bound: upper.bound.map(&conv).transpose()?,
302 },
303 }),
304 };
305 Ok(Range { inner })
306 }
307}
308
309impl<'a, B: Copy + Ord> Range<B> {
311 pub fn contains_elem<T: RangeOps<'a>>(&self, elem: &T) -> bool
312 where
313 Datum<'a>: From<B>,
314 <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
315 {
316 match self.inner {
317 None => false,
318 Some(inner) => inner.lower.satisfied_by(elem) && inner.upper.satisfied_by(elem),
319 }
320 }
321
322 pub fn contains_range(&self, other: &Range<B>) -> bool {
323 match (self.inner, other.inner) {
324 (None, None) | (Some(_), None) => true,
325 (None, Some(_)) => false,
326 (Some(i), Some(j)) => i.lower <= j.lower && j.upper <= i.upper,
327 }
328 }
329
330 pub fn overlaps(&self, other: &Range<B>) -> bool {
331 match (self.inner, other.inner) {
332 (Some(s), Some(o)) => {
333 let r = match s.cmp(&o) {
334 Ordering::Equal => Ordering::Equal,
335 Ordering::Less => s.upper.range_bound_cmp(&o.lower),
336 Ordering::Greater => o.upper.range_bound_cmp(&s.lower),
337 };
338
339 matches!(r, Ordering::Greater | Ordering::Equal)
341 }
342 _ => false,
343 }
344 }
345
346 pub fn before(&self, other: &Range<B>) -> bool {
347 match (self.inner, other.inner) {
348 (Some(s), Some(o)) => {
349 matches!(s.upper.range_bound_cmp(&o.lower), Ordering::Less)
350 }
351 _ => false,
352 }
353 }
354
355 pub fn after(&self, other: &Range<B>) -> bool {
356 match (self.inner, other.inner) {
357 (Some(s), Some(o)) => {
358 matches!(s.lower.range_bound_cmp(&o.upper), Ordering::Greater)
359 }
360 _ => false,
361 }
362 }
363
364 pub fn overleft(&self, other: &Range<B>) -> bool {
365 match (self.inner, other.inner) {
366 (Some(s), Some(o)) => {
367 matches!(
368 s.upper.range_bound_cmp(&o.upper),
369 Ordering::Less | Ordering::Equal
370 )
371 }
372 _ => false,
373 }
374 }
375
376 pub fn overright(&self, other: &Range<B>) -> bool {
377 match (self.inner, other.inner) {
378 (Some(s), Some(o)) => {
379 matches!(
380 s.lower.range_bound_cmp(&o.lower),
381 Ordering::Greater | Ordering::Equal
382 )
383 }
384 _ => false,
385 }
386 }
387
388 pub fn adjacent(&self, other: &Range<B>) -> bool {
389 match (self.inner, other.inner) {
390 (Some(s), Some(o)) => {
391 for (lower, upper) in [(s.lower, o.upper), (o.lower, s.upper)] {
393 if let (Some(l), Some(u)) = (lower.bound, upper.bound) {
394 if lower.inclusive ^ upper.inclusive && l == u {
396 return true;
397 }
398 }
399 }
400 false
401 }
402 _ => false,
403 }
404 }
405
406 pub fn union(&self, other: &Range<B>) -> Result<Range<B>, InvalidRangeError> {
407 let (s, o) = match (self.inner, other.inner) {
409 (None, None) => return Ok(Range { inner: None }),
410 (inner @ Some(_), None) | (None, inner @ Some(_)) => return Ok(Range { inner }),
411 (Some(s), Some(o)) => {
412 if !(self.overlaps(other) || self.adjacent(other)) {
414 return Err(InvalidRangeError::DiscontiguousUnion);
415 }
416 (s, o)
417 }
418 };
419
420 let lower = std::cmp::min(s.lower, o.lower);
421 let upper = std::cmp::max(s.upper, o.upper);
422
423 Ok(Range {
424 inner: Some(RangeInner { lower, upper }),
425 })
426 }
427
428 pub fn intersection(&self, other: &Range<B>) -> Range<B> {
429 let (s, o) = match (self.inner, other.inner) {
431 (Some(s), Some(o)) => {
432 if !self.overlaps(other) {
433 return Range { inner: None };
434 }
435
436 (s, o)
437 }
438 _ => return Range { inner: None },
439 };
440
441 let lower = std::cmp::max(s.lower, o.lower);
442 let upper = std::cmp::min(s.upper, o.upper);
443
444 Range {
445 inner: Some(RangeInner { lower, upper }),
446 }
447 }
448
449 pub fn difference(&self, other: &Range<B>) -> Result<Range<Datum<'a>>, InvalidRangeError>
453 where
454 Datum<'a>: From<B>,
455 {
456 use std::cmp::Ordering::*;
457
458 if !self.overlaps(other) {
460 return Ok(self.into_bounds(Datum::from));
461 }
462
463 let (s, o) = match (self.inner, other.inner) {
464 (None, _) | (_, None) => unreachable!("already returned from overlap check"),
465 (Some(s), Some(o)) => (s, o),
466 };
467
468 let ll = s.lower.cmp(&o.lower);
469 let uu = s.upper.cmp(&o.upper);
470
471 let r = match (ll, uu) {
472 (Less, Greater) => return Err(InvalidRangeError::DiscontiguousDifference),
474 (Greater | Equal, Less | Equal) => Range { inner: None },
476 (Greater | Equal, Greater) => {
477 let lower = RangeBound {
478 inclusive: !o.upper.inclusive,
479 bound: o.upper.bound,
480 };
481 Range {
482 inner: Some(RangeInner {
483 lower,
484 upper: s.upper,
485 }),
486 }
487 }
488 (Less, Less | Equal) => {
489 let upper = RangeBound {
490 inclusive: !o.lower.inclusive,
491 bound: o.lower.bound,
492 };
493 Range {
494 inner: Some(RangeInner {
495 lower: s.lower,
496 upper,
497 }),
498 }
499 }
500 };
501
502 let mut r = r.into_bounds(Datum::from);
503
504 r.canonicalize()?;
505
506 Ok(r)
507 }
508}
509
510impl<'a> Range<Datum<'a>> {
511 pub fn canonicalize(&mut self) -> Result<(), InvalidRangeError> {
522 let (lower, upper) = match &mut self.inner {
523 Some(inner) => (&mut inner.lower, &mut inner.upper),
524 None => return Ok(()),
525 };
526
527 match (lower.bound, upper.bound) {
528 (Some(l), Some(u)) => {
529 assert_eq!(
530 DatumKind::from(l),
531 DatumKind::from(u),
532 "finite bounds must be of same type"
533 );
534 if l > u {
535 return Err(InvalidRangeError::MisorderedRangeBounds);
536 }
537 }
538 _ => {}
539 };
540
541 lower.canonicalize()?;
542 upper.canonicalize()?;
543
544 if !(lower.inclusive && upper.inclusive)
547 && lower.bound >= upper.bound
548 && upper.bound.is_some()
550 {
551 self.inner = None
553 }
554
555 Ok(())
556 }
557}
558
559#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
561pub struct RangeInner<B> {
562 pub lower: RangeLowerBound<B>,
563 pub upper: RangeUpperBound<B>,
564}
565
566impl<B: Display> Display for RangeInner<B> {
567 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568 f.write_str(if self.lower.inclusive { "[" } else { "(" })?;
569 self.lower.fmt(f)?;
570 f.write_str(",")?;
571 Display::fmt(&self.upper, f)?;
572 f.write_str(if self.upper.inclusive { "]" } else { ")" })
573 }
574}
575
576impl<B: Ord> Ord for RangeInner<B> {
577 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
578 self.lower
579 .cmp(&other.lower)
580 .then(self.upper.cmp(&other.upper))
581 }
582}
583
584impl<B: PartialOrd + Ord> PartialOrd for RangeInner<B> {
585 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
586 Some(self.cmp(other))
587 }
588}
589
590#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
592pub struct RangeBound<B, const UPPER: bool = false> {
593 pub inclusive: bool,
594 pub bound: Option<B>,
596}
597
598impl<const UPPER: bool, D: Display> Display for RangeBound<D, UPPER> {
599 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
600 match &self.bound {
601 None => Ok(()),
602 Some(bound) => bound.fmt(f),
603 }
604 }
605}
606
607impl<const UPPER: bool, D: Ord> Ord for RangeBound<D, UPPER> {
608 fn cmp(&self, other: &Self) -> Ordering {
609 let mut cmp = self.bound.cmp(&other.bound);
611 if UPPER && other.bound.is_none() ^ self.bound.is_none() {
613 cmp = cmp.reverse();
614 }
615 cmp.then(if self.inclusive == other.inclusive {
618 Ordering::Equal
619 } else if self.inclusive {
620 if UPPER {
621 Ordering::Greater
622 } else {
623 Ordering::Less
624 }
625 } else if UPPER {
626 Ordering::Less
627 } else {
628 Ordering::Greater
629 })
630 }
631}
632
633impl<const UPPER: bool, D: PartialOrd + Ord> PartialOrd for RangeBound<D, UPPER> {
634 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
635 Some(self.cmp(other))
636 }
637}
638
639pub type RangeLowerBound<B> = RangeBound<B, false>;
641
642pub type RangeUpperBound<B> = RangeBound<B, true>;
644
645impl<'a, const UPPER: bool, B: Copy + Ord> RangeBound<B, UPPER> {
648 fn elem_cmp<T: RangeOps<'a>>(&self, elem: &T) -> Ordering
653 where
654 Datum<'a>: From<B>,
655 <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
656 {
657 match self.bound.map(|bound| <T>::unwrap_datum(bound.into())) {
658 None if UPPER => Ordering::Greater,
659 None => Ordering::Less,
660 Some(bound) => bound.cmp(elem),
661 }
662 }
663
664 fn satisfied_by<T: RangeOps<'a>>(&self, elem: &T) -> bool
666 where
667 Datum<'a>: From<B>,
668 <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
669 {
670 match self.elem_cmp(elem) {
671 Ordering::Equal => self.inclusive,
674 Ordering::Greater => UPPER,
676 Ordering::Less => !UPPER,
678 }
679 }
680
681 fn range_bound_cmp<const OTHER_UPPER: bool>(
684 &self,
685 other: &RangeBound<B, OTHER_UPPER>,
686 ) -> Ordering {
687 if UPPER == OTHER_UPPER {
688 return self.cmp(&RangeBound {
689 inclusive: other.inclusive,
690 bound: other.bound,
691 });
692 }
693
694 if self.bound.is_none() || other.bound.is_none() {
697 return if UPPER {
698 Ordering::Greater
699 } else {
700 Ordering::Less
701 };
702 }
703 let cmp = self.bound.cmp(&other.bound);
705 cmp.then(if self.inclusive && other.inclusive {
708 Ordering::Equal
709 } else if UPPER {
710 Ordering::Less
711 } else {
712 Ordering::Greater
713 })
714 }
715}
716
717impl<'a, const UPPER: bool> RangeBound<Datum<'a>, UPPER> {
718 pub fn new(d: Datum<'a>, inclusive: bool) -> RangeBound<Datum<'a>, UPPER> {
724 RangeBound {
725 inclusive,
726 bound: match d {
727 Datum::Null => None,
728 o => Some(o),
729 },
730 }
731 }
732
733 fn canonicalize(&mut self) -> Result<(), InvalidRangeError> {
737 Ok(match self.bound {
738 None => {
739 self.inclusive = false;
740 }
741 Some(value) => match value {
743 d @ Datum::Int32(_) => self.canonicalize_inner::<i32>(d)?,
744 d @ Datum::Int64(_) => self.canonicalize_inner::<i64>(d)?,
745 d @ Datum::Date(_) => self.canonicalize_inner::<Date>(d)?,
746 Datum::Numeric(..) | Datum::Timestamp(..) | Datum::TimestampTz(..) => {}
747 d => unreachable!("{d:?} not yet supported in ranges"),
748 },
749 })
750 }
751
752 fn canonicalize_inner<T: RangeOps<'a>>(&mut self, d: Datum<'a>) -> Result<(), InvalidRangeError>
758 where
759 <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
760 {
761 if UPPER == self.inclusive {
763 let cur = <T>::unwrap_datum(d);
764 self.bound = Some(
765 cur.step()
766 .ok_or_else(|| {
767 InvalidRangeError::CanonicalizationOverflow(T::err_type_name().into())
768 })?
769 .into(),
770 );
771 self.inclusive = !UPPER;
772 }
773
774 Ok(())
775 }
776}
777
778#[derive(
779 Arbitrary,
780 Ord,
781 PartialOrd,
782 Clone,
783 Debug,
784 Eq,
785 PartialEq,
786 Serialize,
787 Deserialize,
788 Hash,
789 MzReflect
790)]
791pub enum InvalidRangeError {
792 MisorderedRangeBounds,
793 CanonicalizationOverflow(Box<str>),
794 InvalidRangeBoundFlags,
795 DiscontiguousUnion,
796 DiscontiguousDifference,
797 NullRangeBoundFlags,
798}
799
800impl Display for InvalidRangeError {
801 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
802 match self {
803 InvalidRangeError::MisorderedRangeBounds => {
804 f.write_str("range lower bound must be less than or equal to range upper bound")
805 }
806 InvalidRangeError::CanonicalizationOverflow(t) => {
807 write!(f, "{} out of range", t)
808 }
809 InvalidRangeError::InvalidRangeBoundFlags => f.write_str("invalid range bound flags"),
810 InvalidRangeError::DiscontiguousUnion => {
811 f.write_str("result of range union would not be contiguous")
812 }
813 InvalidRangeError::DiscontiguousDifference => {
814 f.write_str("result of range difference would not be contiguous")
815 }
816 InvalidRangeError::NullRangeBoundFlags => {
817 f.write_str("range constructor flags argument must not be null")
818 }
819 }
820 }
821}
822
823impl Error for InvalidRangeError {
824 fn source(&self) -> Option<&(dyn Error + 'static)> {
825 None
826 }
827}
828
829impl From<InvalidRangeError> for String {
831 fn from(e: InvalidRangeError) -> Self {
832 e.to_string()
833 }
834}
835
836impl RustType<ProtoInvalidRangeError> for InvalidRangeError {
837 fn into_proto(&self) -> ProtoInvalidRangeError {
838 use Kind::*;
839 use proto_invalid_range_error::*;
840 let kind = match self {
841 InvalidRangeError::MisorderedRangeBounds => MisorderedRangeBounds(()),
842 InvalidRangeError::CanonicalizationOverflow(s) => {
843 CanonicalizationOverflow(s.into_proto())
844 }
845 InvalidRangeError::InvalidRangeBoundFlags => InvalidRangeBoundFlags(()),
846 InvalidRangeError::DiscontiguousUnion => DiscontiguousUnion(()),
847 InvalidRangeError::DiscontiguousDifference => DiscontiguousDifference(()),
848 InvalidRangeError::NullRangeBoundFlags => NullRangeBoundFlags(()),
849 };
850 ProtoInvalidRangeError { kind: Some(kind) }
851 }
852
853 fn from_proto(proto: ProtoInvalidRangeError) -> Result<Self, TryFromProtoError> {
854 use proto_invalid_range_error::Kind::*;
855 match proto.kind {
856 Some(kind) => Ok(match kind {
857 MisorderedRangeBounds(()) => InvalidRangeError::MisorderedRangeBounds,
858 CanonicalizationOverflow(s) => {
859 InvalidRangeError::CanonicalizationOverflow(s.into())
860 }
861 InvalidRangeBoundFlags(()) => InvalidRangeError::InvalidRangeBoundFlags,
862 DiscontiguousUnion(()) => InvalidRangeError::DiscontiguousUnion,
863 DiscontiguousDifference(()) => InvalidRangeError::DiscontiguousDifference,
864 NullRangeBoundFlags(()) => InvalidRangeError::NullRangeBoundFlags,
865 }),
866 None => Err(TryFromProtoError::missing_field(
867 "`ProtoInvalidRangeError::kind`",
868 )),
869 }
870 }
871}
872
873pub fn parse_range_bound_flags<'a>(flags: &'a str) -> Result<(bool, bool), InvalidRangeError> {
874 let mut flags = flags.chars();
875
876 let lower = match flags.next() {
877 Some('(') => false,
878 Some('[') => true,
879 _ => return Err(InvalidRangeError::InvalidRangeBoundFlags),
880 };
881
882 let upper = match flags.next() {
883 Some(')') => false,
884 Some(']') => true,
885 _ => return Err(InvalidRangeError::InvalidRangeBoundFlags),
886 };
887
888 match flags.next() {
889 Some(_) => Err(InvalidRangeError::InvalidRangeBoundFlags),
890 None => Ok((lower, upper)),
891 }
892}
893
894impl<'a, T: FromSql<'a>> FromSql<'a> for Range<T> {
895 fn from_sql(ty: &PgType, raw: &'a [u8]) -> Result<Range<T>, Box<dyn Error + Sync + Send>> {
896 let inner_typ = match ty {
897 &PgType::INT4_RANGE => PgType::INT4,
898 &PgType::INT8_RANGE => PgType::INT8,
899 &PgType::DATE_RANGE => PgType::DATE,
900 &PgType::NUM_RANGE => PgType::NUMERIC,
901 &PgType::TS_RANGE => PgType::TIMESTAMP,
902 &PgType::TSTZ_RANGE => PgType::TIMESTAMPTZ,
903 _ => unreachable!(),
904 };
905
906 let inner = match types::range_from_sql(raw)? {
907 types::Range::Empty => None,
908 types::Range::Nonempty(lower, upper) => {
909 let mut bounds = Vec::with_capacity(2);
910
911 for bound_outer in [lower, upper].into_iter() {
912 let bound = match bound_outer {
913 types::RangeBound::Exclusive(bound)
914 | types::RangeBound::Inclusive(bound) => bound
915 .map(|bound| T::from_sql(&inner_typ, bound))
916 .transpose()?,
917 types::RangeBound::Unbounded => None,
918 };
919 let inclusive = matches!(bound_outer, types::RangeBound::Inclusive(_));
920 bounds.push(RangeBound { bound, inclusive });
921 }
922
923 let lower = bounds.remove(0);
924 let upper = bounds.remove(0);
925 assert!(bounds.is_empty());
926
927 Some(RangeInner {
928 lower,
929 upper: RangeBound {
931 bound: upper.bound,
932 inclusive: upper.inclusive,
933 },
934 })
935 }
936 };
937
938 Ok(Range { inner })
939 }
940
941 fn accepts(ty: &PgType) -> bool {
942 matches!(
943 ty,
944 &PgType::INT4_RANGE
945 | &PgType::INT8_RANGE
946 | &PgType::DATE_RANGE
947 | &PgType::NUM_RANGE
948 | &PgType::TS_RANGE
949 | &PgType::TSTZ_RANGE
950 )
951 }
952}