1use std::mem;
11use std::str::FromStr;
12
13use derivative::Derivative;
14use mz_lowertest::MzReflect;
15use mz_ore::fmt::FormatBuffer;
16use mz_repr::adt::regex::{Regex, RegexCompilationError};
17use serde::{Deserialize, Serialize};
18
19use crate::scalar::EvalError;
20
21const MAX_SUBPATTERNS: usize = 5;
23
24const DEFAULT_ESCAPE: char = '\\';
26const DOUBLED_ESCAPE: &str = "\\\\";
27
28#[derive(Clone, Copy, Debug)]
30pub enum EscapeBehavior {
31 Disabled,
33 Char(char),
35}
36
37impl Default for EscapeBehavior {
38 fn default() -> EscapeBehavior {
39 EscapeBehavior::Char(DEFAULT_ESCAPE)
40 }
41}
42
43impl FromStr for EscapeBehavior {
44 type Err = EvalError;
45
46 fn from_str(s: &str) -> Result<EscapeBehavior, EvalError> {
47 let mut chars = s.chars();
48 match chars.next() {
49 None => Ok(EscapeBehavior::Disabled),
50 Some(c) => match chars.next() {
51 None => Ok(EscapeBehavior::Char(c)),
52 Some(_) => Err(EvalError::LikeEscapeTooLong),
53 },
54 }
55 }
56}
57
58pub fn normalize_pattern(pattern: &str, escape: EscapeBehavior) -> Result<String, EvalError> {
60 match escape {
61 EscapeBehavior::Disabled => Ok(pattern.replace(DEFAULT_ESCAPE, DOUBLED_ESCAPE)),
62 EscapeBehavior::Char(DEFAULT_ESCAPE) => Ok(pattern.into()),
63 EscapeBehavior::Char(custom_escape_char) => {
64 let mut p = String::with_capacity(2 * pattern.len());
65 let mut cs = pattern.chars();
66 while let Some(c) = cs.next() {
67 if c == custom_escape_char {
68 match cs.next() {
69 Some(c2) => {
70 p.push(DEFAULT_ESCAPE);
71 p.push(c2);
72 }
73 None => return Err(EvalError::UnterminatedLikeEscapeSequence),
74 }
75 } else if c == DEFAULT_ESCAPE {
76 p.push_str(DOUBLED_ESCAPE);
77 } else {
78 p.push(c);
79 }
80 }
81 p.shrink_to_fit();
82 Ok(p)
83 }
84 }
85}
86
87pub use matcher::Matcher;
99use matcher::MatcherImpl;
100
101#[allow(clippy::non_canonical_partial_ord_impl)]
105mod matcher {
106 use super::*;
107
108 #[derive(Debug, Clone, Deserialize, Serialize, Derivative, MzReflect)]
110 #[derivative(Eq, PartialEq, Ord, PartialOrd, Hash)]
111 pub struct Matcher {
112 pub pattern: String,
113 pub case_insensitive: bool,
114 #[derivative(
115 PartialEq = "ignore",
116 Hash = "ignore",
117 Ord = "ignore",
118 PartialOrd = "ignore"
119 )]
120 pub(super) matcher_impl: MatcherImpl,
121 }
122
123 impl Matcher {
124 pub fn is_match(&self, text: &str) -> bool {
125 match &self.matcher_impl {
126 MatcherImpl::String(subpatterns) => is_match_subpatterns(subpatterns, text),
127 MatcherImpl::Regex(r) => r.is_match(text),
128 }
129 }
130 }
131
132 #[derive(Debug, Clone, Deserialize, Serialize, MzReflect)]
133 pub(super) enum MatcherImpl {
134 String(Vec<Subpattern>),
135 Regex(Regex),
136 }
137}
138
139pub fn compile(pattern: &str, case_insensitive: bool) -> Result<Matcher, EvalError> {
141 if pattern.len() > 8 << 10 {
149 return Err(EvalError::LikePatternTooLong);
150 }
151 let subpatterns = build_subpatterns(pattern)?;
152 let matcher_impl = match case_insensitive || subpatterns.len() > MAX_SUBPATTERNS {
153 false => MatcherImpl::String(subpatterns),
154 true => MatcherImpl::Regex(build_regex(&subpatterns, case_insensitive)?),
155 };
156 Ok(Matcher {
157 pattern: pattern.into(),
158 case_insensitive,
159 matcher_impl,
160 })
161}
162
163#[derive(Debug, Default, Clone, Deserialize, Serialize, MzReflect)]
189struct Subpattern {
190 consume: usize,
192 many: bool,
194 suffix: String,
196}
197
198impl Subpattern {
199 fn write_regex_to(&self, r: &mut String) {
201 match self.consume {
202 0 => {
203 if self.many {
204 r.push_str(".*");
205 }
206 }
207 1 => {
208 r.push('.');
209 if self.many {
210 r.push('+');
211 }
212 }
213 n => {
214 r.push_str(".{");
215 write!(r, "{}", n);
216 if self.many {
217 r.push(',');
218 }
219 r.push('}');
220 }
221 }
222 regex_syntax::escape_into(&self.suffix, r);
223 }
224}
225
226fn is_match_subpatterns(subpatterns: &[Subpattern], mut text: &str) -> bool {
227 let (subpattern, subpatterns) = match subpatterns {
228 [] => return text.is_empty(),
229 [subpattern, subpatterns @ ..] => (subpattern, subpatterns),
230 };
231 if subpattern.consume > 0 {
233 let mut chars = text.chars();
234 if chars.nth(subpattern.consume - 1).is_none() {
235 return false;
236 }
237 text = chars.as_str();
238 }
239 if subpattern.many {
240 if subpattern.suffix.len() == 0 {
256 assert!(
258 subpatterns.is_empty(),
259 "empty suffix in middle of a pattern"
260 );
261 return true;
262 }
263 let mut found = text.rfind(&subpattern.suffix);
265 loop {
266 match found {
267 None => return false,
268 Some(offset) => {
269 let mut end = offset + subpattern.suffix.len();
270 if is_match_subpatterns(subpatterns, &text[end..]) {
271 return true;
272 }
273 if offset == 0 {
275 return false;
276 }
277 loop {
279 end -= 1;
280 if text.is_char_boundary(end) {
281 break;
282 }
283 }
284 found = text[..end].rfind(&subpattern.suffix);
285 }
286 }
287 }
288 }
289 if !text.starts_with(&subpattern.suffix) {
291 return false;
292 }
293 is_match_subpatterns(subpatterns, &text[subpattern.suffix.len()..])
294}
295
296fn build_subpatterns(pattern: &str) -> Result<Vec<Subpattern>, EvalError> {
298 let mut subpatterns = Vec::with_capacity(MAX_SUBPATTERNS);
299 let mut current = Subpattern::default();
300 let mut in_wildcard = true;
301 let mut in_escape = false;
302 for c in pattern.chars() {
303 match c {
304 c if !in_escape && c == DEFAULT_ESCAPE => {
305 in_escape = true;
306 in_wildcard = false;
307 }
308 '_' if !in_escape => {
309 if !in_wildcard {
310 current.suffix.shrink_to_fit();
311 subpatterns.push(mem::take(&mut current));
312 in_wildcard = true;
313 }
314 current.consume += 1;
315 }
316 '%' if !in_escape => {
317 if !in_wildcard {
318 current.suffix.shrink_to_fit();
319 subpatterns.push(mem::take(&mut current));
320 in_wildcard = true;
321 }
322 current.many = true;
323 }
324 c => {
325 current.suffix.push(c);
326 in_escape = false;
327 in_wildcard = false;
328 }
329 }
330 }
331 if in_escape {
332 return Err(EvalError::UnterminatedLikeEscapeSequence);
333 }
334 current.suffix.shrink_to_fit();
335 subpatterns.push(current);
336 subpatterns.shrink_to_fit();
337 Ok(subpatterns)
338}
339
340fn build_regex(subpatterns: &[Subpattern], case_insensitive: bool) -> Result<Regex, EvalError> {
342 let mut r = String::from("^");
343 for sp in subpatterns {
344 sp.write_regex_to(&mut r);
345 }
346 r.push('$');
347 match Regex::new(&r, case_insensitive) {
348 Ok(regex) => Ok(regex),
349 Err(RegexCompilationError::PatternTooLarge { .. }) => Err(EvalError::LikePatternTooLong),
350 Err(RegexCompilationError::RegexError(regex::Error::CompiledTooBig(_))) => {
351 Err(EvalError::LikePatternTooLong)
352 }
353 Err(e) => Err(EvalError::Internal(
354 format!("build_regex produced invalid regex: {}", e).into(),
355 )),
356 }
357}
358
359#[cfg(test)]
368mod test {
369 use super::*;
370
371 #[mz_ore::test]
372 fn test_normalize_pattern() {
373 struct TestCase<'a> {
374 pattern: &'a str,
375 escape: EscapeBehavior,
376 expected: &'a str,
377 }
378 let test_cases = vec![
379 TestCase {
380 pattern: "",
381 escape: EscapeBehavior::Disabled,
382 expected: "",
383 },
384 TestCase {
385 pattern: "ban%na!",
386 escape: EscapeBehavior::default(),
387 expected: "ban%na!",
388 },
389 TestCase {
390 pattern: "ban%%%na!",
391 escape: EscapeBehavior::Char('%'),
392 expected: "ban\\%\\na!",
393 },
394 TestCase {
395 pattern: "ban%na\\!",
396 escape: EscapeBehavior::Char('n'),
397 expected: "ba\\%\\a\\\\!",
398 },
399 TestCase {
400 pattern: "ban%na\\!",
401 escape: EscapeBehavior::Disabled,
402 expected: "ban%na\\\\!",
403 },
404 TestCase {
405 pattern: "ban\\na!",
406 escape: EscapeBehavior::Char('n'),
407 expected: "ba\\\\\\a!",
408 },
409 TestCase {
410 pattern: "ban\\\\na!",
411 escape: EscapeBehavior::Char('n'),
412 expected: "ba\\\\\\\\\\a!",
413 },
414 TestCase {
415 pattern: "food",
416 escape: EscapeBehavior::Char('o'),
417 expected: "f\\od",
418 },
419 TestCase {
420 pattern: "漢漢",
421 escape: EscapeBehavior::Char('漢'),
422 expected: "\\漢",
423 },
424 ];
425
426 for input in test_cases {
427 let actual = normalize_pattern(input.pattern, input.escape).unwrap();
428 assert!(
429 actual == input.expected,
430 "normalize_pattern({:?}, {:?}):\n\tactual: {:?}\n\texpected: {:?}\n",
431 input.pattern,
432 input.escape,
433 actual,
434 input.expected,
435 );
436 }
437 }
438
439 #[mz_ore::test]
440 fn test_escape_too_long() {
441 match EscapeBehavior::from_str("foo") {
442 Err(EvalError::LikeEscapeTooLong) => {}
443 _ => {
444 panic!("expected error when using escape string with >1 character");
445 }
446 }
447 }
448
449 #[mz_ore::test]
450 fn test_like() {
451 struct Input<'a> {
452 haystack: &'a str,
453 matches: bool,
454 }
455 let input = |haystack, matches| Input { haystack, matches };
456 struct Pattern<'a> {
457 needle: &'a str,
458 case_insensitive: bool,
459 inputs: Vec<Input<'a>>,
460 }
461 let test_cases = vec![
462 Pattern {
463 needle: "ban%na!",
464 case_insensitive: false,
465 inputs: vec![input("banana!", true)],
466 },
467 Pattern {
468 needle: "foo",
469 case_insensitive: true,
470 inputs: vec![
471 input("", false),
472 input("f", false),
473 input("fo", false),
474 input("foo", true),
475 input("FOO", true),
476 input("Foo", true),
477 input("fOO", true),
478 input("food", false),
479 ],
480 },
481 ];
482
483 for tc in test_cases {
484 let matcher = compile(tc.needle, tc.case_insensitive).unwrap();
485 for input in tc.inputs {
486 let actual = matcher.is_match(input.haystack);
487 assert!(
488 actual == input.matches,
489 "{:?} {} {:?}:\n\tactual: {:?}\n\texpected: {:?}\n",
490 input.haystack,
491 match tc.case_insensitive {
492 true => "ILIKE",
493 false => "LIKE",
494 },
495 tc.needle,
496 actual,
497 input.matches,
498 );
499 }
500 }
501 }
502}