mz_repr/adt/
regex.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Regular expressions.
11
12use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use mz_lowertest::MzReflect;
19use regex::{Error, RegexBuilder};
20use serde::de::Error as DeError;
21use serde::ser::SerializeStruct;
22use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
23
24/// The maximum size of a regex after compilation.
25/// This is the same as the `Regex` crate's default at the time of writing.
26///
27/// Note: This number is mentioned in our user-facing docs at the "String operators" in the function
28/// reference.
29const MAX_REGEX_SIZE_AFTER_COMPILATION: usize = 10 * 1024 * 1024;
30
31/// We also need a separate limit for the size of regexes before compilation. Even though the
32/// `Regex` crate promises that using its `size_limit` option (which we set to the other limit,
33/// `MAX_REGEX_SIZE_AFTER_COMPILATION`) would prevent excessive resource usage, this doesn't seem to
34/// be the case. Since we compile regexes in envd, we need strict limits to prevent envd OOMs.
35/// See <https://github.com/MaterializeInc/database-issues/issues/9907> for an example.
36///
37/// Note: This number is mentioned in our user-facing docs at the "String operators" in the function
38/// reference.
39const MAX_REGEX_SIZE_BEFORE_COMPILATION: usize = 1 * 1024 * 1024;
40
41/// A hashable, comparable, and serializable regular expression type.
42///
43/// The  [`regex::Regex`] type, the de facto standard regex type in Rust, does
44/// not implement [`PartialOrd`], [`Ord`] [`PartialEq`], [`Eq`], or [`Hash`].
45/// The omissions are reasonable. There is no natural definition of ordering for
46/// regexes. There *is* a natural definition of equality—whether two regexes
47/// describe the same regular language—but that is an expensive property to
48/// compute, and [`PartialEq`] is generally expected to be fast to compute.
49///
50/// This type wraps [`regex::Regex`] and imbues it with implementations of the
51/// above traits. Two regexes are considered equal iff their string
52/// representation is identical, plus flags, such as `case_insensitive`,
53/// are identical. The [`PartialOrd`], [`Ord`], and [`Hash`] implementations
54/// are similarly based upon the string representation plus flags. As
55/// mentioned above, this is not the natural equivalence relation for regexes: for
56/// example, the regexes `aa*` and `a+` define the same language, but would not
57/// compare as equal with this implementation of [`PartialEq`]. Still, it is
58/// often useful to have _some_ equivalence relation available (e.g., to store
59/// types containing regexes in a hashmap) even if the equivalence relation is
60/// imperfect.
61///
62/// [regex::Regex] is hard to serialize (because of the compiled code), so our approach is to
63/// instead serialize this wrapper struct, where we skip serializing the actual regex field, and
64/// we reconstruct the regex field from the other fields upon deserialization.
65/// (Earlier, serialization was buggy due to <https://github.com/tailhook/serde-regex/issues/14>,
66/// and also making the same mistake in our own protobuf serialization code.)
67#[derive(Debug, Clone, MzReflect)]
68pub struct Regex {
69    pub case_insensitive: bool,
70    pub dot_matches_new_line: bool,
71    pub regex: regex::Regex,
72}
73
74impl Regex {
75    /// A simple constructor for the default setting of `dot_matches_new_line: true`.
76    /// See <https://www.postgresql.org/docs/current/functions-matching.html#POSIX-MATCHING-RULES>
77    /// "newline-sensitive matching"
78    pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, RegexCompilationError> {
79        Self::new_dot_matches_new_line(pattern, case_insensitive, true)
80    }
81
82    /// Allows explicitly setting `dot_matches_new_line`.
83    pub fn new_dot_matches_new_line(
84        pattern: &str,
85        case_insensitive: bool,
86        dot_matches_new_line: bool,
87    ) -> Result<Regex, RegexCompilationError> {
88        if pattern.len() > MAX_REGEX_SIZE_BEFORE_COMPILATION {
89            return Err(RegexCompilationError::PatternTooLarge {
90                pattern_size: pattern.len(),
91            });
92        }
93        let mut regex_builder = RegexBuilder::new(pattern);
94        regex_builder.case_insensitive(case_insensitive);
95        regex_builder.dot_matches_new_line(dot_matches_new_line);
96        regex_builder.size_limit(MAX_REGEX_SIZE_AFTER_COMPILATION);
97        Ok(Regex {
98            case_insensitive,
99            dot_matches_new_line,
100            regex: regex_builder.build()?,
101        })
102    }
103
104    /// Returns the pattern string of the regex.
105    pub fn pattern(&self) -> &str {
106        // `as_str` returns the raw pattern as provided during construction,
107        // and doesn't include any of the flags.
108        self.regex.as_str()
109    }
110}
111
112/// Error type for regex compilation failures.
113#[derive(Debug, Clone)]
114pub enum RegexCompilationError {
115    /// Wrapper for regex crate's Error type.
116    RegexError(Error),
117    /// Regex pattern size exceeds MAX_REGEX_SIZE_BEFORE_COMPILATION.
118    PatternTooLarge { pattern_size: usize },
119}
120
121impl fmt::Display for RegexCompilationError {
122    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123        match self {
124            RegexCompilationError::RegexError(e) => write!(f, "{}", e),
125            RegexCompilationError::PatternTooLarge {
126                pattern_size: patter_size,
127            } => write!(
128                f,
129                "regex pattern too large ({} bytes, max {} bytes)",
130                patter_size, MAX_REGEX_SIZE_BEFORE_COMPILATION
131            ),
132        }
133    }
134}
135
136impl From<Error> for RegexCompilationError {
137    fn from(e: Error) -> Self {
138        RegexCompilationError::RegexError(e)
139    }
140}
141
142impl PartialEq<Regex> for Regex {
143    fn eq(&self, other: &Regex) -> bool {
144        self.pattern() == other.pattern()
145            && self.case_insensitive == other.case_insensitive
146            && self.dot_matches_new_line == other.dot_matches_new_line
147    }
148}
149
150impl Eq for Regex {}
151
152impl PartialOrd for Regex {
153    fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
154        Some(self.cmp(other))
155    }
156}
157
158impl Ord for Regex {
159    fn cmp(&self, other: &Regex) -> Ordering {
160        (
161            self.pattern(),
162            self.case_insensitive,
163            self.dot_matches_new_line,
164        )
165            .cmp(&(
166                other.pattern(),
167                other.case_insensitive,
168                other.dot_matches_new_line,
169            ))
170    }
171}
172
173impl Hash for Regex {
174    fn hash<H: Hasher>(&self, hasher: &mut H) {
175        self.pattern().hash(hasher);
176        self.case_insensitive.hash(hasher);
177        self.dot_matches_new_line.hash(hasher);
178    }
179}
180
181impl Deref for Regex {
182    type Target = regex::Regex;
183
184    fn deref(&self) -> &regex::Regex {
185        &self.regex
186    }
187}
188
189impl Serialize for Regex {
190    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191    where
192        S: Serializer,
193    {
194        let mut state = serializer.serialize_struct("Regex", 3)?;
195        state.serialize_field("pattern", &self.pattern())?;
196        state.serialize_field("case_insensitive", &self.case_insensitive)?;
197        state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
198        state.end()
199    }
200}
201
202impl<'de> Deserialize<'de> for Regex {
203    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
204    where
205        D: Deserializer<'de>,
206    {
207        enum Field {
208            Pattern,
209            CaseInsensitive,
210            DotMatchesNewLine,
211        }
212
213        impl<'de> Deserialize<'de> for Field {
214            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
215            where
216                D: Deserializer<'de>,
217            {
218                struct FieldVisitor;
219
220                impl<'de> de::Visitor<'de> for FieldVisitor {
221                    type Value = Field;
222
223                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
224                        formatter.write_str(
225                            "pattern string or case_insensitive bool or dot_matches_new_line bool",
226                        )
227                    }
228
229                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
230                    where
231                        E: de::Error,
232                    {
233                        match value {
234                            "pattern" => Ok(Field::Pattern),
235                            "case_insensitive" => Ok(Field::CaseInsensitive),
236                            "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
237                            _ => Err(de::Error::unknown_field(value, FIELDS)),
238                        }
239                    }
240                }
241
242                deserializer.deserialize_identifier(FieldVisitor)
243            }
244        }
245
246        struct RegexVisitor;
247
248        impl<'de> de::Visitor<'de> for RegexVisitor {
249            type Value = Regex;
250
251            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
252                formatter.write_str("Regex serialized by the manual Serialize impl from above")
253            }
254
255            fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
256            where
257                V: de::SeqAccess<'de>,
258            {
259                let pattern = seq
260                    .next_element::<Cow<str>>()?
261                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
262                let case_insensitive = seq
263                    .next_element()?
264                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
265                let dot_matches_new_line = seq
266                    .next_element()?
267                    .ok_or_else(|| de::Error::invalid_length(2, &self))?;
268                Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
269                    .map_err(|err| {
270                        V::Error::custom(format!(
271                            "Unable to recreate regex during deserialization: {}",
272                            err
273                        ))
274                    })
275            }
276
277            fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
278            where
279                V: de::MapAccess<'de>,
280            {
281                let mut pattern: Option<Cow<str>> = None;
282                let mut case_insensitive: Option<bool> = None;
283                let mut dot_matches_new_line: Option<bool> = None;
284                while let Some(key) = map.next_key()? {
285                    match key {
286                        Field::Pattern => {
287                            if pattern.is_some() {
288                                return Err(de::Error::duplicate_field("pattern"));
289                            }
290                            pattern = Some(map.next_value()?);
291                        }
292                        Field::CaseInsensitive => {
293                            if case_insensitive.is_some() {
294                                return Err(de::Error::duplicate_field("case_insensitive"));
295                            }
296                            case_insensitive = Some(map.next_value()?);
297                        }
298                        Field::DotMatchesNewLine => {
299                            if dot_matches_new_line.is_some() {
300                                return Err(de::Error::duplicate_field("dot_matches_new_line"));
301                            }
302                            dot_matches_new_line = Some(map.next_value()?);
303                        }
304                    }
305                }
306                let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
307                let case_insensitive =
308                    case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
309                let dot_matches_new_line = dot_matches_new_line
310                    .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
311                Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
312                    .map_err(|err| {
313                        V::Error::custom(format!(
314                            "Unable to recreate regex during deserialization: {}",
315                            err
316                        ))
317                    })
318            }
319        }
320
321        const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
322        deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    /// This was failing before due to the derived serde serialization being incorrect, because of
331    /// <https://github.com/tailhook/serde-regex/issues/14>.
332    /// Nowadays, we use our own handwritten Serialize/Deserialize impls for our Regex wrapper struct.
333    #[mz_ore::test]
334    fn regex_serde_case_insensitive() {
335        let pattern = "AAA";
336        let orig_regex = Regex::new(pattern, true).unwrap();
337        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
338        let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
339        // Equality test between orig and roundtrip_result wouldn't work, because Eq doesn't test
340        // the actual regex object. So test the actual regex functionality (concentrating on case
341        // sensitivity).
342        assert_eq!(orig_regex.regex.is_match("aaa"), true);
343        assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
344        assert_eq!(pattern, roundtrip_result.pattern());
345    }
346
347    /// Test the roundtripping of `dot_matches_new_line`.
348    /// (Similar to the above `regex_serde_case_insensitive`.)
349    #[mz_ore::test]
350    fn regex_serde_dot_matches_new_line() {
351        {
352            // dot_matches_new_line: true
353            let pattern = "A.*B";
354            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
355            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
356            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
357            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
358            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
359            assert_eq!(pattern, roundtrip_result.pattern());
360        }
361        {
362            // dot_matches_new_line: false
363            let pattern = "A.*B";
364            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
365            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
366            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
367            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
368            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
369            assert_eq!(pattern, roundtrip_result.pattern());
370        }
371        {
372            // dot_matches_new_line: default
373            let pattern = "A.*B";
374            let orig_regex = Regex::new(pattern, true).unwrap();
375            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
376            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
377            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
378            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
379            assert_eq!(pattern, roundtrip_result.pattern());
380        }
381    }
382
383    #[mz_ore::test]
384    fn regex_serde_from_reader() {
385        let pattern = "A.*B";
386        let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
387
388        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
389        let roundtrip_result: Regex = serde_json::from_reader(serialized.as_bytes()).unwrap();
390
391        assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
392        assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
393        assert_eq!(pattern, roundtrip_result.pattern());
394
395        let serialized = bincode::serialize(&orig_regex).unwrap();
396        let roundtrip_result: Regex = bincode::deserialize_from(&*serialized).unwrap();
397
398        assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
399        assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
400        assert_eq!(pattern, roundtrip_result.pattern());
401    }
402}