Skip to main content

mz_repr/adt/
range.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::any::type_name;
11use std::cmp::Ordering;
12use std::error::Error;
13use std::fmt::{self, Debug, Display};
14use std::hash::{Hash, Hasher};
15
16use bitflags::bitflags;
17use chrono::{DateTime, NaiveDateTime, Utc};
18use dec::OrderedDecimal;
19use mz_lowertest::MzReflect;
20use mz_proto::{RustType, TryFromProtoError};
21use postgres_protocol::types;
22use proptest_derive::Arbitrary;
23use serde::{Deserialize, Serialize};
24use tokio_postgres::types::{FromSql, Type as PgType};
25
26use crate::Datum;
27use crate::adt::date::Date;
28use crate::adt::numeric::Numeric;
29use crate::adt::timestamp::CheckedTimestamp;
30use crate::scalar::DatumKind;
31
32include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.range.rs"));
33
34bitflags! {
35    pub(crate) struct InternalFlags: u8 {
36        const EMPTY = 1;
37        const LB_INCLUSIVE = 1 << 1;
38        const LB_INFINITE = 1 << 2;
39        const UB_INCLUSIVE = 1 << 3;
40        const UB_INFINITE = 1 << 4;
41    }
42}
43
44bitflags! {
45    pub(crate) struct PgFlags: u8 {
46        const EMPTY = 0b0000_0001;
47        const LB_INCLUSIVE = 0b0000_0010;
48        const UB_INCLUSIVE = 0b0000_0100;
49        const LB_INFINITE = 0b0000_1000;
50        const UB_INFINITE = 0b0001_0000;
51    }
52}
53
54/// A range of values along the domain `D`.
55///
56/// `D` is generic to facilitate interoperating over multiple representation,
57/// e.g. `Datum` and `mz_pgrepr::Value`. Because of the latter, we have to
58/// "manually derive" traits over `Range`.
59///
60/// Also notable, is that `Datum`s themselves store ranges as
61/// `Range<DatumNested<'a>>`, which lets us avoid unnecessary boxing of the
62/// range's finite bounds, which are most often expressed as `Datum`.
63pub struct Range<D> {
64    /// None value represents empty range
65    pub inner: Option<RangeInner<D>>,
66}
67
68impl<D: Display> Display for Range<D> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match &self.inner {
71            None => f.write_str("empty"),
72            Some(i) => i.fmt(f),
73        }
74    }
75}
76
77impl<D: Debug> Debug for Range<D> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_struct("Range").field("inner", &self.inner).finish()
80    }
81}
82
83impl<D: Clone> Clone for Range<D> {
84    fn clone(&self) -> Self {
85        Self {
86            inner: self.inner.clone(),
87        }
88    }
89}
90
91impl<D: Copy> Copy for Range<D> {}
92
93impl<D: PartialEq> PartialEq for Range<D> {
94    fn eq(&self, other: &Self) -> bool {
95        self.inner == other.inner
96    }
97}
98
99impl<D: Eq> Eq for Range<D> {}
100
101impl<D: Ord + PartialOrd> PartialOrd for Range<D> {
102    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
103        Some(self.cmp(other))
104    }
105}
106
107impl<D: Ord> Ord for Range<D> {
108    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
109        self.inner.cmp(&other.inner)
110    }
111}
112
113impl<D: Hash> Hash for Range<D> {
114    fn hash<H: Hasher>(&self, hasher: &mut H) {
115        self.inner.hash(hasher)
116    }
117}
118
119/// Trait alias for traits required for generic range function implementations.
120pub trait RangeOps<'a>:
121    Debug + Ord + PartialOrd + Eq + PartialEq + TryFrom<Datum<'a>> + Into<Datum<'a>>
122where
123    <Self as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
124{
125    /// Increment `self` one step forward, if applicable. Return `None` if
126    /// overflows.
127    fn step(self) -> Option<Self> {
128        Some(self)
129    }
130
131    fn unwrap_datum(d: Datum<'a>) -> Self {
132        <Self>::try_from(d)
133            .unwrap_or_else(|_| panic!("cannot take {} to {}", d, type_name::<Self>()))
134    }
135
136    fn err_type_name() -> &'static str;
137}
138
139impl<'a> RangeOps<'a> for i32 {
140    fn step(self) -> Option<i32> {
141        self.checked_add(1)
142    }
143
144    fn err_type_name() -> &'static str {
145        "integer"
146    }
147}
148
149impl<'a> RangeOps<'a> for i64 {
150    fn step(self) -> Option<i64> {
151        self.checked_add(1)
152    }
153
154    fn err_type_name() -> &'static str {
155        "bigint"
156    }
157}
158
159impl<'a> RangeOps<'a> for Date {
160    fn step(self) -> Option<Date> {
161        self.checked_add(1).ok()
162    }
163
164    fn err_type_name() -> &'static str {
165        "date"
166    }
167}
168
169impl<'a> RangeOps<'a> for OrderedDecimal<Numeric> {
170    fn err_type_name() -> &'static str {
171        "numeric"
172    }
173}
174
175impl<'a> RangeOps<'a> for CheckedTimestamp<NaiveDateTime> {
176    fn err_type_name() -> &'static str {
177        "timestamp"
178    }
179}
180
181impl<'a> RangeOps<'a> for CheckedTimestamp<DateTime<Utc>> {
182    fn err_type_name() -> &'static str {
183        "timestamptz"
184    }
185}
186
187// Totally generic range implementations.
188impl<D> Range<D> {
189    /// Create a new range.
190    ///
191    /// Note that when constructing `Range<Datum<'a>>`, the range must still be
192    /// canonicalized. If this becomes a common operation, we should consider
193    /// addinga `new_canonical` function that performs both steps.
194    pub fn new(inner: Option<(RangeLowerBound<D>, RangeUpperBound<D>)>) -> Range<D> {
195        Range {
196            inner: inner.map(|(lower, upper)| RangeInner { lower, upper }),
197        }
198    }
199
200    /// Get the flag bits appropriate to use in our internal (i.e. row) encoding
201    /// of range values.
202    ///
203    /// Note that this differs from the flags appropriate to encode with
204    /// Postgres, which has `UB_INFINITE` and `LB_INCLUSIVE` in the alternate
205    /// position.
206    pub fn internal_flag_bits(&self) -> u8 {
207        let mut flags = InternalFlags::empty();
208
209        match &self.inner {
210            None => {
211                flags.set(InternalFlags::EMPTY, true);
212            }
213            Some(RangeInner { lower, upper }) => {
214                flags.set(InternalFlags::EMPTY, false);
215                flags.set(InternalFlags::LB_INFINITE, lower.bound.is_none());
216                flags.set(InternalFlags::UB_INFINITE, upper.bound.is_none());
217                flags.set(InternalFlags::LB_INCLUSIVE, lower.inclusive);
218                flags.set(InternalFlags::UB_INCLUSIVE, upper.inclusive);
219            }
220        }
221
222        flags.bits()
223    }
224
225    /// Get the flag bits appropriate to use in PG-compatible encodings of range
226    /// values.
227    ///
228    /// Note that this differs from the flags appropriate for our internal
229    /// encoding, which has `UB_INFINITE` and `LB_INCLUSIVE` in the alternate
230    /// position.
231    pub fn pg_flag_bits(&self) -> u8 {
232        let mut flags = PgFlags::empty();
233
234        match &self.inner {
235            None => {
236                flags.set(PgFlags::EMPTY, true);
237            }
238            Some(RangeInner { lower, upper }) => {
239                flags.set(PgFlags::EMPTY, false);
240                flags.set(PgFlags::LB_INFINITE, lower.bound.is_none());
241                flags.set(PgFlags::UB_INFINITE, upper.bound.is_none());
242                flags.set(PgFlags::LB_INCLUSIVE, lower.inclusive);
243                flags.set(PgFlags::UB_INCLUSIVE, upper.inclusive);
244            }
245        }
246
247        flags.bits()
248    }
249
250    /// Converts `self` from having bounds of type `D` to type `O`, converting
251    /// the current bounds using `conv`.
252    pub fn into_bounds<F, O>(self, conv: F) -> Range<O>
253    where
254        F: Fn(D) -> O,
255    {
256        Range {
257            inner: self
258                .inner
259                .map(|RangeInner::<D> { lower, upper }| RangeInner::<O> {
260                    lower: RangeLowerBound {
261                        inclusive: lower.inclusive,
262                        bound: lower.bound.map(&conv),
263                    },
264                    upper: RangeUpperBound {
265                        inclusive: upper.inclusive,
266                        bound: upper.bound.map(&conv),
267                    },
268                }),
269        }
270    }
271
272    /// Like [`into_bounds`](Self::into_bounds), but the conversion may fail.
273    ///
274    /// Use this when converting each bound with a fallible function (e.g.
275    /// `into_datum`). Callers need not reach into `Range`'s internals
276    /// (`RangeInner`, `RangeLowerBound`, `RangeUpperBound`).
277    pub fn try_into_bounds<F, O, E>(self, conv: F) -> Result<Range<O>, E>
278    where
279        F: Fn(D) -> Result<O, E>,
280    {
281        let inner = match self.inner {
282            None => None,
283            Some(RangeInner { lower, upper }) => Some(RangeInner {
284                lower: RangeLowerBound {
285                    inclusive: lower.inclusive,
286                    bound: lower.bound.map(&conv).transpose()?,
287                },
288                upper: RangeUpperBound {
289                    inclusive: upper.inclusive,
290                    bound: upper.bound.map(&conv).transpose()?,
291                },
292            }),
293        };
294        Ok(Range { inner })
295    }
296}
297
298/// Range implementations meant to work with `Range<Datum>` and `Range<DatumNested>`.
299impl<'a, B: Copy + Ord + PartialOrd + Display + Debug> Range<B>
300where
301    Datum<'a>: From<B>,
302{
303    pub fn contains_elem<T: RangeOps<'a>>(&self, elem: &T) -> bool
304    where
305        <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
306    {
307        match self.inner {
308            None => false,
309            Some(inner) => inner.lower.satisfied_by(elem) && inner.upper.satisfied_by(elem),
310        }
311    }
312
313    pub fn contains_range(&self, other: &Range<B>) -> bool {
314        match (self.inner, other.inner) {
315            (None, None) | (Some(_), None) => true,
316            (None, Some(_)) => false,
317            (Some(i), Some(j)) => i.lower <= j.lower && j.upper <= i.upper,
318        }
319    }
320
321    pub fn overlaps(&self, other: &Range<B>) -> bool {
322        match (self.inner, other.inner) {
323            (Some(s), Some(o)) => {
324                let r = match s.cmp(&o) {
325                    Ordering::Equal => Ordering::Equal,
326                    Ordering::Less => s.upper.range_bound_cmp(&o.lower),
327                    Ordering::Greater => o.upper.range_bound_cmp(&s.lower),
328                };
329
330                // If smaller upper is >= larger lower, elements overlap.
331                matches!(r, Ordering::Greater | Ordering::Equal)
332            }
333            _ => false,
334        }
335    }
336
337    pub fn before(&self, other: &Range<B>) -> bool {
338        match (self.inner, other.inner) {
339            (Some(s), Some(o)) => {
340                matches!(s.upper.range_bound_cmp(&o.lower), Ordering::Less)
341            }
342            _ => false,
343        }
344    }
345
346    pub fn after(&self, other: &Range<B>) -> bool {
347        match (self.inner, other.inner) {
348            (Some(s), Some(o)) => {
349                matches!(s.lower.range_bound_cmp(&o.upper), Ordering::Greater)
350            }
351            _ => false,
352        }
353    }
354
355    pub fn overleft(&self, other: &Range<B>) -> bool {
356        match (self.inner, other.inner) {
357            (Some(s), Some(o)) => {
358                matches!(
359                    s.upper.range_bound_cmp(&o.upper),
360                    Ordering::Less | Ordering::Equal
361                )
362            }
363            _ => false,
364        }
365    }
366
367    pub fn overright(&self, other: &Range<B>) -> bool {
368        match (self.inner, other.inner) {
369            (Some(s), Some(o)) => {
370                matches!(
371                    s.lower.range_bound_cmp(&o.lower),
372                    Ordering::Greater | Ordering::Equal
373                )
374            }
375            _ => false,
376        }
377    }
378
379    pub fn adjacent(&self, other: &Range<B>) -> bool {
380        match (self.inner, other.inner) {
381            (Some(s), Some(o)) => {
382                // Look at each (lower, upper) pair.
383                for (lower, upper) in [(s.lower, o.upper), (o.lower, s.upper)] {
384                    if let (Some(l), Some(u)) = (lower.bound, upper.bound) {
385                        // If ..x](x.. or ..x)[x.., adjacent
386                        if lower.inclusive ^ upper.inclusive && l == u {
387                            return true;
388                        }
389                    }
390                }
391                false
392            }
393            _ => false,
394        }
395    }
396
397    pub fn union(&self, other: &Range<B>) -> Result<Range<B>, InvalidRangeError> {
398        // Handle self or other being empty
399        let (s, o) = match (self.inner, other.inner) {
400            (None, None) => return Ok(Range { inner: None }),
401            (inner @ Some(_), None) | (None, inner @ Some(_)) => return Ok(Range { inner }),
402            (Some(s), Some(o)) => {
403                // if not overlapping or adjacent, then result would not present continuity, so error.
404                if !(self.overlaps(other) || self.adjacent(other)) {
405                    return Err(InvalidRangeError::DiscontiguousUnion);
406                }
407                (s, o)
408            }
409        };
410
411        let lower = std::cmp::min(s.lower, o.lower);
412        let upper = std::cmp::max(s.upper, o.upper);
413
414        Ok(Range {
415            inner: Some(RangeInner { lower, upper }),
416        })
417    }
418
419    pub fn intersection(&self, other: &Range<B>) -> Range<B> {
420        // Handle self or other being empty
421        let (s, o) = match (self.inner, other.inner) {
422            (Some(s), Some(o)) => {
423                if !self.overlaps(other) {
424                    return Range { inner: None };
425                }
426
427                (s, o)
428            }
429            _ => return Range { inner: None },
430        };
431
432        let lower = std::cmp::max(s.lower, o.lower);
433        let upper = std::cmp::min(s.upper, o.upper);
434
435        Range {
436            inner: Some(RangeInner { lower, upper }),
437        }
438    }
439
440    // Function requires canonicalization so must be taken into `Range<Datum>`,
441    // which can be taken back into `Range<DatumNested>` by the caller if need
442    // be.
443    pub fn difference(&self, other: &Range<B>) -> Result<Range<Datum<'a>>, InvalidRangeError> {
444        use std::cmp::Ordering::*;
445
446        // Difference op does nothing if no overlap.
447        if !self.overlaps(other) {
448            return Ok(self.into_bounds(Datum::from));
449        }
450
451        let (s, o) = match (self.inner, other.inner) {
452            (None, _) | (_, None) => unreachable!("already returned from overlap check"),
453            (Some(s), Some(o)) => (s, o),
454        };
455
456        let ll = s.lower.cmp(&o.lower);
457        let uu = s.upper.cmp(&o.upper);
458
459        let r = match (ll, uu) {
460            // `self` totally contains `other`
461            (Less, Greater) => return Err(InvalidRangeError::DiscontiguousDifference),
462            // `other` totally contains `self`
463            (Greater | Equal, Less | Equal) => Range { inner: None },
464            (Greater | Equal, Greater) => {
465                let lower = RangeBound {
466                    inclusive: !o.upper.inclusive,
467                    bound: o.upper.bound,
468                };
469                Range {
470                    inner: Some(RangeInner {
471                        lower,
472                        upper: s.upper,
473                    }),
474                }
475            }
476            (Less, Less | Equal) => {
477                let upper = RangeBound {
478                    inclusive: !o.lower.inclusive,
479                    bound: o.lower.bound,
480                };
481                Range {
482                    inner: Some(RangeInner {
483                        lower: s.lower,
484                        upper,
485                    }),
486                }
487            }
488        };
489
490        let mut r = r.into_bounds(Datum::from);
491
492        r.canonicalize()?;
493
494        Ok(r)
495    }
496}
497
498impl<'a> Range<Datum<'a>> {
499    /// Canonicalize the range by PG's heuristics, which are:
500    /// - Infinite bounds are always exclusive
501    /// - If type has step:
502    ///  - Exclusive lower bounds are rewritten as inclusive += step
503    ///  - Inclusive lower bounds are rewritten as exclusive += step
504    /// - Ranges are empty if lower >= upper after prev. step unless range type
505    ///   does not have step and both bounds are inclusive
506    ///
507    /// # Panics
508    /// - If the upper and lower bounds are finite and of different types.
509    pub fn canonicalize(&mut self) -> Result<(), InvalidRangeError> {
510        let (lower, upper) = match &mut self.inner {
511            Some(inner) => (&mut inner.lower, &mut inner.upper),
512            None => return Ok(()),
513        };
514
515        match (lower.bound, upper.bound) {
516            (Some(l), Some(u)) => {
517                assert_eq!(
518                    DatumKind::from(l),
519                    DatumKind::from(u),
520                    "finite bounds must be of same type"
521                );
522                if l > u {
523                    return Err(InvalidRangeError::MisorderedRangeBounds);
524                }
525            }
526            _ => {}
527        };
528
529        lower.canonicalize()?;
530        upper.canonicalize()?;
531
532        // The only way that you have two inclusive bounds with equal value are
533        // if type does not have step.
534        if !(lower.inclusive && upper.inclusive)
535            && lower.bound >= upper.bound
536            // None is less than any Some, so only need to check this condition.
537            && upper.bound.is_some()
538        {
539            // emtpy range
540            self.inner = None
541        }
542
543        Ok(())
544    }
545}
546
547/// Holds the upper and lower bounds for non-empty ranges.
548#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
549pub struct RangeInner<B> {
550    pub lower: RangeLowerBound<B>,
551    pub upper: RangeUpperBound<B>,
552}
553
554impl<B: Display> Display for RangeInner<B> {
555    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
556        f.write_str(if self.lower.inclusive { "[" } else { "(" })?;
557        self.lower.fmt(f)?;
558        f.write_str(",")?;
559        Display::fmt(&self.upper, f)?;
560        f.write_str(if self.upper.inclusive { "]" } else { ")" })
561    }
562}
563
564impl<B: Ord> Ord for RangeInner<B> {
565    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
566        self.lower
567            .cmp(&other.lower)
568            .then(self.upper.cmp(&other.upper))
569    }
570}
571
572impl<B: PartialOrd + Ord> PartialOrd for RangeInner<B> {
573    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
574        Some(self.cmp(other))
575    }
576}
577
578/// Represents a terminal point of a range.
579#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
580pub struct RangeBound<B, const UPPER: bool = false> {
581    pub inclusive: bool,
582    /// None value represents an infinite bound.
583    pub bound: Option<B>,
584}
585
586impl<const UPPER: bool, D: Display> Display for RangeBound<D, UPPER> {
587    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588        match &self.bound {
589            None => Ok(()),
590            Some(bound) => bound.fmt(f),
591        }
592    }
593}
594
595impl<const UPPER: bool, D: Ord> Ord for RangeBound<D, UPPER> {
596    fn cmp(&self, other: &Self) -> Ordering {
597        // 1. Sort by bounds
598        let mut cmp = self.bound.cmp(&other.bound);
599        // 2. Infinite bounds vs. finite bounds are reversed for uppers.
600        if UPPER && other.bound.is_none() ^ self.bound.is_none() {
601            cmp = cmp.reverse();
602        }
603        // 3. Tie break by sorting by inclusivity, which is inverted between
604        //    lowers and uppers.
605        cmp.then(if self.inclusive == other.inclusive {
606            Ordering::Equal
607        } else if self.inclusive {
608            if UPPER {
609                Ordering::Greater
610            } else {
611                Ordering::Less
612            }
613        } else if UPPER {
614            Ordering::Less
615        } else {
616            Ordering::Greater
617        })
618    }
619}
620
621impl<const UPPER: bool, D: PartialOrd + Ord> PartialOrd for RangeBound<D, UPPER> {
622    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
623        Some(self.cmp(other))
624    }
625}
626
627/// A `RangeBound` that sorts correctly for use as a lower bound.
628pub type RangeLowerBound<B> = RangeBound<B, false>;
629
630/// A `RangeBound` that sorts correctly for use as an upper bound.
631pub type RangeUpperBound<B> = RangeBound<B, true>;
632
633// Generic RangeBound implementations meant to work over `RangeBound<Datum,..>`
634// and `RangeBound<DatumNested,..>`.
635impl<'a, const UPPER: bool, B: Copy + Ord + PartialOrd + Display + Debug> RangeBound<B, UPPER>
636where
637    Datum<'a>: From<B>,
638{
639    /// Determines where `elem` lies in relation to the range bound.
640    ///
641    /// # Panics
642    /// - If `self.bound.datum()` is not convertible to `T`.
643    fn elem_cmp<T: RangeOps<'a>>(&self, elem: &T) -> Ordering
644    where
645        <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
646    {
647        match self.bound.map(|bound| <T>::unwrap_datum(bound.into())) {
648            None if UPPER => Ordering::Greater,
649            None => Ordering::Less,
650            Some(bound) => bound.cmp(elem),
651        }
652    }
653
654    /// Does `elem` satisfy this bound?
655    fn satisfied_by<T: RangeOps<'a>>(&self, elem: &T) -> bool
656    where
657        <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
658    {
659        match self.elem_cmp(elem) {
660            // Inclusive always satisfied with equality, regardless of upper or
661            // lower.
662            Ordering::Equal => self.inclusive,
663            // Upper satisfied with values less than itself
664            Ordering::Greater => UPPER,
665            // Lower satisfied with values greater than itself
666            Ordering::Less => !UPPER,
667        }
668    }
669
670    // Compares two `RangeBound`, which do not need to both be of the same
671    // `UPPER`.
672    fn range_bound_cmp<const OTHER_UPPER: bool>(
673        &self,
674        other: &RangeBound<B, OTHER_UPPER>,
675    ) -> Ordering {
676        if UPPER == OTHER_UPPER {
677            return self.cmp(&RangeBound {
678                inclusive: other.inclusive,
679                bound: other.bound,
680            });
681        }
682
683        // Handle cases where either are infinite bounds, which have special
684        // semantics.
685        if self.bound.is_none() || other.bound.is_none() {
686            return if UPPER {
687                Ordering::Greater
688            } else {
689                Ordering::Less
690            };
691        }
692        // 1. Sort by bounds
693        let cmp = self.bound.cmp(&other.bound);
694        // 2. Tie break by sorting by inclusivity, which is inverted between
695        //    lowers and uppers.
696        cmp.then(if self.inclusive && other.inclusive {
697            Ordering::Equal
698        } else if UPPER {
699            Ordering::Less
700        } else {
701            Ordering::Greater
702        })
703    }
704}
705
706impl<'a, const UPPER: bool> RangeBound<Datum<'a>, UPPER> {
707    /// Create a new `RangeBound` whose value is "infinite" (i.e. None) if `d ==
708    /// Datum::Null`, otherwise finite (i.e. Some).
709    ///
710    /// There is not a corresponding generic implementation of this because
711    /// genericizing how to express infinite bounds is less clear.
712    pub fn new(d: Datum<'a>, inclusive: bool) -> RangeBound<Datum<'a>, UPPER> {
713        RangeBound {
714            inclusive,
715            bound: match d {
716                Datum::Null => None,
717                o => Some(o),
718            },
719        }
720    }
721
722    /// Rewrite the bounds to the consistent format. This is absolutely
723    /// necessary to perform the correct equality/comparison operations on
724    /// types.
725    fn canonicalize(&mut self) -> Result<(), InvalidRangeError> {
726        Ok(match self.bound {
727            None => {
728                self.inclusive = false;
729            }
730            // Valid range types are defined in typeconv.rs:validate_range_element_type
731            Some(value) => match value {
732                d @ Datum::Int32(_) => self.canonicalize_inner::<i32>(d)?,
733                d @ Datum::Int64(_) => self.canonicalize_inner::<i64>(d)?,
734                d @ Datum::Date(_) => self.canonicalize_inner::<Date>(d)?,
735                Datum::Numeric(..) | Datum::Timestamp(..) | Datum::TimestampTz(..) => {}
736                d => unreachable!("{d:?} not yet supported in ranges"),
737            },
738        })
739    }
740
741    /// Canonicalize `self`'s representation for types that have discrete steps
742    /// between values.
743    ///
744    /// Continuous values (e.g. timestamps, numeric) must not be
745    /// canonicalized.
746    fn canonicalize_inner<T: RangeOps<'a>>(&mut self, d: Datum<'a>) -> Result<(), InvalidRangeError>
747    where
748        <T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
749    {
750        // Upper bounds must be exclusive, lower bounds inclusive
751        if UPPER == self.inclusive {
752            let cur = <T>::unwrap_datum(d);
753            self.bound = Some(
754                cur.step()
755                    .ok_or_else(|| {
756                        InvalidRangeError::CanonicalizationOverflow(T::err_type_name().into())
757                    })?
758                    .into(),
759            );
760            self.inclusive = !UPPER;
761        }
762
763        Ok(())
764    }
765}
766
767#[derive(
768    Arbitrary,
769    Ord,
770    PartialOrd,
771    Clone,
772    Debug,
773    Eq,
774    PartialEq,
775    Serialize,
776    Deserialize,
777    Hash,
778    MzReflect
779)]
780pub enum InvalidRangeError {
781    MisorderedRangeBounds,
782    CanonicalizationOverflow(Box<str>),
783    InvalidRangeBoundFlags,
784    DiscontiguousUnion,
785    DiscontiguousDifference,
786    NullRangeBoundFlags,
787}
788
789impl Display for InvalidRangeError {
790    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
791        match self {
792            InvalidRangeError::MisorderedRangeBounds => {
793                f.write_str("range lower bound must be less than or equal to range upper bound")
794            }
795            InvalidRangeError::CanonicalizationOverflow(t) => {
796                write!(f, "{} out of range", t)
797            }
798            InvalidRangeError::InvalidRangeBoundFlags => f.write_str("invalid range bound flags"),
799            InvalidRangeError::DiscontiguousUnion => {
800                f.write_str("result of range union would not be contiguous")
801            }
802            InvalidRangeError::DiscontiguousDifference => {
803                f.write_str("result of range difference would not be contiguous")
804            }
805            InvalidRangeError::NullRangeBoundFlags => {
806                f.write_str("range constructor flags argument must not be null")
807            }
808        }
809    }
810}
811
812impl Error for InvalidRangeError {
813    fn source(&self) -> Option<&(dyn Error + 'static)> {
814        None
815    }
816}
817
818// Required due to Proto decoding using string as its error type
819impl From<InvalidRangeError> for String {
820    fn from(e: InvalidRangeError) -> Self {
821        e.to_string()
822    }
823}
824
825impl RustType<ProtoInvalidRangeError> for InvalidRangeError {
826    fn into_proto(&self) -> ProtoInvalidRangeError {
827        use Kind::*;
828        use proto_invalid_range_error::*;
829        let kind = match self {
830            InvalidRangeError::MisorderedRangeBounds => MisorderedRangeBounds(()),
831            InvalidRangeError::CanonicalizationOverflow(s) => {
832                CanonicalizationOverflow(s.into_proto())
833            }
834            InvalidRangeError::InvalidRangeBoundFlags => InvalidRangeBoundFlags(()),
835            InvalidRangeError::DiscontiguousUnion => DiscontiguousUnion(()),
836            InvalidRangeError::DiscontiguousDifference => DiscontiguousDifference(()),
837            InvalidRangeError::NullRangeBoundFlags => NullRangeBoundFlags(()),
838        };
839        ProtoInvalidRangeError { kind: Some(kind) }
840    }
841
842    fn from_proto(proto: ProtoInvalidRangeError) -> Result<Self, TryFromProtoError> {
843        use proto_invalid_range_error::Kind::*;
844        match proto.kind {
845            Some(kind) => Ok(match kind {
846                MisorderedRangeBounds(()) => InvalidRangeError::MisorderedRangeBounds,
847                CanonicalizationOverflow(s) => {
848                    InvalidRangeError::CanonicalizationOverflow(s.into())
849                }
850                InvalidRangeBoundFlags(()) => InvalidRangeError::InvalidRangeBoundFlags,
851                DiscontiguousUnion(()) => InvalidRangeError::DiscontiguousUnion,
852                DiscontiguousDifference(()) => InvalidRangeError::DiscontiguousDifference,
853                NullRangeBoundFlags(()) => InvalidRangeError::NullRangeBoundFlags,
854            }),
855            None => Err(TryFromProtoError::missing_field(
856                "`ProtoInvalidRangeError::kind`",
857            )),
858        }
859    }
860}
861
862pub fn parse_range_bound_flags<'a>(flags: &'a str) -> Result<(bool, bool), InvalidRangeError> {
863    let mut flags = flags.chars();
864
865    let lower = match flags.next() {
866        Some('(') => false,
867        Some('[') => true,
868        _ => return Err(InvalidRangeError::InvalidRangeBoundFlags),
869    };
870
871    let upper = match flags.next() {
872        Some(')') => false,
873        Some(']') => true,
874        _ => return Err(InvalidRangeError::InvalidRangeBoundFlags),
875    };
876
877    match flags.next() {
878        Some(_) => Err(InvalidRangeError::InvalidRangeBoundFlags),
879        None => Ok((lower, upper)),
880    }
881}
882
883impl<'a, T: FromSql<'a>> FromSql<'a> for Range<T> {
884    fn from_sql(ty: &PgType, raw: &'a [u8]) -> Result<Range<T>, Box<dyn Error + Sync + Send>> {
885        let inner_typ = match ty {
886            &PgType::INT4_RANGE => PgType::INT4,
887            &PgType::INT8_RANGE => PgType::INT8,
888            &PgType::DATE_RANGE => PgType::DATE,
889            &PgType::NUM_RANGE => PgType::NUMERIC,
890            &PgType::TS_RANGE => PgType::TIMESTAMP,
891            &PgType::TSTZ_RANGE => PgType::TIMESTAMPTZ,
892            _ => unreachable!(),
893        };
894
895        let inner = match types::range_from_sql(raw)? {
896            types::Range::Empty => None,
897            types::Range::Nonempty(lower, upper) => {
898                let mut bounds = Vec::with_capacity(2);
899
900                for bound_outer in [lower, upper].into_iter() {
901                    let bound = match bound_outer {
902                        types::RangeBound::Exclusive(bound)
903                        | types::RangeBound::Inclusive(bound) => bound
904                            .map(|bound| T::from_sql(&inner_typ, bound))
905                            .transpose()?,
906                        types::RangeBound::Unbounded => None,
907                    };
908                    let inclusive = matches!(bound_outer, types::RangeBound::Inclusive(_));
909                    bounds.push(RangeBound { bound, inclusive });
910                }
911
912                let lower = bounds.remove(0);
913                let upper = bounds.remove(0);
914                assert!(bounds.is_empty());
915
916                Some(RangeInner {
917                    lower,
918                    // Rewrite bound in terms of appropriate `UPPER`
919                    upper: RangeBound {
920                        bound: upper.bound,
921                        inclusive: upper.inclusive,
922                    },
923                })
924            }
925        };
926
927        Ok(Range { inner })
928    }
929
930    fn accepts(ty: &PgType) -> bool {
931        matches!(
932            ty,
933            &PgType::INT4_RANGE
934                | &PgType::INT8_RANGE
935                | &PgType::DATE_RANGE
936                | &PgType::NUM_RANGE
937                | &PgType::TS_RANGE
938                | &PgType::TSTZ_RANGE
939        )
940    }
941}