1use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use mz_lowertest::MzReflect;
19use mz_proto::{RustType, TryFromProtoError};
20use proptest::prelude::any;
21use proptest::prop_compose;
22use regex::{Error, RegexBuilder};
23use serde::de::Error as DeError;
24use serde::ser::SerializeStruct;
25use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
26
27include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.regex.rs"));
28
29#[derive(Debug, Clone, MzReflect)]
56pub struct Regex {
57 pub case_insensitive: bool,
58 pub dot_matches_new_line: bool,
59 pub regex: regex::Regex,
60}
61
62impl Regex {
63 pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
67 Self::new_dot_matches_new_line(pattern, case_insensitive, true)
68 }
69
70 pub fn new_dot_matches_new_line(
72 pattern: &str,
73 case_insensitive: bool,
74 dot_matches_new_line: bool,
75 ) -> Result<Regex, Error> {
76 let mut regex_builder = RegexBuilder::new(pattern);
77 regex_builder.case_insensitive(case_insensitive);
78 regex_builder.dot_matches_new_line(dot_matches_new_line);
79 Ok(Regex {
80 case_insensitive,
81 dot_matches_new_line,
82 regex: regex_builder.build()?,
83 })
84 }
85
86 pub fn pattern(&self) -> &str {
88 self.regex.as_str()
91 }
92}
93
94impl PartialEq<Regex> for Regex {
95 fn eq(&self, other: &Regex) -> bool {
96 self.pattern() == other.pattern()
97 && self.case_insensitive == other.case_insensitive
98 && self.dot_matches_new_line == other.dot_matches_new_line
99 }
100}
101
102impl Eq for Regex {}
103
104impl PartialOrd for Regex {
105 fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
106 Some(self.cmp(other))
107 }
108}
109
110impl Ord for Regex {
111 fn cmp(&self, other: &Regex) -> Ordering {
112 (
113 self.pattern(),
114 self.case_insensitive,
115 self.dot_matches_new_line,
116 )
117 .cmp(&(
118 other.pattern(),
119 other.case_insensitive,
120 other.dot_matches_new_line,
121 ))
122 }
123}
124
125impl Hash for Regex {
126 fn hash<H: Hasher>(&self, hasher: &mut H) {
127 self.pattern().hash(hasher);
128 self.case_insensitive.hash(hasher);
129 self.dot_matches_new_line.hash(hasher);
130 }
131}
132
133impl Deref for Regex {
134 type Target = regex::Regex;
135
136 fn deref(&self) -> ®ex::Regex {
137 &self.regex
138 }
139}
140
141impl RustType<ProtoRegex> for Regex {
142 fn into_proto(&self) -> ProtoRegex {
143 ProtoRegex {
144 pattern: self.pattern().to_owned(),
145 case_insensitive: self.case_insensitive,
146 dot_matches_new_line: self.dot_matches_new_line,
147 }
148 }
149
150 fn from_proto(proto: ProtoRegex) -> Result<Self, TryFromProtoError> {
151 Ok(Regex::new_dot_matches_new_line(
152 &proto.pattern,
153 proto.case_insensitive,
154 proto.dot_matches_new_line,
155 )?)
156 }
157}
158
159impl Serialize for Regex {
160 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161 where
162 S: Serializer,
163 {
164 let mut state = serializer.serialize_struct("Regex", 3)?;
165 state.serialize_field("pattern", &self.pattern())?;
166 state.serialize_field("case_insensitive", &self.case_insensitive)?;
167 state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
168 state.end()
169 }
170}
171
172impl<'de> Deserialize<'de> for Regex {
173 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174 where
175 D: Deserializer<'de>,
176 {
177 enum Field {
178 Pattern,
179 CaseInsensitive,
180 DotMatchesNewLine,
181 }
182
183 impl<'de> Deserialize<'de> for Field {
184 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
185 where
186 D: Deserializer<'de>,
187 {
188 struct FieldVisitor;
189
190 impl<'de> de::Visitor<'de> for FieldVisitor {
191 type Value = Field;
192
193 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194 formatter.write_str(
195 "pattern string or case_insensitive bool or dot_matches_new_line bool",
196 )
197 }
198
199 fn visit_str<E>(self, value: &str) -> Result<Field, E>
200 where
201 E: de::Error,
202 {
203 match value {
204 "pattern" => Ok(Field::Pattern),
205 "case_insensitive" => Ok(Field::CaseInsensitive),
206 "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
207 _ => Err(de::Error::unknown_field(value, FIELDS)),
208 }
209 }
210 }
211
212 deserializer.deserialize_identifier(FieldVisitor)
213 }
214 }
215
216 struct RegexVisitor;
217
218 impl<'de> de::Visitor<'de> for RegexVisitor {
219 type Value = Regex;
220
221 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
222 formatter.write_str("Regex serialized by the manual Serialize impl from above")
223 }
224
225 fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
226 where
227 V: de::SeqAccess<'de>,
228 {
229 let pattern = seq
230 .next_element::<Cow<str>>()?
231 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
232 let case_insensitive = seq
233 .next_element()?
234 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
235 let dot_matches_new_line = seq
236 .next_element()?
237 .ok_or_else(|| de::Error::invalid_length(2, &self))?;
238 Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
239 .map_err(|err| {
240 V::Error::custom(format!(
241 "Unable to recreate regex during deserialization: {}",
242 err
243 ))
244 })
245 }
246
247 fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
248 where
249 V: de::MapAccess<'de>,
250 {
251 let mut pattern: Option<Cow<str>> = None;
252 let mut case_insensitive: Option<bool> = None;
253 let mut dot_matches_new_line: Option<bool> = None;
254 while let Some(key) = map.next_key()? {
255 match key {
256 Field::Pattern => {
257 if pattern.is_some() {
258 return Err(de::Error::duplicate_field("pattern"));
259 }
260 pattern = Some(map.next_value()?);
261 }
262 Field::CaseInsensitive => {
263 if case_insensitive.is_some() {
264 return Err(de::Error::duplicate_field("case_insensitive"));
265 }
266 case_insensitive = Some(map.next_value()?);
267 }
268 Field::DotMatchesNewLine => {
269 if dot_matches_new_line.is_some() {
270 return Err(de::Error::duplicate_field("dot_matches_new_line"));
271 }
272 dot_matches_new_line = Some(map.next_value()?);
273 }
274 }
275 }
276 let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
277 let case_insensitive =
278 case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
279 let dot_matches_new_line = dot_matches_new_line
280 .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
281 Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
282 .map_err(|err| {
283 V::Error::custom(format!(
284 "Unable to recreate regex during deserialization: {}",
285 err
286 ))
287 })
288 }
289 }
290
291 const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
292 deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
293 }
294}
295
296const BEGINNING_SYMBOLS: &str = r"((\\A)|\^)?";
301const CHARACTERS: &str = r"[\.x]{1}";
302const REPETITIONS: &str = r"((\*|\+|\?|(\{[1-9],?\}))\??)?";
303const END_SYMBOLS: &str = r"(\$|(\\z))?";
304
305prop_compose! {
306 pub fn any_regex()
307 (b in BEGINNING_SYMBOLS, c in CHARACTERS,
308 r in REPETITIONS, e in END_SYMBOLS, case_insensitive in any::<bool>(), dot_matches_new_line in any::<bool>())
309 -> Regex {
310 let string = format!("{}{}{}{}", b, c, r, e);
311 let regex = Regex::new_dot_matches_new_line(&string, case_insensitive, dot_matches_new_line).unwrap();
312 assert_eq!(regex.pattern(), string);
313 regex
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use mz_ore::assert_ok;
320 use mz_proto::protobuf_roundtrip;
321 use proptest::prelude::*;
322
323 use super::*;
324
325 proptest! {
326 #[mz_ore::test]
327 #[cfg_attr(miri, ignore)] fn regex_protobuf_roundtrip( expect in any_regex() ) {
329 let actual = protobuf_roundtrip::<_, ProtoRegex>(&expect);
330 assert_ok!(actual);
331 assert_eq!(actual.unwrap(), expect);
332 }
333 }
334
335 #[mz_ore::test]
339 fn regex_serde_case_insensitive() {
340 let pattern = "AAA";
341 let orig_regex = Regex::new(pattern, true).unwrap();
342 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
343 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
344 assert_eq!(orig_regex.regex.is_match("aaa"), true);
348 assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
349 assert_eq!(pattern, roundtrip_result.pattern());
350 }
351
352 #[mz_ore::test]
355 fn regex_serde_dot_matches_new_line() {
356 {
357 let pattern = "A.*B";
359 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
360 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
361 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
362 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
363 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
364 assert_eq!(pattern, roundtrip_result.pattern());
365 }
366 {
367 let pattern = "A.*B";
369 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
370 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
371 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
372 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
373 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
374 assert_eq!(pattern, roundtrip_result.pattern());
375 }
376 {
377 let pattern = "A.*B";
379 let orig_regex = Regex::new(pattern, true).unwrap();
380 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
381 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
382 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
383 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
384 assert_eq!(pattern, roundtrip_result.pattern());
385 }
386 }
387
388 #[mz_ore::test]
389 fn regex_serde_from_reader() {
390 let pattern = "A.*B";
391 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
392
393 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
394 let roundtrip_result: Regex = serde_json::from_reader(serialized.as_bytes()).unwrap();
395
396 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
397 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
398 assert_eq!(pattern, roundtrip_result.pattern());
399
400 let serialized = bincode::serialize(&orig_regex).unwrap();
401 let roundtrip_result: Regex = bincode::deserialize_from(&*serialized).unwrap();
402
403 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
404 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
405 assert_eq!(pattern, roundtrip_result.pattern());
406 }
407}