mz_sql/session/vars/
constraints.rs1use std::fmt::Debug;
13use std::ops::{RangeBounds, RangeFrom, RangeInclusive};
14
15use mz_repr::adt::numeric::Numeric;
16use mz_repr::bytes::ByteSize;
17
18use super::{Value, Var, VarError};
19
20pub static NUMERIC_NON_NEGATIVE: NumericNonNegNonNan = NumericNonNegNonNan;
21
22pub static NUMERIC_BOUNDED_0_1_INCLUSIVE: NumericInRange<RangeInclusive<f64>> =
23 NumericInRange(0.0f64..=1.0);
24
25pub static BYTESIZE_AT_LEAST_1MB: ByteSizeInRange<RangeFrom<ByteSize>> =
26 ByteSizeInRange(ByteSize::mb(1)..);
27
28#[derive(Debug)]
29pub enum ValueConstraint {
30 ReadOnly,
32 Fixed,
34 Domain(&'static dyn DynDomainConstraint),
36}
37
38impl ValueConstraint {
39 pub fn check_constraint(
40 &self,
41 var: &dyn Var,
42 cur_value: &dyn Value,
43 new_value: &dyn Value,
44 ) -> Result<(), VarError> {
45 match self {
46 ValueConstraint::ReadOnly => return Err(VarError::ReadOnlyParameter(var.name())),
47 ValueConstraint::Fixed => {
48 if cur_value != new_value {
49 return Err(VarError::FixedValueParameter {
50 name: var.name(),
51 value: cur_value.format(),
52 });
53 }
54 }
55 ValueConstraint::Domain(check) => check.check(var, new_value)?,
56 }
57
58 Ok(())
59 }
60}
61
62impl Clone for ValueConstraint {
63 fn clone(&self) -> Self {
64 match self {
65 ValueConstraint::Fixed => ValueConstraint::Fixed,
66 ValueConstraint::ReadOnly => ValueConstraint::ReadOnly,
67 ValueConstraint::Domain(c) => ValueConstraint::Domain(*c),
68 }
69 }
70}
71
72pub trait DynDomainConstraint: Debug + Send + Sync + 'static {
76 fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError>;
77}
78
79impl<D> DynDomainConstraint for D
80where
81 D: DomainConstraint + Send + Sync + 'static,
82 D::Value: Value,
83{
84 fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError> {
85 let val = v
86 .as_any()
87 .downcast_ref::<D::Value>()
88 .expect("type should match");
89 self.check(var, val)
90 }
91}
92pub trait DomainConstraint: Debug + Send + Sync + 'static {
93 type Value;
94
95 fn check(&self, var: &dyn Var, v: &Self::Value) -> Result<(), VarError>;
96}
97
98#[derive(Debug, Clone, Eq, PartialEq)]
99pub struct NumericNonNegNonNan;
100
101impl DomainConstraint for NumericNonNegNonNan {
102 type Value = Numeric;
103
104 fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
105 if n.is_nan() || n.is_negative() {
106 Err(VarError::InvalidParameterValue {
107 name: var.name(),
108 invalid_values: vec![n.to_string()],
109 reason: "only supports non-negative, non-NaN numeric values".to_string(),
110 })
111 } else {
112 Ok(())
113 }
114 }
115}
116
117#[derive(Debug, Clone, Eq, PartialEq)]
118pub struct NumericInRange<R>(pub R);
119
120impl<R> DomainConstraint for NumericInRange<R>
121where
122 R: RangeBounds<f64> + std::fmt::Debug + Send + Sync + 'static,
123{
124 type Value = Numeric;
125
126 fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
127 let n: f64 = (*n)
128 .try_into()
129 .map_err(|_e| VarError::InvalidParameterValue {
130 name: var.name(),
131 invalid_values: vec![n.to_string()],
132 reason: format!("only supports values in range {:?}", self.0),
144 })?;
145 if !self.0.contains(&n) {
146 Err(VarError::InvalidParameterValue {
147 name: var.name(),
148 invalid_values: vec![n.to_string()],
149 reason: format!("only supports values in range {:?}", self.0),
150 })
151 } else {
152 Ok(())
153 }
154 }
155}
156
157#[derive(Debug, Clone, Eq, PartialEq)]
158pub struct ByteSizeInRange<R>(pub R);
159
160impl<R> DomainConstraint for ByteSizeInRange<R>
161where
162 R: RangeBounds<ByteSize> + std::fmt::Debug + Send + Sync + 'static,
163{
164 type Value = ByteSize;
165
166 fn check(&self, var: &dyn Var, size: &ByteSize) -> Result<(), VarError> {
167 if self.0.contains(size) {
168 Ok(())
169 } else {
170 Err(VarError::InvalidParameterValue {
171 name: var.name(),
172 invalid_values: vec![size.to_string()],
173 reason: format!("only supports values in range {:?}", self.0),
174 })
175 }
176 }
177}