1use std::cmp::Ordering;
13use std::fmt;
14use std::hash::{Hash, Hasher};
15use std::ops::Deref;
16
17use mz_lowertest::MzReflect;
18use mz_proto::{RustType, TryFromProtoError};
19use proptest::prelude::any;
20use proptest::prop_compose;
21use regex::{Error, RegexBuilder};
22use serde::de::Error as DeError;
23use serde::ser::SerializeStruct;
24use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
25
26include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.regex.rs"));
27
28#[derive(Debug, Clone, MzReflect)]
55pub struct Regex {
56 pub case_insensitive: bool,
57 pub dot_matches_new_line: bool,
58 pub regex: regex::Regex,
59}
60
61impl Regex {
62 pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
66 Self::new_dot_matches_new_line(pattern, case_insensitive, true)
67 }
68
69 pub fn new_dot_matches_new_line(
71 pattern: &str,
72 case_insensitive: bool,
73 dot_matches_new_line: bool,
74 ) -> Result<Regex, Error> {
75 let mut regex_builder = RegexBuilder::new(pattern);
76 regex_builder.case_insensitive(case_insensitive);
77 regex_builder.dot_matches_new_line(dot_matches_new_line);
78 Ok(Regex {
79 case_insensitive,
80 dot_matches_new_line,
81 regex: regex_builder.build()?,
82 })
83 }
84
85 pub fn pattern(&self) -> &str {
87 self.regex.as_str()
90 }
91}
92
93impl PartialEq<Regex> for Regex {
94 fn eq(&self, other: &Regex) -> bool {
95 self.pattern() == other.pattern()
96 && self.case_insensitive == other.case_insensitive
97 && self.dot_matches_new_line == other.dot_matches_new_line
98 }
99}
100
101impl Eq for Regex {}
102
103impl PartialOrd for Regex {
104 fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
105 Some(self.cmp(other))
106 }
107}
108
109impl Ord for Regex {
110 fn cmp(&self, other: &Regex) -> Ordering {
111 (
112 self.pattern(),
113 self.case_insensitive,
114 self.dot_matches_new_line,
115 )
116 .cmp(&(
117 other.pattern(),
118 other.case_insensitive,
119 other.dot_matches_new_line,
120 ))
121 }
122}
123
124impl Hash for Regex {
125 fn hash<H: Hasher>(&self, hasher: &mut H) {
126 self.pattern().hash(hasher);
127 self.case_insensitive.hash(hasher);
128 self.dot_matches_new_line.hash(hasher);
129 }
130}
131
132impl Deref for Regex {
133 type Target = regex::Regex;
134
135 fn deref(&self) -> ®ex::Regex {
136 &self.regex
137 }
138}
139
140impl RustType<ProtoRegex> for Regex {
141 fn into_proto(&self) -> ProtoRegex {
142 ProtoRegex {
143 pattern: self.pattern().to_owned(),
144 case_insensitive: self.case_insensitive,
145 dot_matches_new_line: self.dot_matches_new_line,
146 }
147 }
148
149 fn from_proto(proto: ProtoRegex) -> Result<Self, TryFromProtoError> {
150 Ok(Regex::new_dot_matches_new_line(
151 &proto.pattern,
152 proto.case_insensitive,
153 proto.dot_matches_new_line,
154 )?)
155 }
156}
157
158impl Serialize for Regex {
159 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
160 where
161 S: Serializer,
162 {
163 let mut state = serializer.serialize_struct("Regex", 3)?;
164 state.serialize_field("pattern", &self.pattern())?;
165 state.serialize_field("case_insensitive", &self.case_insensitive)?;
166 state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
167 state.end()
168 }
169}
170
171impl<'de> Deserialize<'de> for Regex {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 enum Field {
177 Pattern,
178 CaseInsensitive,
179 DotMatchesNewLine,
180 }
181
182 impl<'de> Deserialize<'de> for Field {
183 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
184 where
185 D: Deserializer<'de>,
186 {
187 struct FieldVisitor;
188
189 impl<'de> de::Visitor<'de> for FieldVisitor {
190 type Value = Field;
191
192 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
193 formatter.write_str(
194 "pattern string or case_insensitive bool or dot_matches_new_line bool",
195 )
196 }
197
198 fn visit_str<E>(self, value: &str) -> Result<Field, E>
199 where
200 E: de::Error,
201 {
202 match value {
203 "pattern" => Ok(Field::Pattern),
204 "case_insensitive" => Ok(Field::CaseInsensitive),
205 "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
206 _ => Err(de::Error::unknown_field(value, FIELDS)),
207 }
208 }
209 }
210
211 deserializer.deserialize_identifier(FieldVisitor)
212 }
213 }
214
215 struct RegexVisitor;
216
217 impl<'de> de::Visitor<'de> for RegexVisitor {
218 type Value = Regex;
219
220 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
221 formatter.write_str("Regex serialized by the manual Serialize impl from above")
222 }
223
224 fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
225 where
226 V: de::SeqAccess<'de>,
227 {
228 let pattern = seq
229 .next_element()?
230 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
231 let case_insensitive = seq
232 .next_element()?
233 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
234 let dot_matches_new_line = seq
235 .next_element()?
236 .ok_or_else(|| de::Error::invalid_length(2, &self))?;
237 Regex::new_dot_matches_new_line(pattern, case_insensitive, dot_matches_new_line)
238 .map_err(|err| {
239 V::Error::custom(format!(
240 "Unable to recreate regex during deserialization: {}",
241 err
242 ))
243 })
244 }
245
246 fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
247 where
248 V: de::MapAccess<'de>,
249 {
250 let mut pattern: Option<&str> = None;
251 let mut case_insensitive: Option<bool> = None;
252 let mut dot_matches_new_line: Option<bool> = None;
253 while let Some(key) = map.next_key()? {
254 match key {
255 Field::Pattern => {
256 if pattern.is_some() {
257 return Err(de::Error::duplicate_field("pattern"));
258 }
259 pattern = Some(map.next_value()?);
260 }
261 Field::CaseInsensitive => {
262 if case_insensitive.is_some() {
263 return Err(de::Error::duplicate_field("case_insensitive"));
264 }
265 case_insensitive = Some(map.next_value()?);
266 }
267 Field::DotMatchesNewLine => {
268 if dot_matches_new_line.is_some() {
269 return Err(de::Error::duplicate_field("dot_matches_new_line"));
270 }
271 dot_matches_new_line = Some(map.next_value()?);
272 }
273 }
274 }
275 let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
276 let case_insensitive =
277 case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
278 let dot_matches_new_line = dot_matches_new_line
279 .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
280 Regex::new_dot_matches_new_line(pattern, case_insensitive, dot_matches_new_line)
281 .map_err(|err| {
282 V::Error::custom(format!(
283 "Unable to recreate regex during deserialization: {}",
284 err
285 ))
286 })
287 }
288 }
289
290 const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
291 deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
292 }
293}
294
295const BEGINNING_SYMBOLS: &str = r"((\\A)|\^)?";
300const CHARACTERS: &str = r"[\.x]{1}";
301const REPETITIONS: &str = r"((\*|\+|\?|(\{[1-9],?\}))\??)?";
302const END_SYMBOLS: &str = r"(\$|(\\z))?";
303
304prop_compose! {
305 pub fn any_regex()
306 (b in BEGINNING_SYMBOLS, c in CHARACTERS,
307 r in REPETITIONS, e in END_SYMBOLS, case_insensitive in any::<bool>(), dot_matches_new_line in any::<bool>())
308 -> Regex {
309 let string = format!("{}{}{}{}", b, c, r, e);
310 let regex = Regex::new_dot_matches_new_line(&string, case_insensitive, dot_matches_new_line).unwrap();
311 assert_eq!(regex.pattern(), string);
312 regex
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use mz_ore::assert_ok;
319 use mz_proto::protobuf_roundtrip;
320 use proptest::prelude::*;
321
322 use super::*;
323
324 proptest! {
325 #[mz_ore::test]
326 #[cfg_attr(miri, ignore)] fn regex_protobuf_roundtrip( expect in any_regex() ) {
328 let actual = protobuf_roundtrip::<_, ProtoRegex>(&expect);
329 assert_ok!(actual);
330 assert_eq!(actual.unwrap(), expect);
331 }
332 }
333
334 #[mz_ore::test]
338 fn regex_serde_case_insensitive() {
339 let pattern = "AAA";
340 let orig_regex = Regex::new(pattern, true).unwrap();
341 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
342 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
343 assert_eq!(orig_regex.regex.is_match("aaa"), true);
347 assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
348 assert_eq!(pattern, roundtrip_result.pattern());
349 }
350
351 #[mz_ore::test]
354 fn regex_serde_dot_matches_new_line() {
355 {
356 let pattern = "A.*B";
358 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
359 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
360 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
361 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
362 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
363 assert_eq!(pattern, roundtrip_result.pattern());
364 }
365 {
366 let pattern = "A.*B";
368 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
369 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
370 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
371 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
372 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
373 assert_eq!(pattern, roundtrip_result.pattern());
374 }
375 {
376 let pattern = "A.*B";
378 let orig_regex = Regex::new(pattern, true).unwrap();
379 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
380 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
381 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
382 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
383 assert_eq!(pattern, roundtrip_result.pattern());
384 }
385 }
386}