mz_sql/session/vars/
constraints.rs1use std::fmt::Debug;
13use std::ops::{RangeBounds, RangeFrom, RangeInclusive};
14use std::time::Duration;
15
16use mz_repr::adt::numeric::Numeric;
17use mz_repr::bytes::ByteSize;
18
19use super::{Value, Var, VarError};
20
21pub static NUMERIC_NON_NEGATIVE: NumericNonNegNonNan = NumericNonNegNonNan;
22
23pub static NON_ZERO_DURATION: NonZeroDuration = NonZeroDuration;
24
25pub static NUMERIC_BOUNDED_0_1_INCLUSIVE: NumericInRange<RangeInclusive<f64>> =
26 NumericInRange(0.0f64..=1.0);
27
28pub static BYTESIZE_AT_LEAST_1MB: ByteSizeInRange<RangeFrom<ByteSize>> =
29 ByteSizeInRange(ByteSize::mb(1)..);
30
31#[derive(Debug)]
32pub enum ValueConstraint {
33 ReadOnly,
35 Fixed,
37 Domain(&'static dyn DynDomainConstraint),
39}
40
41impl ValueConstraint {
42 pub fn check_constraint(
43 &self,
44 var: &dyn Var,
45 cur_value: &dyn Value,
46 new_value: &dyn Value,
47 ) -> Result<(), VarError> {
48 match self {
49 ValueConstraint::ReadOnly => return Err(VarError::ReadOnlyParameter(var.name())),
50 ValueConstraint::Fixed => {
51 if cur_value != new_value {
52 return Err(VarError::FixedValueParameter {
53 name: var.name(),
54 value: cur_value.format(),
55 });
56 }
57 }
58 ValueConstraint::Domain(check) => check.check(var, new_value)?,
59 }
60
61 Ok(())
62 }
63}
64
65impl Clone for ValueConstraint {
66 fn clone(&self) -> Self {
67 match self {
68 ValueConstraint::Fixed => ValueConstraint::Fixed,
69 ValueConstraint::ReadOnly => ValueConstraint::ReadOnly,
70 ValueConstraint::Domain(c) => ValueConstraint::Domain(*c),
71 }
72 }
73}
74
75pub trait DynDomainConstraint: Debug + Send + Sync + 'static {
79 fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError>;
80}
81
82impl<D> DynDomainConstraint for D
83where
84 D: DomainConstraint + Send + Sync + 'static,
85 D::Value: Value,
86{
87 fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError> {
88 let val = v
89 .as_any()
90 .downcast_ref::<D::Value>()
91 .expect("type should match");
92 self.check(var, val)
93 }
94}
95pub trait DomainConstraint: Debug + Send + Sync + 'static {
96 type Value;
97
98 fn check(&self, var: &dyn Var, v: &Self::Value) -> Result<(), VarError>;
99}
100
101#[derive(Debug, Clone, Eq, PartialEq)]
102pub struct NumericNonNegNonNan;
103
104impl DomainConstraint for NumericNonNegNonNan {
105 type Value = Numeric;
106
107 fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
108 if n.is_nan() || n.is_negative() {
109 Err(VarError::InvalidParameterValue {
110 name: var.name(),
111 invalid_values: vec![n.to_string()],
112 reason: "only supports non-negative, non-NaN numeric values".to_string(),
113 })
114 } else {
115 Ok(())
116 }
117 }
118}
119
120#[derive(Debug, Clone, Eq, PartialEq)]
121pub struct NonZeroDuration;
122
123impl DomainConstraint for NonZeroDuration {
124 type Value = Duration;
125
126 fn check(&self, var: &dyn Var, d: &Duration) -> Result<(), VarError> {
127 if d.is_zero() {
128 Err(VarError::InvalidParameterValue {
129 name: var.name(),
130 invalid_values: vec![format!("{:?}", d)],
131 reason: "only supports non-zero durations".to_string(),
132 })
133 } else {
134 Ok(())
135 }
136 }
137}
138
139#[derive(Debug, Clone, Eq, PartialEq)]
140pub struct NumericInRange<R>(pub R);
141
142impl<R> DomainConstraint for NumericInRange<R>
143where
144 R: RangeBounds<f64> + std::fmt::Debug + Send + Sync + 'static,
145{
146 type Value = Numeric;
147
148 fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
149 let n: f64 = (*n)
150 .try_into()
151 .map_err(|_e| VarError::InvalidParameterValue {
152 name: var.name(),
153 invalid_values: vec![n.to_string()],
154 reason: format!("only supports values in range {:?}", self.0),
166 })?;
167 if !self.0.contains(&n) {
168 Err(VarError::InvalidParameterValue {
169 name: var.name(),
170 invalid_values: vec![n.to_string()],
171 reason: format!("only supports values in range {:?}", self.0),
172 })
173 } else {
174 Ok(())
175 }
176 }
177}
178
179#[derive(Debug, Clone, Eq, PartialEq)]
180pub struct ByteSizeInRange<R>(pub R);
181
182impl<R> DomainConstraint for ByteSizeInRange<R>
183where
184 R: RangeBounds<ByteSize> + std::fmt::Debug + Send + Sync + 'static,
185{
186 type Value = ByteSize;
187
188 fn check(&self, var: &dyn Var, size: &ByteSize) -> Result<(), VarError> {
189 if self.0.contains(size) {
190 Ok(())
191 } else {
192 Err(VarError::InvalidParameterValue {
193 name: var.name(),
194 invalid_values: vec![size.to_string()],
195 reason: format!("only supports values in range {:?}", self.0),
196 })
197 }
198 }
199}