1use std::mem;
11use std::str::FromStr;
12
13use derivative::Derivative;
14use mz_lowertest::MzReflect;
15use mz_ore::fmt::FormatBuffer;
16use mz_repr::adt::regex::Regex;
17use serde::{Deserialize, Serialize};
18
19use crate::scalar::EvalError;
20
21const MAX_SUBPATTERNS: usize = 5;
23
24const DEFAULT_ESCAPE: char = '\\';
26const DOUBLED_ESCAPE: &str = "\\\\";
27
28#[derive(Clone, Copy, Debug)]
30pub enum EscapeBehavior {
31    Disabled,
33    Char(char),
35}
36
37impl Default for EscapeBehavior {
38    fn default() -> EscapeBehavior {
39        EscapeBehavior::Char(DEFAULT_ESCAPE)
40    }
41}
42
43impl FromStr for EscapeBehavior {
44    type Err = EvalError;
45
46    fn from_str(s: &str) -> Result<EscapeBehavior, EvalError> {
47        let mut chars = s.chars();
48        match chars.next() {
49            None => Ok(EscapeBehavior::Disabled),
50            Some(c) => match chars.next() {
51                None => Ok(EscapeBehavior::Char(c)),
52                Some(_) => Err(EvalError::LikeEscapeTooLong),
53            },
54        }
55    }
56}
57
58pub fn normalize_pattern(pattern: &str, escape: EscapeBehavior) -> Result<String, EvalError> {
60    match escape {
61        EscapeBehavior::Disabled => Ok(pattern.replace(DEFAULT_ESCAPE, DOUBLED_ESCAPE)),
62        EscapeBehavior::Char(DEFAULT_ESCAPE) => Ok(pattern.into()),
63        EscapeBehavior::Char(custom_escape_char) => {
64            let mut p = String::with_capacity(2 * pattern.len());
65            let mut cs = pattern.chars();
66            while let Some(c) = cs.next() {
67                if c == custom_escape_char {
68                    match cs.next() {
69                        Some(c2) => {
70                            p.push(DEFAULT_ESCAPE);
71                            p.push(c2);
72                        }
73                        None => return Err(EvalError::UnterminatedLikeEscapeSequence),
74                    }
75                } else if c == DEFAULT_ESCAPE {
76                    p.push_str(DOUBLED_ESCAPE);
77                } else {
78                    p.push(c);
79                }
80            }
81            p.shrink_to_fit();
82            Ok(p)
83        }
84    }
85}
86
87pub use matcher::Matcher;
99use matcher::MatcherImpl;
100
101#[allow(clippy::non_canonical_partial_ord_impl)]
105mod matcher {
106    use super::*;
107
108    #[derive(Debug, Clone, Deserialize, Serialize, Derivative, MzReflect)]
110    #[derivative(Eq, PartialEq, Ord, PartialOrd, Hash)]
111    pub struct Matcher {
112        pub pattern: String,
113        pub case_insensitive: bool,
114        #[derivative(
115            PartialEq = "ignore",
116            Hash = "ignore",
117            Ord = "ignore",
118            PartialOrd = "ignore"
119        )]
120        pub(super) matcher_impl: MatcherImpl,
121    }
122
123    impl Matcher {
124        pub fn is_match(&self, text: &str) -> bool {
125            match &self.matcher_impl {
126                MatcherImpl::String(subpatterns) => is_match_subpatterns(subpatterns, text),
127                MatcherImpl::Regex(r) => r.is_match(text),
128            }
129        }
130    }
131
132    #[derive(Debug, Clone, Deserialize, Serialize, MzReflect)]
133    pub(super) enum MatcherImpl {
134        String(Vec<Subpattern>),
135        Regex(Regex),
136    }
137}
138
139pub fn compile(pattern: &str, case_insensitive: bool) -> Result<Matcher, EvalError> {
141    if pattern.len() > 8 << 10 {
149        return Err(EvalError::LikePatternTooLong);
150    }
151    let subpatterns = build_subpatterns(pattern)?;
152    let matcher_impl = match case_insensitive || subpatterns.len() > MAX_SUBPATTERNS {
153        false => MatcherImpl::String(subpatterns),
154        true => MatcherImpl::Regex(build_regex(&subpatterns, case_insensitive)?),
155    };
156    Ok(Matcher {
157        pattern: pattern.into(),
158        case_insensitive,
159        matcher_impl,
160    })
161}
162
163#[derive(Debug, Default, Clone, Deserialize, Serialize, MzReflect)]
189struct Subpattern {
190    consume: usize,
192    many: bool,
194    suffix: String,
196}
197
198impl Subpattern {
199    fn write_regex_to(&self, r: &mut String) {
201        match self.consume {
202            0 => {
203                if self.many {
204                    r.push_str(".*");
205                }
206            }
207            1 => {
208                r.push('.');
209                if self.many {
210                    r.push('+');
211                }
212            }
213            n => {
214                r.push_str(".{");
215                write!(r, "{}", n);
216                if self.many {
217                    r.push(',');
218                }
219                r.push('}');
220            }
221        }
222        regex_syntax::escape_into(&self.suffix, r);
223    }
224}
225
226fn is_match_subpatterns(subpatterns: &[Subpattern], mut text: &str) -> bool {
227    let (subpattern, subpatterns) = match subpatterns {
228        [] => return text.is_empty(),
229        [subpattern, subpatterns @ ..] => (subpattern, subpatterns),
230    };
231    if subpattern.consume > 0 {
233        let mut chars = text.chars();
234        if chars.nth(subpattern.consume - 1).is_none() {
235            return false;
236        }
237        text = chars.as_str();
238    }
239    if subpattern.many {
240        if subpattern.suffix.len() == 0 {
256            assert!(
258                subpatterns.is_empty(),
259                "empty suffix in middle of a pattern"
260            );
261            return true;
262        }
263        let mut found = text.rfind(&subpattern.suffix);
265        loop {
266            match found {
267                None => return false,
268                Some(offset) => {
269                    let mut end = offset + subpattern.suffix.len();
270                    if is_match_subpatterns(subpatterns, &text[end..]) {
271                        return true;
272                    }
273                    if offset == 0 {
275                        return false;
276                    }
277                    loop {
279                        end -= 1;
280                        if text.is_char_boundary(end) {
281                            break;
282                        }
283                    }
284                    found = text[..end].rfind(&subpattern.suffix);
285                }
286            }
287        }
288    }
289    if !text.starts_with(&subpattern.suffix) {
291        return false;
292    }
293    is_match_subpatterns(subpatterns, &text[subpattern.suffix.len()..])
294}
295
296fn build_subpatterns(pattern: &str) -> Result<Vec<Subpattern>, EvalError> {
298    let mut subpatterns = Vec::with_capacity(MAX_SUBPATTERNS);
299    let mut current = Subpattern::default();
300    let mut in_wildcard = true;
301    let mut in_escape = false;
302    for c in pattern.chars() {
303        match c {
304            c if !in_escape && c == DEFAULT_ESCAPE => {
305                in_escape = true;
306                in_wildcard = false;
307            }
308            '_' if !in_escape => {
309                if !in_wildcard {
310                    current.suffix.shrink_to_fit();
311                    subpatterns.push(mem::take(&mut current));
312                    in_wildcard = true;
313                }
314                current.consume += 1;
315            }
316            '%' if !in_escape => {
317                if !in_wildcard {
318                    current.suffix.shrink_to_fit();
319                    subpatterns.push(mem::take(&mut current));
320                    in_wildcard = true;
321                }
322                current.many = true;
323            }
324            c => {
325                current.suffix.push(c);
326                in_escape = false;
327                in_wildcard = false;
328            }
329        }
330    }
331    if in_escape {
332        return Err(EvalError::UnterminatedLikeEscapeSequence);
333    }
334    current.suffix.shrink_to_fit();
335    subpatterns.push(current);
336    subpatterns.shrink_to_fit();
337    Ok(subpatterns)
338}
339
340fn build_regex(subpatterns: &[Subpattern], case_insensitive: bool) -> Result<Regex, EvalError> {
342    let mut r = String::from("^");
343    for sp in subpatterns {
344        sp.write_regex_to(&mut r);
345    }
346    r.push('$');
347    match Regex::new(&r, case_insensitive) {
348        Ok(regex) => Ok(regex),
349        Err(regex::Error::CompiledTooBig(_)) => Err(EvalError::LikePatternTooLong),
350        Err(e) => Err(EvalError::Internal(
351            format!("build_regex produced invalid regex: {}", e).into(),
352        )),
353    }
354}
355
356#[cfg(test)]
365mod test {
366    use super::*;
367
368    #[mz_ore::test]
369    fn test_normalize_pattern() {
370        struct TestCase<'a> {
371            pattern: &'a str,
372            escape: EscapeBehavior,
373            expected: &'a str,
374        }
375        let test_cases = vec![
376            TestCase {
377                pattern: "",
378                escape: EscapeBehavior::Disabled,
379                expected: "",
380            },
381            TestCase {
382                pattern: "ban%na!",
383                escape: EscapeBehavior::default(),
384                expected: "ban%na!",
385            },
386            TestCase {
387                pattern: "ban%%%na!",
388                escape: EscapeBehavior::Char('%'),
389                expected: "ban\\%\\na!",
390            },
391            TestCase {
392                pattern: "ban%na\\!",
393                escape: EscapeBehavior::Char('n'),
394                expected: "ba\\%\\a\\\\!",
395            },
396            TestCase {
397                pattern: "ban%na\\!",
398                escape: EscapeBehavior::Disabled,
399                expected: "ban%na\\\\!",
400            },
401            TestCase {
402                pattern: "ban\\na!",
403                escape: EscapeBehavior::Char('n'),
404                expected: "ba\\\\\\a!",
405            },
406            TestCase {
407                pattern: "ban\\\\na!",
408                escape: EscapeBehavior::Char('n'),
409                expected: "ba\\\\\\\\\\a!",
410            },
411            TestCase {
412                pattern: "food",
413                escape: EscapeBehavior::Char('o'),
414                expected: "f\\od",
415            },
416            TestCase {
417                pattern: "漢漢",
418                escape: EscapeBehavior::Char('漢'),
419                expected: "\\漢",
420            },
421        ];
422
423        for input in test_cases {
424            let actual = normalize_pattern(input.pattern, input.escape).unwrap();
425            assert!(
426                actual == input.expected,
427                "normalize_pattern({:?}, {:?}):\n\tactual: {:?}\n\texpected: {:?}\n",
428                input.pattern,
429                input.escape,
430                actual,
431                input.expected,
432            );
433        }
434    }
435
436    #[mz_ore::test]
437    fn test_escape_too_long() {
438        match EscapeBehavior::from_str("foo") {
439            Err(EvalError::LikeEscapeTooLong) => {}
440            _ => {
441                panic!("expected error when using escape string with >1 character");
442            }
443        }
444    }
445
446    #[mz_ore::test]
447    fn test_like() {
448        struct Input<'a> {
449            haystack: &'a str,
450            matches: bool,
451        }
452        let input = |haystack, matches| Input { haystack, matches };
453        struct Pattern<'a> {
454            needle: &'a str,
455            case_insensitive: bool,
456            inputs: Vec<Input<'a>>,
457        }
458        let test_cases = vec![
459            Pattern {
460                needle: "ban%na!",
461                case_insensitive: false,
462                inputs: vec![input("banana!", true)],
463            },
464            Pattern {
465                needle: "foo",
466                case_insensitive: true,
467                inputs: vec![
468                    input("", false),
469                    input("f", false),
470                    input("fo", false),
471                    input("foo", true),
472                    input("FOO", true),
473                    input("Foo", true),
474                    input("fOO", true),
475                    input("food", false),
476                ],
477            },
478        ];
479
480        for tc in test_cases {
481            let matcher = compile(tc.needle, tc.case_insensitive).unwrap();
482            for input in tc.inputs {
483                let actual = matcher.is_match(input.haystack);
484                assert!(
485                    actual == input.matches,
486                    "{:?} {} {:?}:\n\tactual: {:?}\n\texpected: {:?}\n",
487                    input.haystack,
488                    match tc.case_insensitive {
489                        true => "ILIKE",
490                        false => "LIKE",
491                    },
492                    tc.needle,
493                    actual,
494                    input.matches,
495                );
496            }
497        }
498    }
499}