1use std::borrow::Cow;
13use std::cmp::Ordering;
14use std::fmt;
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use mz_lowertest::MzReflect;
19use regex::{Error, RegexBuilder};
20use serde::de::Error as DeError;
21use serde::ser::SerializeStruct;
22use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
23
24const MAX_REGEX_SIZE_AFTER_COMPILATION: usize = 10 * 1024 * 1024;
30
31const MAX_REGEX_SIZE_BEFORE_COMPILATION: usize = 1 * 1024 * 1024;
40
41#[derive(Debug, Clone, MzReflect)]
68pub struct Regex {
69 pub case_insensitive: bool,
70 pub dot_matches_new_line: bool,
71 pub regex: regex::Regex,
72}
73
74impl Regex {
75 pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, RegexCompilationError> {
79 Self::new_dot_matches_new_line(pattern, case_insensitive, true)
80 }
81
82 pub fn new_dot_matches_new_line(
84 pattern: &str,
85 case_insensitive: bool,
86 dot_matches_new_line: bool,
87 ) -> Result<Regex, RegexCompilationError> {
88 if pattern.len() > MAX_REGEX_SIZE_BEFORE_COMPILATION {
89 return Err(RegexCompilationError::PatternTooLarge {
90 pattern_size: pattern.len(),
91 });
92 }
93 let mut regex_builder = RegexBuilder::new(pattern);
94 regex_builder.case_insensitive(case_insensitive);
95 regex_builder.dot_matches_new_line(dot_matches_new_line);
96 regex_builder.size_limit(MAX_REGEX_SIZE_AFTER_COMPILATION);
97 Ok(Regex {
98 case_insensitive,
99 dot_matches_new_line,
100 regex: regex_builder.build()?,
101 })
102 }
103
104 pub fn pattern(&self) -> &str {
106 self.regex.as_str()
109 }
110}
111
112#[derive(Debug, Clone)]
114pub enum RegexCompilationError {
115 RegexError(Error),
117 PatternTooLarge { pattern_size: usize },
119}
120
121impl fmt::Display for RegexCompilationError {
122 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123 match self {
124 RegexCompilationError::RegexError(e) => write!(f, "{}", e),
125 RegexCompilationError::PatternTooLarge {
126 pattern_size: patter_size,
127 } => write!(
128 f,
129 "regex pattern too large ({} bytes, max {} bytes)",
130 patter_size, MAX_REGEX_SIZE_BEFORE_COMPILATION
131 ),
132 }
133 }
134}
135
136impl From<Error> for RegexCompilationError {
137 fn from(e: Error) -> Self {
138 RegexCompilationError::RegexError(e)
139 }
140}
141
142impl PartialEq<Regex> for Regex {
143 fn eq(&self, other: &Regex) -> bool {
144 self.pattern() == other.pattern()
145 && self.case_insensitive == other.case_insensitive
146 && self.dot_matches_new_line == other.dot_matches_new_line
147 }
148}
149
150impl Eq for Regex {}
151
152impl PartialOrd for Regex {
153 fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
154 Some(self.cmp(other))
155 }
156}
157
158impl Ord for Regex {
159 fn cmp(&self, other: &Regex) -> Ordering {
160 (
161 self.pattern(),
162 self.case_insensitive,
163 self.dot_matches_new_line,
164 )
165 .cmp(&(
166 other.pattern(),
167 other.case_insensitive,
168 other.dot_matches_new_line,
169 ))
170 }
171}
172
173impl Hash for Regex {
174 fn hash<H: Hasher>(&self, hasher: &mut H) {
175 self.pattern().hash(hasher);
176 self.case_insensitive.hash(hasher);
177 self.dot_matches_new_line.hash(hasher);
178 }
179}
180
181impl Deref for Regex {
182 type Target = regex::Regex;
183
184 fn deref(&self) -> ®ex::Regex {
185 &self.regex
186 }
187}
188
189impl Serialize for Regex {
190 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191 where
192 S: Serializer,
193 {
194 let mut state = serializer.serialize_struct("Regex", 3)?;
195 state.serialize_field("pattern", &self.pattern())?;
196 state.serialize_field("case_insensitive", &self.case_insensitive)?;
197 state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
198 state.end()
199 }
200}
201
202impl<'de> Deserialize<'de> for Regex {
203 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
204 where
205 D: Deserializer<'de>,
206 {
207 enum Field {
208 Pattern,
209 CaseInsensitive,
210 DotMatchesNewLine,
211 }
212
213 impl<'de> Deserialize<'de> for Field {
214 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
215 where
216 D: Deserializer<'de>,
217 {
218 struct FieldVisitor;
219
220 impl<'de> de::Visitor<'de> for FieldVisitor {
221 type Value = Field;
222
223 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
224 formatter.write_str(
225 "pattern string or case_insensitive bool or dot_matches_new_line bool",
226 )
227 }
228
229 fn visit_str<E>(self, value: &str) -> Result<Field, E>
230 where
231 E: de::Error,
232 {
233 match value {
234 "pattern" => Ok(Field::Pattern),
235 "case_insensitive" => Ok(Field::CaseInsensitive),
236 "dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
237 _ => Err(de::Error::unknown_field(value, FIELDS)),
238 }
239 }
240 }
241
242 deserializer.deserialize_identifier(FieldVisitor)
243 }
244 }
245
246 struct RegexVisitor;
247
248 impl<'de> de::Visitor<'de> for RegexVisitor {
249 type Value = Regex;
250
251 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
252 formatter.write_str("Regex serialized by the manual Serialize impl from above")
253 }
254
255 fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
256 where
257 V: de::SeqAccess<'de>,
258 {
259 let pattern = seq
260 .next_element::<Cow<str>>()?
261 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
262 let case_insensitive = seq
263 .next_element()?
264 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
265 let dot_matches_new_line = seq
266 .next_element()?
267 .ok_or_else(|| de::Error::invalid_length(2, &self))?;
268 Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
269 .map_err(|err| {
270 V::Error::custom(format!(
271 "Unable to recreate regex during deserialization: {}",
272 err
273 ))
274 })
275 }
276
277 fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
278 where
279 V: de::MapAccess<'de>,
280 {
281 let mut pattern: Option<Cow<str>> = None;
282 let mut case_insensitive: Option<bool> = None;
283 let mut dot_matches_new_line: Option<bool> = None;
284 while let Some(key) = map.next_key()? {
285 match key {
286 Field::Pattern => {
287 if pattern.is_some() {
288 return Err(de::Error::duplicate_field("pattern"));
289 }
290 pattern = Some(map.next_value()?);
291 }
292 Field::CaseInsensitive => {
293 if case_insensitive.is_some() {
294 return Err(de::Error::duplicate_field("case_insensitive"));
295 }
296 case_insensitive = Some(map.next_value()?);
297 }
298 Field::DotMatchesNewLine => {
299 if dot_matches_new_line.is_some() {
300 return Err(de::Error::duplicate_field("dot_matches_new_line"));
301 }
302 dot_matches_new_line = Some(map.next_value()?);
303 }
304 }
305 }
306 let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
307 let case_insensitive =
308 case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
309 let dot_matches_new_line = dot_matches_new_line
310 .ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
311 Regex::new_dot_matches_new_line(&pattern, case_insensitive, dot_matches_new_line)
312 .map_err(|err| {
313 V::Error::custom(format!(
314 "Unable to recreate regex during deserialization: {}",
315 err
316 ))
317 })
318 }
319 }
320
321 const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
322 deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[mz_ore::test]
334 fn regex_serde_case_insensitive() {
335 let pattern = "AAA";
336 let orig_regex = Regex::new(pattern, true).unwrap();
337 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
338 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
339 assert_eq!(orig_regex.regex.is_match("aaa"), true);
343 assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
344 assert_eq!(pattern, roundtrip_result.pattern());
345 }
346
347 #[mz_ore::test]
350 fn regex_serde_dot_matches_new_line() {
351 {
352 let pattern = "A.*B";
354 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
355 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
356 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
357 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
358 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
359 assert_eq!(pattern, roundtrip_result.pattern());
360 }
361 {
362 let pattern = "A.*B";
364 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
365 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
366 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
367 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
368 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
369 assert_eq!(pattern, roundtrip_result.pattern());
370 }
371 {
372 let pattern = "A.*B";
374 let orig_regex = Regex::new(pattern, true).unwrap();
375 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
376 let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
377 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
378 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
379 assert_eq!(pattern, roundtrip_result.pattern());
380 }
381 }
382
383 #[mz_ore::test]
384 fn regex_serde_from_reader() {
385 let pattern = "A.*B";
386 let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
387
388 let serialized: String = serde_json::to_string(&orig_regex).unwrap();
389 let roundtrip_result: Regex = serde_json::from_reader(serialized.as_bytes()).unwrap();
390
391 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
392 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
393 assert_eq!(pattern, roundtrip_result.pattern());
394
395 let serialized = bincode::serialize(&orig_regex).unwrap();
396 let roundtrip_result: Regex = bincode::deserialize_from(&*serialized).unwrap();
397
398 assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
399 assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
400 assert_eq!(pattern, roundtrip_result.pattern());
401 }
402}