use std::any::type_name;
use std::cmp::Ordering;
use std::error::Error;
use std::fmt::{self, Debug, Display};
use std::hash::{Hash, Hasher};
use bitflags::bitflags;
use chrono::{DateTime, NaiveDateTime, Utc};
use dec::OrderedDecimal;
use mz_lowertest::MzReflect;
use mz_proto::{RustType, TryFromProtoError};
use postgres_protocol::types;
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};
use tokio_postgres::types::{FromSql, Type as PgType};
use crate::adt::date::Date;
use crate::adt::numeric::Numeric;
use crate::adt::timestamp::CheckedTimestamp;
use crate::scalar::DatumKind;
use crate::Datum;
include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.range.rs"));
bitflags! {
pub(crate) struct InternalFlags: u8 {
const EMPTY = 1;
const LB_INCLUSIVE = 1 << 1;
const LB_INFINITE = 1 << 2;
const UB_INCLUSIVE = 1 << 3;
const UB_INFINITE = 1 << 4;
}
}
bitflags! {
pub(crate) struct PgFlags: u8 {
const EMPTY = 0b0000_0001;
const LB_INCLUSIVE = 0b0000_0010;
const UB_INCLUSIVE = 0b0000_0100;
const LB_INFINITE = 0b0000_1000;
const UB_INFINITE = 0b0001_0000;
}
}
pub struct Range<D> {
pub inner: Option<RangeInner<D>>,
}
impl<D: Display> Display for Range<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
None => f.write_str("empty"),
Some(i) => i.fmt(f),
}
}
}
impl<D: Debug> Debug for Range<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Range").field("inner", &self.inner).finish()
}
}
impl<D: Clone> Clone for Range<D> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<D: Copy> Copy for Range<D> {}
impl<D: PartialEq> PartialEq for Range<D> {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl<D: Eq> Eq for Range<D> {}
impl<D: Ord + PartialOrd> PartialOrd for Range<D> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<D: Ord> Ord for Range<D> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.inner.cmp(&other.inner)
}
}
impl<D: Hash> Hash for Range<D> {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.inner.hash(hasher)
}
}
pub trait RangeOps<'a>:
Debug + Ord + PartialOrd + Eq + PartialEq + TryFrom<Datum<'a>> + Into<Datum<'a>>
where
<Self as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
fn step(self) -> Option<Self> {
Some(self)
}
fn unwrap_datum(d: Datum<'a>) -> Self {
<Self>::try_from(d)
.unwrap_or_else(|_| panic!("cannot take {} to {}", d, type_name::<Self>()))
}
fn err_type_name() -> &'static str;
}
impl<'a> RangeOps<'a> for i32 {
fn step(self) -> Option<i32> {
self.checked_add(1)
}
fn err_type_name() -> &'static str {
"integer"
}
}
impl<'a> RangeOps<'a> for i64 {
fn step(self) -> Option<i64> {
self.checked_add(1)
}
fn err_type_name() -> &'static str {
"bigint"
}
}
impl<'a> RangeOps<'a> for Date {
fn step(self) -> Option<Date> {
self.checked_add(1).ok()
}
fn err_type_name() -> &'static str {
"date"
}
}
impl<'a> RangeOps<'a> for OrderedDecimal<Numeric> {
fn err_type_name() -> &'static str {
"numeric"
}
}
impl<'a> RangeOps<'a> for CheckedTimestamp<NaiveDateTime> {
fn err_type_name() -> &'static str {
"timestamp"
}
}
impl<'a> RangeOps<'a> for CheckedTimestamp<DateTime<Utc>> {
fn err_type_name() -> &'static str {
"timestamptz"
}
}
impl<D> Range<D> {
pub fn new(inner: Option<(RangeLowerBound<D>, RangeUpperBound<D>)>) -> Range<D> {
Range {
inner: inner.map(|(lower, upper)| RangeInner { lower, upper }),
}
}
pub fn internal_flag_bits(&self) -> u8 {
let mut flags = InternalFlags::empty();
match &self.inner {
None => {
flags.set(InternalFlags::EMPTY, true);
}
Some(RangeInner { lower, upper }) => {
flags.set(InternalFlags::EMPTY, false);
flags.set(InternalFlags::LB_INFINITE, lower.bound.is_none());
flags.set(InternalFlags::UB_INFINITE, upper.bound.is_none());
flags.set(InternalFlags::LB_INCLUSIVE, lower.inclusive);
flags.set(InternalFlags::UB_INCLUSIVE, upper.inclusive);
}
}
flags.bits()
}
pub fn pg_flag_bits(&self) -> u8 {
let mut flags = PgFlags::empty();
match &self.inner {
None => {
flags.set(PgFlags::EMPTY, true);
}
Some(RangeInner { lower, upper }) => {
flags.set(PgFlags::EMPTY, false);
flags.set(PgFlags::LB_INFINITE, lower.bound.is_none());
flags.set(PgFlags::UB_INFINITE, upper.bound.is_none());
flags.set(PgFlags::LB_INCLUSIVE, lower.inclusive);
flags.set(PgFlags::UB_INCLUSIVE, upper.inclusive);
}
}
flags.bits()
}
pub fn into_bounds<F, O>(self, conv: F) -> Range<O>
where
F: Fn(D) -> O,
{
Range {
inner: self
.inner
.map(|RangeInner::<D> { lower, upper }| RangeInner::<O> {
lower: RangeLowerBound {
inclusive: lower.inclusive,
bound: lower.bound.map(&conv),
},
upper: RangeUpperBound {
inclusive: upper.inclusive,
bound: upper.bound.map(&conv),
},
}),
}
}
}
impl<'a, B: Copy + Ord + PartialOrd + Display + Debug> Range<B>
where
Datum<'a>: From<B>,
{
pub fn contains_elem<T: RangeOps<'a>>(&self, elem: &T) -> bool
where
<T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
match self.inner {
None => false,
Some(inner) => inner.lower.satisfied_by(elem) && inner.upper.satisfied_by(elem),
}
}
pub fn contains_range(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(None, None) | (Some(_), None) => true,
(None, Some(_)) => false,
(Some(i), Some(j)) => i.lower <= j.lower && j.upper <= i.upper,
}
}
pub fn overlaps(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
let r = match s.cmp(&o) {
Ordering::Equal => Ordering::Equal,
Ordering::Less => s.upper.range_bound_cmp(&o.lower),
Ordering::Greater => o.upper.range_bound_cmp(&s.lower),
};
matches!(r, Ordering::Greater | Ordering::Equal)
}
_ => false,
}
}
pub fn before(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(s.upper.range_bound_cmp(&o.lower), Ordering::Less)
}
_ => false,
}
}
pub fn after(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(s.lower.range_bound_cmp(&o.upper), Ordering::Greater)
}
_ => false,
}
}
pub fn overleft(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(
s.upper.range_bound_cmp(&o.upper),
Ordering::Less | Ordering::Equal
)
}
_ => false,
}
}
pub fn overright(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(
s.lower.range_bound_cmp(&o.lower),
Ordering::Greater | Ordering::Equal
)
}
_ => false,
}
}
pub fn adjacent(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
for (lower, upper) in [(s.lower, o.upper), (o.lower, s.upper)] {
if let (Some(l), Some(u)) = (lower.bound, upper.bound) {
if lower.inclusive ^ upper.inclusive && l == u {
return true;
}
}
}
false
}
_ => false,
}
}
pub fn union(&self, other: &Range<B>) -> Result<Range<B>, InvalidRangeError> {
let (s, o) = match (self.inner, other.inner) {
(None, None) => return Ok(Range { inner: None }),
(inner @ Some(_), None) | (None, inner @ Some(_)) => return Ok(Range { inner }),
(Some(s), Some(o)) => {
if !(self.overlaps(other) || self.adjacent(other)) {
return Err(InvalidRangeError::DiscontiguousUnion);
}
(s, o)
}
};
let lower = std::cmp::min(s.lower, o.lower);
let upper = std::cmp::max(s.upper, o.upper);
Ok(Range {
inner: Some(RangeInner { lower, upper }),
})
}
pub fn intersection(&self, other: &Range<B>) -> Range<B> {
let (s, o) = match (self.inner, other.inner) {
(Some(s), Some(o)) => {
if !self.overlaps(other) {
return Range { inner: None };
}
(s, o)
}
_ => return Range { inner: None },
};
let lower = std::cmp::max(s.lower, o.lower);
let upper = std::cmp::min(s.upper, o.upper);
Range {
inner: Some(RangeInner { lower, upper }),
}
}
pub fn difference(&self, other: &Range<B>) -> Result<Range<Datum<'a>>, InvalidRangeError> {
use std::cmp::Ordering::*;
if !self.overlaps(other) {
return Ok(self.into_bounds(Datum::from));
}
let (s, o) = match (self.inner, other.inner) {
(None, _) | (_, None) => unreachable!("already returned from overlap check"),
(Some(s), Some(o)) => (s, o),
};
let ll = s.lower.cmp(&o.lower);
let uu = s.upper.cmp(&o.upper);
let r = match (ll, uu) {
(Less, Greater) => return Err(InvalidRangeError::DiscontiguousDifference),
(Greater | Equal, Less | Equal) => Range { inner: None },
(Greater | Equal, Greater) => {
let lower = RangeBound {
inclusive: !o.upper.inclusive,
bound: o.upper.bound,
};
Range {
inner: Some(RangeInner {
lower,
upper: s.upper,
}),
}
}
(Less, Less | Equal) => {
let upper = RangeBound {
inclusive: !o.lower.inclusive,
bound: o.lower.bound,
};
Range {
inner: Some(RangeInner {
lower: s.lower,
upper,
}),
}
}
};
let mut r = r.into_bounds(Datum::from);
r.canonicalize()?;
Ok(r)
}
}
impl<'a> Range<Datum<'a>> {
pub fn canonicalize(&mut self) -> Result<(), InvalidRangeError> {
let (lower, upper) = match &mut self.inner {
Some(inner) => (&mut inner.lower, &mut inner.upper),
None => return Ok(()),
};
match (lower.bound, upper.bound) {
(Some(l), Some(u)) => {
assert_eq!(
DatumKind::from(l),
DatumKind::from(u),
"finite bounds must be of same type"
);
if l > u {
return Err(InvalidRangeError::MisorderedRangeBounds);
}
}
_ => {}
};
lower.canonicalize()?;
upper.canonicalize()?;
if !(lower.inclusive && upper.inclusive)
&& lower.bound >= upper.bound
&& upper.bound.is_some()
{
self.inner = None
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct RangeInner<B> {
pub lower: RangeLowerBound<B>,
pub upper: RangeUpperBound<B>,
}
impl<B: Display> Display for RangeInner<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(if self.lower.inclusive { "[" } else { "(" })?;
self.lower.fmt(f)?;
f.write_str(",")?;
Display::fmt(&self.upper, f)?;
f.write_str(if self.upper.inclusive { "]" } else { ")" })
}
}
impl<B: Ord> Ord for RangeInner<B> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.lower
.cmp(&other.lower)
.then(self.upper.cmp(&other.upper))
}
}
impl<B: PartialOrd + Ord> PartialOrd for RangeInner<B> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct RangeBound<B, const UPPER: bool = false> {
pub inclusive: bool,
pub bound: Option<B>,
}
impl<const UPPER: bool, D: Display> Display for RangeBound<D, UPPER> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.bound {
None => Ok(()),
Some(bound) => bound.fmt(f),
}
}
}
impl<const UPPER: bool, D: Ord> Ord for RangeBound<D, UPPER> {
fn cmp(&self, other: &Self) -> Ordering {
let mut cmp = self.bound.cmp(&other.bound);
if UPPER && other.bound.is_none() ^ self.bound.is_none() {
cmp = cmp.reverse();
}
cmp.then(if self.inclusive == other.inclusive {
Ordering::Equal
} else if self.inclusive {
if UPPER {
Ordering::Greater
} else {
Ordering::Less
}
} else if UPPER {
Ordering::Less
} else {
Ordering::Greater
})
}
}
impl<const UPPER: bool, D: PartialOrd + Ord> PartialOrd for RangeBound<D, UPPER> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub type RangeLowerBound<B> = RangeBound<B, false>;
pub type RangeUpperBound<B> = RangeBound<B, true>;
impl<'a, const UPPER: bool, B: Copy + Ord + PartialOrd + Display + Debug> RangeBound<B, UPPER>
where
Datum<'a>: From<B>,
{
fn elem_cmp<T: RangeOps<'a>>(&self, elem: &T) -> Ordering
where
<T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
match self.bound.map(|bound| <T>::unwrap_datum(bound.into())) {
None if UPPER => Ordering::Greater,
None => Ordering::Less,
Some(bound) => bound.cmp(elem),
}
}
fn satisfied_by<T: RangeOps<'a>>(&self, elem: &T) -> bool
where
<T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
match self.elem_cmp(elem) {
Ordering::Equal => self.inclusive,
Ordering::Greater => UPPER,
Ordering::Less => !UPPER,
}
}
fn range_bound_cmp<const OTHER_UPPER: bool>(
&self,
other: &RangeBound<B, OTHER_UPPER>,
) -> Ordering {
if UPPER == OTHER_UPPER {
return self.cmp(&RangeBound {
inclusive: other.inclusive,
bound: other.bound,
});
}
if self.bound.is_none() || other.bound.is_none() {
return if UPPER {
Ordering::Greater
} else {
Ordering::Less
};
}
let cmp = self.bound.cmp(&other.bound);
cmp.then(if self.inclusive && other.inclusive {
Ordering::Equal
} else if UPPER {
Ordering::Less
} else {
Ordering::Greater
})
}
}
impl<'a, const UPPER: bool> RangeBound<Datum<'a>, UPPER> {
pub fn new(d: Datum<'a>, inclusive: bool) -> RangeBound<Datum<'a>, UPPER> {
RangeBound {
inclusive,
bound: match d {
Datum::Null => None,
o => Some(o),
},
}
}
fn canonicalize(&mut self) -> Result<(), InvalidRangeError> {
Ok(match self.bound {
None => {
self.inclusive = false;
}
Some(value) => match value {
d @ Datum::Int32(_) => self.canonicalize_inner::<i32>(d)?,
d @ Datum::Int64(_) => self.canonicalize_inner::<i64>(d)?,
d @ Datum::Date(_) => self.canonicalize_inner::<Date>(d)?,
Datum::Numeric(..) | Datum::Timestamp(..) | Datum::TimestampTz(..) => {}
d => unreachable!("{d:?} not yet supported in ranges"),
},
})
}
fn canonicalize_inner<T: RangeOps<'a>>(&mut self, d: Datum<'a>) -> Result<(), InvalidRangeError>
where
<T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
if UPPER == self.inclusive {
let cur = <T>::unwrap_datum(d);
self.bound = Some(
cur.step()
.ok_or_else(|| {
InvalidRangeError::CanonicalizationOverflow(T::err_type_name().into())
})?
.into(),
);
self.inclusive = !UPPER;
}
Ok(())
}
}
#[derive(
Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect,
)]
pub enum InvalidRangeError {
MisorderedRangeBounds,
CanonicalizationOverflow(Box<str>),
InvalidRangeBoundFlags,
DiscontiguousUnion,
DiscontiguousDifference,
NullRangeBoundFlags,
}
impl Display for InvalidRangeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
InvalidRangeError::MisorderedRangeBounds => {
f.write_str("range lower bound must be less than or equal to range upper bound")
}
InvalidRangeError::CanonicalizationOverflow(t) => {
write!(f, "{} out of range", t)
}
InvalidRangeError::InvalidRangeBoundFlags => f.write_str("invalid range bound flags"),
InvalidRangeError::DiscontiguousUnion => {
f.write_str("result of range union would not be contiguous")
}
InvalidRangeError::DiscontiguousDifference => {
f.write_str("result of range difference would not be contiguous")
}
InvalidRangeError::NullRangeBoundFlags => {
f.write_str("range constructor flags argument must not be null")
}
}
}
}
impl Error for InvalidRangeError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl From<InvalidRangeError> for String {
fn from(e: InvalidRangeError) -> Self {
e.to_string()
}
}
impl RustType<ProtoInvalidRangeError> for InvalidRangeError {
fn into_proto(&self) -> ProtoInvalidRangeError {
use proto_invalid_range_error::*;
use Kind::*;
let kind = match self {
InvalidRangeError::MisorderedRangeBounds => MisorderedRangeBounds(()),
InvalidRangeError::CanonicalizationOverflow(s) => {
CanonicalizationOverflow(s.into_proto())
}
InvalidRangeError::InvalidRangeBoundFlags => InvalidRangeBoundFlags(()),
InvalidRangeError::DiscontiguousUnion => DiscontiguousUnion(()),
InvalidRangeError::DiscontiguousDifference => DiscontiguousDifference(()),
InvalidRangeError::NullRangeBoundFlags => NullRangeBoundFlags(()),
};
ProtoInvalidRangeError { kind: Some(kind) }
}
fn from_proto(proto: ProtoInvalidRangeError) -> Result<Self, TryFromProtoError> {
use proto_invalid_range_error::Kind::*;
match proto.kind {
Some(kind) => Ok(match kind {
MisorderedRangeBounds(()) => InvalidRangeError::MisorderedRangeBounds,
CanonicalizationOverflow(s) => {
InvalidRangeError::CanonicalizationOverflow(s.into())
}
InvalidRangeBoundFlags(()) => InvalidRangeError::InvalidRangeBoundFlags,
DiscontiguousUnion(()) => InvalidRangeError::DiscontiguousUnion,
DiscontiguousDifference(()) => InvalidRangeError::DiscontiguousDifference,
NullRangeBoundFlags(()) => InvalidRangeError::NullRangeBoundFlags,
}),
None => Err(TryFromProtoError::missing_field(
"`ProtoInvalidRangeError::kind`",
)),
}
}
}
pub fn parse_range_bound_flags<'a>(flags: &'a str) -> Result<(bool, bool), InvalidRangeError> {
let mut flags = flags.chars();
let lower = match flags.next() {
Some('(') => false,
Some('[') => true,
_ => return Err(InvalidRangeError::InvalidRangeBoundFlags),
};
let upper = match flags.next() {
Some(')') => false,
Some(']') => true,
_ => return Err(InvalidRangeError::InvalidRangeBoundFlags),
};
match flags.next() {
Some(_) => Err(InvalidRangeError::InvalidRangeBoundFlags),
None => Ok((lower, upper)),
}
}
impl<'a, T: FromSql<'a>> FromSql<'a> for Range<T> {
fn from_sql(ty: &PgType, raw: &'a [u8]) -> Result<Range<T>, Box<dyn Error + Sync + Send>> {
let inner_typ = match ty {
&PgType::INT4_RANGE => PgType::INT4,
&PgType::INT8_RANGE => PgType::INT8,
&PgType::DATE_RANGE => PgType::DATE,
&PgType::NUM_RANGE => PgType::NUMERIC,
&PgType::TS_RANGE => PgType::TIMESTAMP,
&PgType::TSTZ_RANGE => PgType::TIMESTAMPTZ,
_ => unreachable!(),
};
let inner = match types::range_from_sql(raw)? {
types::Range::Empty => None,
types::Range::Nonempty(lower, upper) => {
let mut bounds = Vec::with_capacity(2);
for bound_outer in [lower, upper].into_iter() {
let bound = match bound_outer {
types::RangeBound::Exclusive(bound)
| types::RangeBound::Inclusive(bound) => bound
.map(|bound| T::from_sql(&inner_typ, bound))
.transpose()?,
types::RangeBound::Unbounded => None,
};
let inclusive = matches!(bound_outer, types::RangeBound::Inclusive(_));
bounds.push(RangeBound { bound, inclusive });
}
let lower = bounds.remove(0);
let upper = bounds.remove(0);
assert!(bounds.is_empty());
Some(RangeInner {
lower,
upper: RangeBound {
bound: upper.bound,
inclusive: upper.inclusive,
},
})
}
};
Ok(Range { inner })
}
fn accepts(ty: &PgType) -> bool {
matches!(
ty,
&PgType::INT4_RANGE
| &PgType::INT8_RANGE
| &PgType::DATE_RANGE
| &PgType::NUM_RANGE
| &PgType::TS_RANGE
| &PgType::TSTZ_RANGE
)
}
}