mz_repr/adt/
regex.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Regular expressions.
11
12use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use mz_lowertest::MzReflect;
19use regex::{Error, RegexBuilder};
20use serde::de::Error as DeError;
21use serde::ser::SerializeStruct;
22use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
23
24/// A hashable, comparable, and serializable regular expression type.
25///
26/// The  [`regex::Regex`] type, the de facto standard regex type in Rust, does
27/// not implement [`PartialOrd`], [`Ord`] [`PartialEq`], [`Eq`], or [`Hash`].
28/// The omissions are reasonable. There is no natural definition of ordering for
29/// regexes. There *is* a natural definition of equality—whether two regexes
30/// describe the same regular language—but that is an expensive property to
31/// compute, and [`PartialEq`] is generally expected to be fast to compute.
32///
33/// This type wraps [`regex::Regex`] and imbues it with implementations of the
34/// above traits. Two regexes are considered equal iff their string
35/// representation is identical, plus flags, such as `case_insensitive`,
36/// are identical. The [`PartialOrd`], [`Ord`], and [`Hash`] implementations
37/// are similarly based upon the string representation plus flags. As
38/// mentioned above, this is not the natural equivalence relation for regexes: for
39/// example, the regexes `aa*` and `a+` define the same language, but would not
40/// compare as equal with this implementation of [`PartialEq`]. Still, it is
41/// often useful to have _some_ equivalence relation available (e.g., to store
42/// types containing regexes in a hashmap) even if the equivalence relation is
43/// imperfect.
44///
45/// [regex::Regex] is hard to serialize (because of the compiled code), so our approach is to
46/// instead serialize this wrapper struct, where we skip serializing the actual regex field, and
47/// we reconstruct the regex field from the other fields upon deserialization.
48/// (Earlier, serialization was buggy due to <https://github.com/tailhook/serde-regex/issues/14>,
49/// and also making the same mistake in our own protobuf serialization code.)
50#[derive(Debug, Clone, MzReflect)]
51pub struct Regex {
52    pub case_insensitive: bool,
53    pub dot_matches_new_line: bool,
54    pub regex: regex::Regex,
55}
56
57impl Regex {
58    /// A simple constructor for the default setting of `dot_matches_new_line: true`.
59    /// See <https://www.postgresql.org/docs/current/functions-matching.html#POSIX-MATCHING-RULES>
60    /// "newline-sensitive matching"
61    pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
62        Self::new_dot_matches_new_line(pattern, case_insensitive, true)
63    }
64
65    /// Allows explicitly setting `dot_matches_new_line`.
66    pub fn new_dot_matches_new_line(
67        pattern: &str,
68        case_insensitive: bool,
69        dot_matches_new_line: bool,
70    ) -> Result<Regex, Error> {
71        let mut regex_builder = RegexBuilder::new(pattern);
72        regex_builder.case_insensitive(case_insensitive);
73        regex_builder.dot_matches_new_line(dot_matches_new_line);
74        Ok(Regex {
75            case_insensitive,
76            dot_matches_new_line,
77            regex: regex_builder.build()?,
78        })
79    }
80
81    /// Returns the pattern string of the regex.
82    pub fn pattern(&self) -> &str {
83        // `as_str` returns the raw pattern as provided during construction,
84        // and doesn't include any of the flags.
85        self.regex.as_str()
86    }
87}
88
89impl PartialEq<Regex> for Regex {
90    fn eq(&self, other: &Regex) -> bool {
91        self.pattern() == other.pattern()
92            && self.case_insensitive == other.case_insensitive
93            && self.dot_matches_new_line == other.dot_matches_new_line
94    }
95}
96
97impl Eq for Regex {}
98
99impl PartialOrd for Regex {
100    fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
101        Some(self.cmp(other))
102    }
103}
104
105impl Ord for Regex {
106    fn cmp(&self, other: &Regex) -> Ordering {
107        (
108            self.pattern(),
109            self.case_insensitive,
110            self.dot_matches_new_line,
111        )
112            .cmp(&(
113                other.pattern(),
114                other.case_insensitive,
115                other.dot_matches_new_line,
116            ))
117    }
118}
119
120impl Hash for Regex {
121    fn hash<H: Hasher>(&self, hasher: &mut H) {
122        self.pattern().hash(hasher);
123        self.case_insensitive.hash(hasher);
124        self.dot_matches_new_line.hash(hasher);
125    }
126}
127
128impl Deref for Regex {
129    type Target = regex::Regex;
130
131    fn deref(&self) -> &regex::Regex {
132        &self.regex
133    }
134}
135
136impl Serialize for Regex {
137    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138    where
139        S: Serializer,
140    {
141        let mut state = serializer.serialize_struct("Regex", 3)?;
142        state.serialize_field("pattern", &self.pattern())?;
143        state.serialize_field("case_insensitive", &self.case_insensitive)?;
144        state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
145        state.end()
146    }
147}
148
149impl<'de> Deserialize<'de> for Regex {
150    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
151    where
152        D: Deserializer<'de>,
153    {
154        enum Field {
155            Pattern,
156            CaseInsensitive,
157            DotMatchesNewLine,
158        }
159
160        impl<'de> Deserialize<'de> for Field {
161            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
162            where
163                D: Deserializer<'de>,
164            {
165                struct FieldVisitor;
166
167                impl<'de> de::Visitor<'de> for FieldVisitor {
168                    type Value = Field;
169
170                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
171                        formatter.write_str(
172                            "pattern string or case_insensitive bool or dot_matches_new_line bool",
173                        )
174                    }
175
176                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
177                    where
178                        E: de::Error,
179                    {
180                        match value {
181                            "pattern" => Ok(Field::Pattern),
182                            "case_insensitive" => Ok(Field::CaseInsensitive),
183                            "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
184                            _ => Err(de::Error::unknown_field(value, FIELDS)),
185                        }
186                    }
187                }
188
189                deserializer.deserialize_identifier(FieldVisitor)
190            }
191        }
192
193        struct RegexVisitor;
194
195        impl<'de> de::Visitor<'de> for RegexVisitor {
196            type Value = Regex;
197
198            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
199                formatter.write_str("Regex serialized by the manual Serialize impl from above")
200            }
201
202            fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
203            where
204                V: de::SeqAccess<'de>,
205            {
206                let pattern = seq
207                    .next_element::<Cow<str>>()?
208                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
209                let case_insensitive = seq
210                    .next_element()?
211                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
212                let dot_matches_new_line = seq
213                    .next_element()?
214                    .ok_or_else(|| de::Error::invalid_length(2, &self))?;
215                Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
216                    .map_err(|err| {
217                        V::Error::custom(format!(
218                            "Unable to recreate regex during deserialization: {}",
219                            err
220                        ))
221                    })
222            }
223
224            fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
225            where
226                V: de::MapAccess<'de>,
227            {
228                let mut pattern: Option<Cow<str>> = None;
229                let mut case_insensitive: Option<bool> = None;
230                let mut dot_matches_new_line: Option<bool> = None;
231                while let Some(key) = map.next_key()? {
232                    match key {
233                        Field::Pattern => {
234                            if pattern.is_some() {
235                                return Err(de::Error::duplicate_field("pattern"));
236                            }
237                            pattern = Some(map.next_value()?);
238                        }
239                        Field::CaseInsensitive => {
240                            if case_insensitive.is_some() {
241                                return Err(de::Error::duplicate_field("case_insensitive"));
242                            }
243                            case_insensitive = Some(map.next_value()?);
244                        }
245                        Field::DotMatchesNewLine => {
246                            if dot_matches_new_line.is_some() {
247                                return Err(de::Error::duplicate_field("dot_matches_new_line"));
248                            }
249                            dot_matches_new_line = Some(map.next_value()?);
250                        }
251                    }
252                }
253                let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
254                let case_insensitive =
255                    case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
256                let dot_matches_new_line = dot_matches_new_line
257                    .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
258                Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
259                    .map_err(|err| {
260                        V::Error::custom(format!(
261                            "Unable to recreate regex during deserialization: {}",
262                            err
263                        ))
264                    })
265            }
266        }
267
268        const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
269        deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    /// This was failing before due to the derived serde serialization being incorrect, because of
278    /// <https://github.com/tailhook/serde-regex/issues/14>.
279    /// Nowadays, we use our own handwritten Serialize/Deserialize impls for our Regex wrapper struct.
280    #[mz_ore::test]
281    fn regex_serde_case_insensitive() {
282        let pattern = "AAA";
283        let orig_regex = Regex::new(pattern, true).unwrap();
284        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
285        let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
286        // Equality test between orig and roundtrip_result wouldn't work, because Eq doesn't test
287        // the actual regex object. So test the actual regex functionality (concentrating on case
288        // sensitivity).
289        assert_eq!(orig_regex.regex.is_match("aaa"), true);
290        assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
291        assert_eq!(pattern, roundtrip_result.pattern());
292    }
293
294    /// Test the roundtripping of `dot_matches_new_line`.
295    /// (Similar to the above `regex_serde_case_insensitive`.)
296    #[mz_ore::test]
297    fn regex_serde_dot_matches_new_line() {
298        {
299            // dot_matches_new_line: true
300            let pattern = "A.*B";
301            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
302            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
303            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
304            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
305            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
306            assert_eq!(pattern, roundtrip_result.pattern());
307        }
308        {
309            // dot_matches_new_line: false
310            let pattern = "A.*B";
311            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
312            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
313            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
314            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
315            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
316            assert_eq!(pattern, roundtrip_result.pattern());
317        }
318        {
319            // dot_matches_new_line: default
320            let pattern = "A.*B";
321            let orig_regex = Regex::new(pattern, true).unwrap();
322            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
323            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
324            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
325            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
326            assert_eq!(pattern, roundtrip_result.pattern());
327        }
328    }
329
330    #[mz_ore::test]
331    fn regex_serde_from_reader() {
332        let pattern = "A.*B";
333        let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
334
335        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
336        let roundtrip_result: Regex = serde_json::from_reader(serialized.as_bytes()).unwrap();
337
338        assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
339        assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
340        assert_eq!(pattern, roundtrip_result.pattern());
341
342        let serialized = bincode::serialize(&orig_regex).unwrap();
343        let roundtrip_result: Regex = bincode::deserialize_from(&*serialized).unwrap();
344
345        assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
346        assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
347        assert_eq!(pattern, roundtrip_result.pattern());
348    }
349}