mz_repr/adt/
regex.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Regular expressions.
11
12use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use mz_lowertest::MzReflect;
19use mz_proto::{RustType, TryFromProtoError};
20use proptest::prelude::any;
21use proptest::prop_compose;
22use regex::{Error, RegexBuilder};
23use serde::de::Error as DeError;
24use serde::ser::SerializeStruct;
25use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
26
27include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.regex.rs"));
28
29/// A hashable, comparable, and serializable regular expression type.
30///
31/// The  [`regex::Regex`] type, the de facto standard regex type in Rust, does
32/// not implement [`PartialOrd`], [`Ord`] [`PartialEq`], [`Eq`], or [`Hash`].
33/// The omissions are reasonable. There is no natural definition of ordering for
34/// regexes. There *is* a natural definition of equality—whether two regexes
35/// describe the same regular language—but that is an expensive property to
36/// compute, and [`PartialEq`] is generally expected to be fast to compute.
37///
38/// This type wraps [`regex::Regex`] and imbues it with implementations of the
39/// above traits. Two regexes are considered equal iff their string
40/// representation is identical, plus flags, such as `case_insensitive`,
41/// are identical. The [`PartialOrd`], [`Ord`], and [`Hash`] implementations
42/// are similarly based upon the string representation plus flags. As
43/// mentioned above, this is not the natural equivalence relation for regexes: for
44/// example, the regexes `aa*` and `a+` define the same language, but would not
45/// compare as equal with this implementation of [`PartialEq`]. Still, it is
46/// often useful to have _some_ equivalence relation available (e.g., to store
47/// types containing regexes in a hashmap) even if the equivalence relation is
48/// imperfect.
49///
50/// [regex::Regex] is hard to serialize (because of the compiled code), so our approach is to
51/// instead serialize this wrapper struct, where we skip serializing the actual regex field, and
52/// we reconstruct the regex field from the other fields upon deserialization.
53/// (Earlier, serialization was buggy due to <https://github.com/tailhook/serde-regex/issues/14>,
54/// and also making the same mistake in our own protobuf serialization code.)
55#[derive(Debug, Clone, MzReflect)]
56pub struct Regex {
57    pub case_insensitive: bool,
58    pub dot_matches_new_line: bool,
59    pub regex: regex::Regex,
60}
61
62impl Regex {
63    /// A simple constructor for the default setting of `dot_matches_new_line: true`.
64    /// See <https://www.postgresql.org/docs/current/functions-matching.html#POSIX-MATCHING-RULES>
65    /// "newline-sensitive matching"
66    pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
67        Self::new_dot_matches_new_line(pattern, case_insensitive, true)
68    }
69
70    /// Allows explicitly setting `dot_matches_new_line`.
71    pub fn new_dot_matches_new_line(
72        pattern: &str,
73        case_insensitive: bool,
74        dot_matches_new_line: bool,
75    ) -> Result<Regex, Error> {
76        let mut regex_builder = RegexBuilder::new(pattern);
77        regex_builder.case_insensitive(case_insensitive);
78        regex_builder.dot_matches_new_line(dot_matches_new_line);
79        Ok(Regex {
80            case_insensitive,
81            dot_matches_new_line,
82            regex: regex_builder.build()?,
83        })
84    }
85
86    /// Returns the pattern string of the regex.
87    pub fn pattern(&self) -> &str {
88        // `as_str` returns the raw pattern as provided during construction,
89        // and doesn't include any of the flags.
90        self.regex.as_str()
91    }
92}
93
94impl PartialEq<Regex> for Regex {
95    fn eq(&self, other: &Regex) -> bool {
96        self.pattern() == other.pattern()
97            && self.case_insensitive == other.case_insensitive
98            && self.dot_matches_new_line == other.dot_matches_new_line
99    }
100}
101
102impl Eq for Regex {}
103
104impl PartialOrd for Regex {
105    fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
106        Some(self.cmp(other))
107    }
108}
109
110impl Ord for Regex {
111    fn cmp(&self, other: &Regex) -> Ordering {
112        (
113            self.pattern(),
114            self.case_insensitive,
115            self.dot_matches_new_line,
116        )
117            .cmp(&(
118                other.pattern(),
119                other.case_insensitive,
120                other.dot_matches_new_line,
121            ))
122    }
123}
124
125impl Hash for Regex {
126    fn hash<H: Hasher>(&self, hasher: &mut H) {
127        self.pattern().hash(hasher);
128        self.case_insensitive.hash(hasher);
129        self.dot_matches_new_line.hash(hasher);
130    }
131}
132
133impl Deref for Regex {
134    type Target = regex::Regex;
135
136    fn deref(&self) -> &regex::Regex {
137        &self.regex
138    }
139}
140
141impl RustType<ProtoRegex> for Regex {
142    fn into_proto(&self) -> ProtoRegex {
143        ProtoRegex {
144            pattern: self.pattern().to_owned(),
145            case_insensitive: self.case_insensitive,
146            dot_matches_new_line: self.dot_matches_new_line,
147        }
148    }
149
150    fn from_proto(proto: ProtoRegex) -> Result<Self, TryFromProtoError> {
151        Ok(Regex::new_dot_matches_new_line(
152            &proto.pattern,
153            proto.case_insensitive,
154            proto.dot_matches_new_line,
155        )?)
156    }
157}
158
159impl Serialize for Regex {
160    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161    where
162        S: Serializer,
163    {
164        let mut state = serializer.serialize_struct("Regex", 3)?;
165        state.serialize_field("pattern", &self.pattern())?;
166        state.serialize_field("case_insensitive", &self.case_insensitive)?;
167        state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
168        state.end()
169    }
170}
171
172impl<'de> Deserialize<'de> for Regex {
173    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174    where
175        D: Deserializer<'de>,
176    {
177        enum Field {
178            Pattern,
179            CaseInsensitive,
180            DotMatchesNewLine,
181        }
182
183        impl<'de> Deserialize<'de> for Field {
184            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
185            where
186                D: Deserializer<'de>,
187            {
188                struct FieldVisitor;
189
190                impl<'de> de::Visitor<'de> for FieldVisitor {
191                    type Value = Field;
192
193                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194                        formatter.write_str(
195                            "pattern string or case_insensitive bool or dot_matches_new_line bool",
196                        )
197                    }
198
199                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
200                    where
201                        E: de::Error,
202                    {
203                        match value {
204                            "pattern" => Ok(Field::Pattern),
205                            "case_insensitive" => Ok(Field::CaseInsensitive),
206                            "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
207                            _ => Err(de::Error::unknown_field(value, FIELDS)),
208                        }
209                    }
210                }
211
212                deserializer.deserialize_identifier(FieldVisitor)
213            }
214        }
215
216        struct RegexVisitor;
217
218        impl<'de> de::Visitor<'de> for RegexVisitor {
219            type Value = Regex;
220
221            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
222                formatter.write_str("Regex serialized by the manual Serialize impl from above")
223            }
224
225            fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
226            where
227                V: de::SeqAccess<'de>,
228            {
229                let pattern = seq
230                    .next_element::<Cow<str>>()?
231                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
232                let case_insensitive = seq
233                    .next_element()?
234                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
235                let dot_matches_new_line = seq
236                    .next_element()?
237                    .ok_or_else(|| de::Error::invalid_length(2, &self))?;
238                Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
239                    .map_err(|err| {
240                        V::Error::custom(format!(
241                            "Unable to recreate regex during deserialization: {}",
242                            err
243                        ))
244                    })
245            }
246
247            fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
248            where
249                V: de::MapAccess<'de>,
250            {
251                let mut pattern: Option<Cow<str>> = None;
252                let mut case_insensitive: Option<bool> = None;
253                let mut dot_matches_new_line: Option<bool> = None;
254                while let Some(key) = map.next_key()? {
255                    match key {
256                        Field::Pattern => {
257                            if pattern.is_some() {
258                                return Err(de::Error::duplicate_field("pattern"));
259                            }
260                            pattern = Some(map.next_value()?);
261                        }
262                        Field::CaseInsensitive => {
263                            if case_insensitive.is_some() {
264                                return Err(de::Error::duplicate_field("case_insensitive"));
265                            }
266                            case_insensitive = Some(map.next_value()?);
267                        }
268                        Field::DotMatchesNewLine => {
269                            if dot_matches_new_line.is_some() {
270                                return Err(de::Error::duplicate_field("dot_matches_new_line"));
271                            }
272                            dot_matches_new_line = Some(map.next_value()?);
273                        }
274                    }
275                }
276                let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
277                let case_insensitive =
278                    case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
279                let dot_matches_new_line = dot_matches_new_line
280                    .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
281                Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
282                    .map_err(|err| {
283                        V::Error::custom(format!(
284                            "Unable to recreate regex during deserialization: {}",
285                            err
286                        ))
287                    })
288            }
289        }
290
291        const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
292        deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
293    }
294}
295
296// TODO: this is not really high priority, but this could modified to generate a
297// greater variety of regexes. Ignoring the beginning-of-file/line and EOF/EOL
298// symbols, the only regexes being generated are `.{#repetitions}` and
299// `x{#repetitions}`.
300const BEGINNING_SYMBOLS: &str = r"((\\A)|\^)?";
301const CHARACTERS: &str = r"[\.x]{1}";
302const REPETITIONS: &str = r"((\*|\+|\?|(\{[1-9],?\}))\??)?";
303const END_SYMBOLS: &str = r"(\$|(\\z))?";
304
305prop_compose! {
306    pub fn any_regex()
307                (b in BEGINNING_SYMBOLS, c in CHARACTERS,
308                 r in REPETITIONS, e in END_SYMBOLS, case_insensitive in any::<bool>(), dot_matches_new_line in any::<bool>())
309                -> Regex {
310        let string = format!("{}{}{}{}", b, c, r, e);
311        let regex = Regex::new_dot_matches_new_line(&string, case_insensitive, dot_matches_new_line).unwrap();
312        assert_eq!(regex.pattern(), string);
313        regex
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use mz_ore::assert_ok;
320    use mz_proto::protobuf_roundtrip;
321    use proptest::prelude::*;
322
323    use super::*;
324
325    proptest! {
326        #[mz_ore::test]
327        #[cfg_attr(miri, ignore)] // too slow
328        fn regex_protobuf_roundtrip( expect in any_regex() ) {
329            let actual =  protobuf_roundtrip::<_, ProtoRegex>(&expect);
330            assert_ok!(actual);
331            assert_eq!(actual.unwrap(), expect);
332        }
333    }
334
335    /// This was failing before due to the derived serde serialization being incorrect, because of
336    /// <https://github.com/tailhook/serde-regex/issues/14>.
337    /// Nowadays, we use our own handwritten Serialize/Deserialize impls for our Regex wrapper struct.
338    #[mz_ore::test]
339    fn regex_serde_case_insensitive() {
340        let pattern = "AAA";
341        let orig_regex = Regex::new(pattern, true).unwrap();
342        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
343        let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
344        // Equality test between orig and roundtrip_result wouldn't work, because Eq doesn't test
345        // the actual regex object. So test the actual regex functionality (concentrating on case
346        // sensitivity).
347        assert_eq!(orig_regex.regex.is_match("aaa"), true);
348        assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
349        assert_eq!(pattern, roundtrip_result.pattern());
350    }
351
352    /// Test the roundtripping of `dot_matches_new_line`.
353    /// (Similar to the above `regex_serde_case_insensitive`.)
354    #[mz_ore::test]
355    fn regex_serde_dot_matches_new_line() {
356        {
357            // dot_matches_new_line: true
358            let pattern = "A.*B";
359            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
360            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
361            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
362            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
363            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
364            assert_eq!(pattern, roundtrip_result.pattern());
365        }
366        {
367            // dot_matches_new_line: false
368            let pattern = "A.*B";
369            let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
370            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
371            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
372            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
373            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
374            assert_eq!(pattern, roundtrip_result.pattern());
375        }
376        {
377            // dot_matches_new_line: default
378            let pattern = "A.*B";
379            let orig_regex = Regex::new(pattern, true).unwrap();
380            let serialized: String = serde_json::to_string(&orig_regex).unwrap();
381            let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
382            assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
383            assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
384            assert_eq!(pattern, roundtrip_result.pattern());
385        }
386    }
387
388    #[mz_ore::test]
389    fn regex_serde_from_reader() {
390        let pattern = "A.*B";
391        let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
392
393        let serialized: String = serde_json::to_string(&orig_regex).unwrap();
394        let roundtrip_result: Regex = serde_json::from_reader(serialized.as_bytes()).unwrap();
395
396        assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
397        assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
398        assert_eq!(pattern, roundtrip_result.pattern());
399
400        let serialized = bincode::serialize(&orig_regex).unwrap();
401        let roundtrip_result: Regex = bincode::deserialize_from(&*serialized).unwrap();
402
403        assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
404        assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
405        assert_eq!(pattern, roundtrip_result.pattern());
406    }
407}