Skip to main content

mz_sql/session/vars/
constraints.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Defines constraints that can be imposed on variables.
11
12use std::fmt::Debug;
13use std::ops::{RangeBounds, RangeFrom, RangeInclusive};
14use std::time::Duration;
15
16use mz_repr::adt::numeric::Numeric;
17use mz_repr::bytes::ByteSize;
18
19use super::{Value, Var, VarError};
20
21pub static NUMERIC_NON_NEGATIVE: NumericNonNegNonNan = NumericNonNegNonNan;
22
23pub static NON_ZERO_DURATION: NonZeroDuration = NonZeroDuration;
24
25pub static NUMERIC_BOUNDED_0_1_INCLUSIVE: NumericInRange<RangeInclusive<f64>> =
26    NumericInRange(0.0f64..=1.0);
27
28pub static BYTESIZE_AT_LEAST_1MB: ByteSizeInRange<RangeFrom<ByteSize>> =
29    ByteSizeInRange(ByteSize::mb(1)..);
30
31#[derive(Debug)]
32pub enum ValueConstraint {
33    /// Variable is read-only and cannot be updated.
34    ReadOnly,
35    /// The variables value can be updated, but only to a fixed value.
36    Fixed,
37    // Arbitrary constraints over values.
38    Domain(&'static dyn DynDomainConstraint),
39}
40
41impl ValueConstraint {
42    pub fn check_constraint(
43        &self,
44        var: &dyn Var,
45        cur_value: &dyn Value,
46        new_value: &dyn Value,
47    ) -> Result<(), VarError> {
48        match self {
49            ValueConstraint::ReadOnly => return Err(VarError::ReadOnlyParameter(var.name())),
50            ValueConstraint::Fixed => {
51                if cur_value != new_value {
52                    return Err(VarError::FixedValueParameter {
53                        name: var.name(),
54                        value: cur_value.format(),
55                    });
56                }
57            }
58            ValueConstraint::Domain(check) => check.check(var, new_value)?,
59        }
60
61        Ok(())
62    }
63}
64
65impl Clone for ValueConstraint {
66    fn clone(&self) -> Self {
67        match self {
68            ValueConstraint::Fixed => ValueConstraint::Fixed,
69            ValueConstraint::ReadOnly => ValueConstraint::ReadOnly,
70            ValueConstraint::Domain(c) => ValueConstraint::Domain(*c),
71        }
72    }
73}
74
75/// A type erased version of [`DomainConstraint`] that we can reference on a [`VarDefinition`].
76///
77/// [`VarDefinition`]: crate::session::vars::definitions::VarDefinition
78pub trait DynDomainConstraint: Debug + Send + Sync + 'static {
79    fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError>;
80}
81
82impl<D> DynDomainConstraint for D
83where
84    D: DomainConstraint + Send + Sync + 'static,
85    D::Value: Value,
86{
87    fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError> {
88        let val = v
89            .as_any()
90            .downcast_ref::<D::Value>()
91            .expect("type should match");
92        self.check(var, val)
93    }
94}
95pub trait DomainConstraint: Debug + Send + Sync + 'static {
96    type Value;
97
98    fn check(&self, var: &dyn Var, v: &Self::Value) -> Result<(), VarError>;
99}
100
101#[derive(Debug, Clone, Eq, PartialEq)]
102pub struct NumericNonNegNonNan;
103
104impl DomainConstraint for NumericNonNegNonNan {
105    type Value = Numeric;
106
107    fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
108        if n.is_nan() || n.is_negative() {
109            Err(VarError::InvalidParameterValue {
110                name: var.name(),
111                invalid_values: vec![n.to_string()],
112                reason: "only supports non-negative, non-NaN numeric values".to_string(),
113            })
114        } else {
115            Ok(())
116        }
117    }
118}
119
120#[derive(Debug, Clone, Eq, PartialEq)]
121pub struct NonZeroDuration;
122
123impl DomainConstraint for NonZeroDuration {
124    type Value = Duration;
125
126    fn check(&self, var: &dyn Var, d: &Duration) -> Result<(), VarError> {
127        if d.is_zero() {
128            Err(VarError::InvalidParameterValue {
129                name: var.name(),
130                invalid_values: vec![format!("{:?}", d)],
131                reason: "only supports non-zero durations".to_string(),
132            })
133        } else {
134            Ok(())
135        }
136    }
137}
138
139#[derive(Debug, Clone, Eq, PartialEq)]
140pub struct NumericInRange<R>(pub R);
141
142impl<R> DomainConstraint for NumericInRange<R>
143where
144    R: RangeBounds<f64> + std::fmt::Debug + Send + Sync + 'static,
145{
146    type Value = Numeric;
147
148    fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
149        let n: f64 = (*n)
150            .try_into()
151            .map_err(|_e| VarError::InvalidParameterValue {
152                name: var.name(),
153                invalid_values: vec![n.to_string()],
154                // This first check can fail if the value is NaN, out of range,
155                // OR if it underflows (i.e. is very close to 0 without actually being 0, and the closest
156                // representable float is 0).
157                //
158                // The underflow case is very unlikely to be accidentally hit by a user, so let's
159                // not make the error message more confusing by talking about it, even though that makes
160                // the error message slightly inaccurate.
161                //
162                // If the user tries to set the paramater to 0.000<hundreds more zeros>001
163                // and gets the message "only supports values in range [0.0..=1.0]", I think they will
164                // understand, or at least accept, what's going on.
165                reason: format!("only supports values in range {:?}", self.0),
166            })?;
167        if !self.0.contains(&n) {
168            Err(VarError::InvalidParameterValue {
169                name: var.name(),
170                invalid_values: vec![n.to_string()],
171                reason: format!("only supports values in range {:?}", self.0),
172            })
173        } else {
174            Ok(())
175        }
176    }
177}
178
179#[derive(Debug, Clone, Eq, PartialEq)]
180pub struct ByteSizeInRange<R>(pub R);
181
182impl<R> DomainConstraint for ByteSizeInRange<R>
183where
184    R: RangeBounds<ByteSize> + std::fmt::Debug + Send + Sync + 'static,
185{
186    type Value = ByteSize;
187
188    fn check(&self, var: &dyn Var, size: &ByteSize) -> Result<(), VarError> {
189        if self.0.contains(size) {
190            Ok(())
191        } else {
192            Err(VarError::InvalidParameterValue {
193                name: var.name(),
194                invalid_values: vec![size.to_string()],
195                reason: format!("only supports values in range {:?}", self.0),
196            })
197        }
198    }
199}