mz_sql/session/vars/
constraints.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Defines constraints that can be imposed on variables.
11
12use std::fmt::Debug;
13use std::ops::{RangeBounds, RangeFrom, RangeInclusive};
14
15use mz_repr::adt::numeric::Numeric;
16use mz_repr::bytes::ByteSize;
17
18use super::{Value, Var, VarError};
19
20pub static NUMERIC_NON_NEGATIVE: NumericNonNegNonNan = NumericNonNegNonNan;
21
22pub static NUMERIC_BOUNDED_0_1_INCLUSIVE: NumericInRange<RangeInclusive<f64>> =
23    NumericInRange(0.0f64..=1.0);
24
25pub static BYTESIZE_AT_LEAST_1MB: ByteSizeInRange<RangeFrom<ByteSize>> =
26    ByteSizeInRange(ByteSize::mb(1)..);
27
28#[derive(Debug)]
29pub enum ValueConstraint {
30    /// Variable is read-only and cannot be updated.
31    ReadOnly,
32    /// The variables value can be updated, but only to a fixed value.
33    Fixed,
34    // Arbitrary constraints over values.
35    Domain(&'static dyn DynDomainConstraint),
36}
37
38impl ValueConstraint {
39    pub fn check_constraint(
40        &self,
41        var: &dyn Var,
42        cur_value: &dyn Value,
43        new_value: &dyn Value,
44    ) -> Result<(), VarError> {
45        match self {
46            ValueConstraint::ReadOnly => return Err(VarError::ReadOnlyParameter(var.name())),
47            ValueConstraint::Fixed => {
48                if cur_value != new_value {
49                    return Err(VarError::FixedValueParameter {
50                        name: var.name(),
51                        value: cur_value.format(),
52                    });
53                }
54            }
55            ValueConstraint::Domain(check) => check.check(var, new_value)?,
56        }
57
58        Ok(())
59    }
60}
61
62impl Clone for ValueConstraint {
63    fn clone(&self) -> Self {
64        match self {
65            ValueConstraint::Fixed => ValueConstraint::Fixed,
66            ValueConstraint::ReadOnly => ValueConstraint::ReadOnly,
67            ValueConstraint::Domain(c) => ValueConstraint::Domain(*c),
68        }
69    }
70}
71
72/// A type erased version of [`DomainConstraint`] that we can reference on a [`VarDefinition`].
73///
74/// [`VarDefinition`]: crate::session::vars::definitions::VarDefinition
75pub trait DynDomainConstraint: Debug + Send + Sync + 'static {
76    fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError>;
77}
78
79impl<D> DynDomainConstraint for D
80where
81    D: DomainConstraint + Send + Sync + 'static,
82    D::Value: Value,
83{
84    fn check(&self, var: &dyn Var, v: &dyn Value) -> Result<(), VarError> {
85        let val = v
86            .as_any()
87            .downcast_ref::<D::Value>()
88            .expect("type should match");
89        self.check(var, val)
90    }
91}
92pub trait DomainConstraint: Debug + Send + Sync + 'static {
93    type Value;
94
95    fn check(&self, var: &dyn Var, v: &Self::Value) -> Result<(), VarError>;
96}
97
98#[derive(Debug, Clone, Eq, PartialEq)]
99pub struct NumericNonNegNonNan;
100
101impl DomainConstraint for NumericNonNegNonNan {
102    type Value = Numeric;
103
104    fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
105        if n.is_nan() || n.is_negative() {
106            Err(VarError::InvalidParameterValue {
107                name: var.name(),
108                invalid_values: vec![n.to_string()],
109                reason: "only supports non-negative, non-NaN numeric values".to_string(),
110            })
111        } else {
112            Ok(())
113        }
114    }
115}
116
117#[derive(Debug, Clone, Eq, PartialEq)]
118pub struct NumericInRange<R>(pub R);
119
120impl<R> DomainConstraint for NumericInRange<R>
121where
122    R: RangeBounds<f64> + std::fmt::Debug + Send + Sync + 'static,
123{
124    type Value = Numeric;
125
126    fn check(&self, var: &dyn Var, n: &Numeric) -> Result<(), VarError> {
127        let n: f64 = (*n)
128            .try_into()
129            .map_err(|_e| VarError::InvalidParameterValue {
130                name: var.name(),
131                invalid_values: vec![n.to_string()],
132                // This first check can fail if the value is NaN, out of range,
133                // OR if it underflows (i.e. is very close to 0 without actually being 0, and the closest
134                // representable float is 0).
135                //
136                // The underflow case is very unlikely to be accidentally hit by a user, so let's
137                // not make the error message more confusing by talking about it, even though that makes
138                // the error message slightly inaccurate.
139                //
140                // If the user tries to set the paramater to 0.000<hundreds more zeros>001
141                // and gets the message "only supports values in range [0.0..=1.0]", I think they will
142                // understand, or at least accept, what's going on.
143                reason: format!("only supports values in range {:?}", self.0),
144            })?;
145        if !self.0.contains(&n) {
146            Err(VarError::InvalidParameterValue {
147                name: var.name(),
148                invalid_values: vec![n.to_string()],
149                reason: format!("only supports values in range {:?}", self.0),
150            })
151        } else {
152            Ok(())
153        }
154    }
155}
156
157#[derive(Debug, Clone, Eq, PartialEq)]
158pub struct ByteSizeInRange<R>(pub R);
159
160impl<R> DomainConstraint for ByteSizeInRange<R>
161where
162    R: RangeBounds<ByteSize> + std::fmt::Debug + Send + Sync + 'static,
163{
164    type Value = ByteSize;
165
166    fn check(&self, var: &dyn Var, size: &ByteSize) -> Result<(), VarError> {
167        if self.0.contains(size) {
168            Ok(())
169        } else {
170            Err(VarError::InvalidParameterValue {
171                name: var.name(),
172                invalid_values: vec![size.to_string()],
173                reason: format!("only supports values in range {:?}", self.0),
174            })
175        }
176    }
177}