1use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use mz_lowertest::MzReflect;
19use regex::{Error, RegexBuilder};
20use serde::de::Error as DeError;
21use serde::ser::SerializeStruct;
22use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
23
24#[derive(Debug, Clone, MzReflect)]
51pub struct Regex {
52 pub case_insensitive: bool,
53 pub dot_matches_new_line: bool,
54 pub regex: regex::Regex,
55}
56
57impl Regex {
58 pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
62 Self::new_dot_matches_new_line(pattern, case_insensitive, true)
63 }
64
65 pub fn new_dot_matches_new_line(
67 pattern: &str,
68 case_insensitive: bool,
69 dot_matches_new_line: bool,
70 ) -> Result<Regex, Error> {
71 let mut regex_builder = RegexBuilder::new(pattern);
72 regex_builder.case_insensitive(case_insensitive);
73 regex_builder.dot_matches_new_line(dot_matches_new_line);
74 Ok(Regex {
75 case_insensitive,
76 dot_matches_new_line,
77 regex: regex_builder.build()?,
78 })
79 }
80
81 pub fn pattern(&self) -> &str {
83 self.regex.as_str()
86 }
87}
88
89impl PartialEq<Regex> for Regex {
90 fn eq(&self, other: &Regex) -> bool {
91 self.pattern() == other.pattern()
92 && self.case_insensitive == other.case_insensitive
93 && self.dot_matches_new_line == other.dot_matches_new_line
94 }
95}
96
97impl Eq for Regex {}
98
99impl PartialOrd for Regex {
100 fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
101 Some(self.cmp(other))
102 }
103}
104
105impl Ord for Regex {
106 fn cmp(&self, other: &Regex) -> Ordering {
107 (
108 self.pattern(),
109 self.case_insensitive,
110 self.dot_matches_new_line,
111 )
112 .cmp(&(
113 other.pattern(),
114 other.case_insensitive,
115 other.dot_matches_new_line,
116 ))
117 }
118}
119
120impl Hash for Regex {
121 fn hash<H: Hasher>(&self, hasher: &mut H) {
122 self.pattern().hash(hasher);
123 self.case_insensitive.hash(hasher);
124 self.dot_matches_new_line.hash(hasher);
125 }
126}
127
128impl Deref for Regex {
129 type Target = regex::Regex;
130
131 fn deref(&self) -> ®ex::Regex {
132 &self.regex
133 }
134}
135
136impl Serialize for Regex {
137 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138 where
139 S: Serializer,
140 {
141 let mut state = serializer.serialize_struct("Regex", 3)?;
142 state.serialize_field("pattern", &self.pattern())?;
143 state.serialize_field("case_insensitive", &self.case_insensitive)?;
144 state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
145 state.end()
146 }
147}
148
149impl<'de> Deserialize<'de> for Regex {
150 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
151 where
152 D: Deserializer<'de>,
153 {
154 enum Field {
155 Pattern,
156 CaseInsensitive,
157 DotMatchesNewLine,
158 }
159
160 impl<'de> Deserialize<'de> for Field {
161 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
162 where
163 D: Deserializer<'de>,
164 {
165 struct FieldVisitor;
166
167 impl<'de> de::Visitor<'de> for FieldVisitor {
168 type Value = Field;
169
170 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
171 formatter.write_str(
172 "pattern string or case_insensitive bool or dot_matches_new_line bool",
173 )
174 }
175
176 fn visit_str<E>(self, value: &str) -> Result<Field, E>
177 where
178 E: de::Error,
179 {
180 match value {
181 "pattern" => Ok(Field::Pattern),
182 "case_insensitive" => Ok(Field::CaseInsensitive),
183 "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
184 _ => Err(de::Error::unknown_field(value, FIELDS)),
185 }
186 }
187 }
188
189 deserializer.deserialize_identifier(FieldVisitor)
190 }
191 }
192
193 struct RegexVisitor;
194
195 impl<'de> de::Visitor<'de> for RegexVisitor {
196 type Value = Regex;
197
198 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
199 formatter.write_str("Regex serialized by the manual Serialize impl from above")
200 }
201
202 fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
203 where
204 V: de::SeqAccess<'de>,
205 {
206 let pattern = seq
207 .next_element::<Cow<str>>()?
208 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
209 let case_insensitive = seq
210 .next_element()?
211 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
212 let dot_matches_new_line = seq
213 .next_element()?
214 .ok_or_else(|| de::Error::invalid_length(2, &self))?;
215 Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
216 .map_err(|err| {
217 V::Error::custom(format!(
218 "Unable to recreate regex during deserialization: {}",
219 err
220 ))
221 })
222 }
223
224 fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
225 where
226 V: de::MapAccess<'de>,
227 {
228 let mut pattern: Option<Cow<str>> = None;
229 let mut case_insensitive: Option<bool> = None;
230 let mut dot_matches_new_line: Option<bool> = None;
231 while let Some(key) = map.next_key()? {
232 match key {
233 Field::Pattern => {
234 if pattern.is_some() {
235 return Err(de::Error::duplicate_field("pattern"));
236 }
237 pattern = Some(map.next_value()?);
238 }
239 Field::CaseInsensitive => {
240 if case_insensitive.is_some() {
241 return Err(de::Error::duplicate_field("case_insensitive"));
242 }
243 case_insensitive = Some(map.next_value()?);
244 }
245 Field::DotMatchesNewLine => {
246 if dot_matches_new_line.is_some() {
247 return Err(de::Error::duplicate_field("dot_matches_new_line"));
248 }
249 dot_matches_new_line = Some(map.next_value()?);
250 }
251 }
252 }
253 let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
254 let case_insensitive =
255 case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
256 let dot_matches_new_line = dot_matches_new_line
257 .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
258 Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
259 .map_err(|err| {
260 V::Error::custom(format!(
261 "Unable to recreate regex during deserialization: {}",
262 err
263 ))
264 })
265 }
266 }
267
268 const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
269 deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[mz_ore::test]
281 fn regex_serde_case_insensitive() {
282 let pattern = "AAA";
283 let orig_regex = Regex::new(pattern, true).unwrap();
284 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
285 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
286 assert_eq!(orig_regex.regex.is_match("aaa"), true);
290 assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
291 assert_eq!(pattern, roundtrip_result.pattern());
292 }
293
294 #[mz_ore::test]
297 fn regex_serde_dot_matches_new_line() {
298 {
299 let pattern = "A.*B";
301 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
302 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
303 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
304 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
305 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
306 assert_eq!(pattern, roundtrip_result.pattern());
307 }
308 {
309 let pattern = "A.*B";
311 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
312 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
313 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
314 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
315 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
316 assert_eq!(pattern, roundtrip_result.pattern());
317 }
318 {
319 let pattern = "A.*B";
321 let orig_regex = Regex::new(pattern, true).unwrap();
322 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
323 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
324 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
325 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
326 assert_eq!(pattern, roundtrip_result.pattern());
327 }
328 }
329
330 #[mz_ore::test]
331 fn regex_serde_from_reader() {
332 let pattern = "A.*B";
333 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
334
335 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
336 let roundtrip_result: Regex = serde_json::from_reader(serialized.as_bytes()).unwrap();
337
338 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
339 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
340 assert_eq!(pattern, roundtrip_result.pattern());
341
342 let serialized = bincode::serialize(&orig_regex).unwrap();
343 let roundtrip_result: Regex = bincode::deserialize_from(&*serialized).unwrap();
344
345 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
346 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
347 assert_eq!(pattern, roundtrip_result.pattern());
348 }
349}