mz_repr/adt/
regex.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Regular expressions.
11
12use std::cmp::Ordering;
13use std::fmt;
14use std::hash::{Hash, Hasher};
15use std::ops::Deref;
16
17use mz_lowertest::MzReflect;
18use mz_proto::{RustType, TryFromProtoError};
19use proptest::prelude::any;
20use proptest::prop_compose;
21use regex::{Error, RegexBuilder};
22use serde::de::Error as DeError;
23use serde::ser::SerializeStruct;
24use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
25
26include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.regex.rs"));
27
28/// A hashable, comparable, and serializable regular expression type.
29///
30/// The  [`regex::Regex`] type, the de facto standard regex type in Rust, does
31/// not implement [`PartialOrd`], [`Ord`] [`PartialEq`], [`Eq`], or [`Hash`].
32/// The omissions are reasonable. There is no natural definition of ordering for
33/// regexes. There *is* a natural definition of equality—whether two regexes
34/// describe the same regular language—but that is an expensive property to
35/// compute, and [`PartialEq`] is generally expected to be fast to compute.
36///
37/// This type wraps [`regex::Regex`] and imbues it with implementations of the
38/// above traits. Two regexes are considered equal iff their string
39/// representation is identical, plus flags, such as `case_insensitive`,
40/// are identical. The [`PartialOrd`], [`Ord`], and [`Hash`] implementations
41/// are similarly based upon the string representation plus flags. As
42/// mentioned above, this is not the natural equivalence relation for regexes: for
43/// example, the regexes `aa*` and `a+` define the same language, but would not
44/// compare as equal with this implementation of [`PartialEq`]. Still, it is
45/// often useful to have _some_ equivalence relation available (e.g., to store
46/// types containing regexes in a hashmap) even if the equivalence relation is
47/// imperfect.
48///
49/// [regex::Regex] is hard to serialize (because of the compiled code), so our approach is to
50/// instead serialize this wrapper struct, where we skip serializing the actual regex field, and
51/// we reconstruct the regex field from the other fields upon deserialization.
52/// (Earlier, serialization was buggy due to <https://github.com/tailhook/serde-regex/issues/14>,
53/// and also making the same mistake in our own protobuf serialization code.)
54#[derive(Debug, Clone, MzReflect)]
55pub struct Regex {
56    pub case_insensitive: bool,
57    pub dot_matches_new_line: bool,
58    pub regex: regex::Regex,
59}
60
61impl Regex {
62    /// A simple constructor for the default setting of `dot_matches_new_line: true`.
63    /// See <https://www.postgresql.org/docs/current/functions-matching.html#POSIX-MATCHING-RULES>
64    /// "newline-sensitive matching"
65    pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
66        Self::new_dot_matches_new_line(pattern, case_insensitive, true)
67    }
68
69    /// Allows explicitly setting `dot_matches_new_line`.
70    pub fn new_dot_matches_new_line(
71        pattern: &str,
72        case_insensitive: bool,
73        dot_matches_new_line: bool,
74    ) -> Result<Regex, Error> {
75        let mut regex_builder = RegexBuilder::new(pattern);
76        regex_builder.case_insensitive(case_insensitive);
77        regex_builder.dot_matches_new_line(dot_matches_new_line);
78        Ok(Regex {
79            case_insensitive,
80            dot_matches_new_line,
81            regex: regex_builder.build()?,
82        })
83    }
84
85    /// Returns the pattern string of the regex.
86    pub fn pattern(&self) -> &str {
87        // `as_str` returns the raw pattern as provided during construction,
88        // and doesn't include any of the flags.
89        self.regex.as_str()
90    }
91}
92
93impl PartialEq<Regex> for Regex {
94    fn eq(&self, other: &Regex) -> bool {
95        self.pattern() == other.pattern()
96            && self.case_insensitive == other.case_insensitive
97            && self.dot_matches_new_line == other.dot_matches_new_line
98    }
99}
100
101impl Eq for Regex {}
102
103impl PartialOrd for Regex {
104    fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
105        Some(self.cmp(other))
106    }
107}
108
109impl Ord for Regex {
110    fn cmp(&self, other: &Regex) -> Ordering {
111        (
112            self.pattern(),
113            self.case_insensitive,
114            self.dot_matches_new_line,
115        )
116            .cmp(&(
117                other.pattern(),
118                other.case_insensitive,
119                other.dot_matches_new_line,
120            ))
121    }
122}
123
124impl Hash for Regex {
125    fn hash<H: Hasher>(&self, hasher: &mut H) {
126        self.pattern().hash(hasher);
127        self.case_insensitive.hash(hasher);
128        self.dot_matches_new_line.hash(hasher);
129    }
130}
131
132impl Deref for Regex {
133    type Target = regex::Regex;
134
135    fn deref(&self) -> &regex::Regex {
136        &self.regex
137    }
138}
139
140impl RustType<ProtoRegex> for Regex {
141    fn into_proto(&self) -> ProtoRegex {
142        ProtoRegex {
143            pattern: self.pattern().to_owned(),
144            case_insensitive: self.case_insensitive,
145            dot_matches_new_line: self.dot_matches_new_line,
146        }
147    }
148
149    fn from_proto(proto: ProtoRegex) -> Result<Self, TryFromProtoError> {
150        Ok(Regex::new_dot_matches_new_line(
151            &proto.pattern,
152            proto.case_insensitive,
153            proto.dot_matches_new_line,
154        )?)
155    }
156}
157
158impl Serialize for Regex {
159    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
160    where
161        S: Serializer,
162    {
163        let mut state = serializer.serialize_struct("Regex", 3)?;
164        state.serialize_field("pattern", &self.pattern())?;
165        state.serialize_field("case_insensitive", &self.case_insensitive)?;
166        state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
167        state.end()
168    }
169}
170
171impl<'de> Deserialize<'de> for Regex {
172    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173    where
174        D: Deserializer<'de>,
175    {
176        enum Field {
177            Pattern,
178            CaseInsensitive,
179            DotMatchesNewLine,
180        }
181
182        impl<'de> Deserialize<'de> for Field {
183            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
184            where
185                D: Deserializer<'de>,
186            {
187                struct FieldVisitor;
188
189                impl<'de> de::Visitor<'de> for FieldVisitor {
190                    type Value = Field;
191
192                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
193                        formatter.write_str(
194                            "pattern string or case_insensitive bool or dot_matches_new_line bool",
195                        )
196                    }
197
198                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
199                    where
200                        E: de::Error,
201                    {
202                        match value {
203                            "pattern" => Ok(Field::Pattern),
204                            "case_insensitive" => Ok(Field::CaseInsensitive),
205                            "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
206                            _ => Err(de::Error::unknown_field(value, FIELDS)),
207                        }
208                    }
209                }
210
211                deserializer.deserialize_identifier(FieldVisitor)
212            }
213        }
214
215        struct RegexVisitor;
216
217        impl<'de> de::Visitor<'de> for RegexVisitor {
218            type Value = Regex;
219
220            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
221                formatter.write_str("Regex serialized by the manual Serialize impl from above")
222            }
223
224            fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
225            where
226                V: de::SeqAccess<'de>,
227            {
228                let pattern = seq
229                    .next_element()?
230                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
231                let case_insensitive = seq
232                    .next_element()?
233                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
234                let dot_matches_new_line = seq
235                    .next_element()?
236                    .ok_or_else(|| de::Error::invalid_length(2, &self))?;
237                Regex::new_dot_matches_new_line(pattern, case_insensitive, dot_matches_new_line)
238                    .map_err(|err| {
239                        V::Error::custom(format!(
240                            "Unable to recreate regex during deserialization: {}",
241                            err
242                        ))
243                    })
244            }
245
246            fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
247            where
248                V: de::MapAccess<'de>,
249            {
250                let mut pattern: Option<&str> = None;
251                let mut case_insensitive: Option<bool> = None;
252                let mut dot_matches_new_line: Option<bool> = None;
253                while let Some(key) = map.next_key()? {
254                    match key {
255                        Field::Pattern => {
256                            if pattern.is_some() {
257                                return Err(de::Error::duplicate_field("pattern"));
258                            }
259                            pattern = Some(map.next_value()?);
260                        }
261                        Field::CaseInsensitive => {
262                            if case_insensitive.is_some() {
263                                return Err(de::Error::duplicate_field("case_insensitive"));
264                            }
265                            case_insensitive = Some(map.next_value()?);
266                        }
267                        Field::DotMatchesNewLine => {
268                            if dot_matches_new_line.is_some() {
269                                return Err(de::Error::duplicate_field("dot_matches_new_line"));
270                            }
271                            dot_matches_new_line = Some(map.next_value()?);
272                        }
273                    }
274                }
275                let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
276                let case_insensitive =
277                    case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
278                let dot_matches_new_line = dot_matches_new_line
279                    .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
280                Regex::new_dot_matches_new_line(pattern, case_insensitive, dot_matches_new_line)
281                    .map_err(|err| {
282                        V::Error::custom(format!(
283                            "Unable to recreate regex during deserialization: {}",
284                            err
285                        ))
286                    })
287            }
288        }
289
290        const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
291        deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
292    }
293}
294
295// TODO: this is not really high priority, but this could modified to generate a
296// greater variety of regexes. Ignoring the beginning-of-file/line and EOF/EOL
297// symbols, the only regexes being generated are `.{#repetitions}` and
298// `x{#repetitions}`.
299const BEGINNING_SYMBOLS: &str = r"((\\A)|\^)?";
300const CHARACTERS: &str = r"[\.x]{1}";
301const REPETITIONS: &str = r"((\*|\+|\?|(\{[1-9],?\}))\??)?";
302const END_SYMBOLS: &str = r"(\$|(\\z))?";
303
304prop_compose! {
305    pub fn any_regex()
306                (b in BEGINNING_SYMBOLS, c in CHARACTERS,
307                 r in REPETITIONS, e in END_SYMBOLS, case_insensitive in any::<bool>(), dot_matches_new_line in any::<bool>())
308                -> Regex {
309        let string = format!("{}{}{}{}", b, c, r, e);
310        let regex = Regex::new_dot_matches_new_line(&string, case_insensitive, dot_matches_new_line).unwrap();
311        assert_eq!(regex.pattern(), string);
312        regex
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use mz_ore::assert_ok;
319    use mz_proto::protobuf_roundtrip;
320    use proptest::prelude::*;
321
322    use super::*;
323
324    proptest! {
325        #[mz_ore::test]
326        #[cfg_attr(miri, ignore)] // too slow
327        fn regex_protobuf_roundtrip( expect in any_regex() ) {
328            let actual =  protobuf_roundtrip::<_, ProtoRegex>(&expect);
329            assert_ok!(actual);
330            assert_eq!(actual.unwrap(), expect);
331        }
332    }
333
334    /// This was failing before due to the derived serde serialization being incorrect, because of
335    /// <https://github.com/tailhook/serde-regex/issues/14>.
336    /// Nowadays, we use our own handwritten Serialize/Deserialize impls for our Regex wrapper struct.
337    #[mz_ore::test]
338    fn regex_serde_case_insensitive() {
339        let pattern = "AAA";
340        let orig_regex = Regex::new(pattern, true).unwrap();
341        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
342        let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
343        // Equality test between orig and roundtrip_result wouldn't work, because Eq doesn't test
344        // the actual regex object. So test the actual regex functionality (concentrating on case
345        // sensitivity).
346        assert_eq!(orig_regex.regex.is_match("aaa"), true);
347        assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
348        assert_eq!(pattern, roundtrip_result.pattern());
349    }
350
351    /// Test the roundtripping of `dot_matches_new_line`.
352    /// (Similar to the above `regex_serde_case_insensitive`.)
353    #[mz_ore::test]
354    fn regex_serde_dot_matches_new_line() {
355        {
356            // dot_matches_new_line: true
357            let pattern = "A.*B";
358            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
359            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
360            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
361            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
362            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
363            assert_eq!(pattern, roundtrip_result.pattern());
364        }
365        {
366            // dot_matches_new_line: false
367            let pattern = "A.*B";
368            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
369            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
370            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
371            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
372            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
373            assert_eq!(pattern, roundtrip_result.pattern());
374        }
375        {
376            // dot_matches_new_line: default
377            let pattern = "A.*B";
378            let orig_regex = Regex::new(pattern, true).unwrap();
379            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
380            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
381            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
382            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
383            assert_eq!(pattern, roundtrip_result.pattern());
384        }
385    }
386}