1use std::mem;
11use std::str::FromStr;
12
13use derivative::Derivative;
14use mz_lowertest::MzReflect;
15use mz_ore::fmt::FormatBuffer;
16use mz_repr::adt::regex::Regex;
17use serde::{Deserialize, Serialize};
18
19use crate::scalar::EvalError;
20
21const MAX_SUBPATTERNS: usize = 5;
23
24const DEFAULT_ESCAPE: char = '\\';
26const DOUBLED_ESCAPE: &str = "\\\\";
27
28#[derive(Clone, Copy, Debug)]
30pub enum EscapeBehavior {
31 Disabled,
33 Char(char),
35}
36
37impl Default for EscapeBehavior {
38 fn default() -> EscapeBehavior {
39 EscapeBehavior::Char(DEFAULT_ESCAPE)
40 }
41}
42
43impl FromStr for EscapeBehavior {
44 type Err = EvalError;
45
46 fn from_str(s: &str) -> Result<EscapeBehavior, EvalError> {
47 let mut chars = s.chars();
48 match chars.next() {
49 None => Ok(EscapeBehavior::Disabled),
50 Some(c) => match chars.next() {
51 None => Ok(EscapeBehavior::Char(c)),
52 Some(_) => Err(EvalError::LikeEscapeTooLong),
53 },
54 }
55 }
56}
57
58pub fn normalize_pattern(pattern: &str, escape: EscapeBehavior) -> Result<String, EvalError> {
60 match escape {
61 EscapeBehavior::Disabled => Ok(pattern.replace(DEFAULT_ESCAPE, DOUBLED_ESCAPE)),
62 EscapeBehavior::Char(DEFAULT_ESCAPE) => Ok(pattern.into()),
63 EscapeBehavior::Char(custom_escape_char) => {
64 let mut p = String::with_capacity(2 * pattern.len());
65 let mut cs = pattern.chars();
66 while let Some(c) = cs.next() {
67 if c == custom_escape_char {
68 match cs.next() {
69 Some(c2) => {
70 p.push(DEFAULT_ESCAPE);
71 p.push(c2);
72 }
73 None => return Err(EvalError::UnterminatedLikeEscapeSequence),
74 }
75 } else if c == DEFAULT_ESCAPE {
76 p.push_str(DOUBLED_ESCAPE);
77 } else {
78 p.push(c);
79 }
80 }
81 p.shrink_to_fit();
82 Ok(p)
83 }
84 }
85}
86
87pub use matcher::Matcher;
99use matcher::MatcherImpl;
100
101#[allow(clippy::non_canonical_partial_ord_impl)]
105mod matcher {
106 use super::*;
107
108 #[derive(Debug, Clone, Deserialize, Serialize, Derivative, MzReflect)]
110 #[derivative(Eq, PartialEq, Ord, PartialOrd, Hash)]
111 pub struct Matcher {
112 pub pattern: String,
113 pub case_insensitive: bool,
114 #[derivative(
115 PartialEq = "ignore",
116 Hash = "ignore",
117 Ord = "ignore",
118 PartialOrd = "ignore"
119 )]
120 pub(super) matcher_impl: MatcherImpl,
121 }
122
123 impl Matcher {
124 pub fn is_match(&self, text: &str) -> bool {
125 match &self.matcher_impl {
126 MatcherImpl::String(subpatterns) => is_match_subpatterns(subpatterns, text),
127 MatcherImpl::Regex(r) => r.is_match(text),
128 }
129 }
130 }
131
132 #[derive(Debug, Clone, Deserialize, Serialize, MzReflect)]
133 pub(super) enum MatcherImpl {
134 String(Vec<Subpattern>),
135 Regex(Regex),
136 }
137}
138
139pub fn compile(pattern: &str, case_insensitive: bool) -> Result<Matcher, EvalError> {
141 if pattern.len() > 8 << 10 {
149 return Err(EvalError::LikePatternTooLong);
150 }
151 let subpatterns = build_subpatterns(pattern)?;
152 let matcher_impl = match case_insensitive || subpatterns.len() > MAX_SUBPATTERNS {
153 false => MatcherImpl::String(subpatterns),
154 true => MatcherImpl::Regex(build_regex(&subpatterns, case_insensitive)?),
155 };
156 Ok(Matcher {
157 pattern: pattern.into(),
158 case_insensitive,
159 matcher_impl,
160 })
161}
162
163#[derive(Debug, Default, Clone, Deserialize, Serialize, MzReflect)]
189struct Subpattern {
190 consume: usize,
192 many: bool,
194 suffix: String,
196}
197
198impl Subpattern {
199 fn write_regex_to(&self, r: &mut String) {
201 match self.consume {
202 0 => {
203 if self.many {
204 r.push_str(".*");
205 }
206 }
207 1 => {
208 r.push('.');
209 if self.many {
210 r.push('+');
211 }
212 }
213 n => {
214 r.push_str(".{");
215 write!(r, "{}", n);
216 if self.many {
217 r.push(',');
218 }
219 r.push('}');
220 }
221 }
222 regex_syntax::escape_into(&self.suffix, r);
223 }
224}
225
226fn is_match_subpatterns(subpatterns: &[Subpattern], mut text: &str) -> bool {
227 let (subpattern, subpatterns) = match subpatterns {
228 [] => return text.is_empty(),
229 [subpattern, subpatterns @ ..] => (subpattern, subpatterns),
230 };
231 if subpattern.consume > 0 {
233 let mut chars = text.chars();
234 if chars.nth(subpattern.consume - 1).is_none() {
235 return false;
236 }
237 text = chars.as_str();
238 }
239 if subpattern.many {
240 if subpattern.suffix.len() == 0 {
256 assert!(
258 subpatterns.is_empty(),
259 "empty suffix in middle of a pattern"
260 );
261 return true;
262 }
263 let mut found = text.rfind(&subpattern.suffix);
265 loop {
266 match found {
267 None => return false,
268 Some(offset) => {
269 let mut end = offset + subpattern.suffix.len();
270 if is_match_subpatterns(subpatterns, &text[end..]) {
271 return true;
272 }
273 if offset == 0 {
275 return false;
276 }
277 loop {
279 end -= 1;
280 if text.is_char_boundary(end) {
281 break;
282 }
283 }
284 found = text[..end].rfind(&subpattern.suffix);
285 }
286 }
287 }
288 }
289 if !text.starts_with(&subpattern.suffix) {
291 return false;
292 }
293 is_match_subpatterns(subpatterns, &text[subpattern.suffix.len()..])
294}
295
296fn build_subpatterns(pattern: &str) -> Result<Vec<Subpattern>, EvalError> {
298 let mut subpatterns = Vec::with_capacity(MAX_SUBPATTERNS);
299 let mut current = Subpattern::default();
300 let mut in_wildcard = true;
301 let mut in_escape = false;
302 for c in pattern.chars() {
303 match c {
304 c if !in_escape && c == DEFAULT_ESCAPE => {
305 in_escape = true;
306 in_wildcard = false;
307 }
308 '_' if !in_escape => {
309 if !in_wildcard {
310 current.suffix.shrink_to_fit();
311 subpatterns.push(mem::take(&mut current));
312 in_wildcard = true;
313 }
314 current.consume += 1;
315 }
316 '%' if !in_escape => {
317 if !in_wildcard {
318 current.suffix.shrink_to_fit();
319 subpatterns.push(mem::take(&mut current));
320 in_wildcard = true;
321 }
322 current.many = true;
323 }
324 c => {
325 current.suffix.push(c);
326 in_escape = false;
327 in_wildcard = false;
328 }
329 }
330 }
331 if in_escape {
332 return Err(EvalError::UnterminatedLikeEscapeSequence);
333 }
334 current.suffix.shrink_to_fit();
335 subpatterns.push(current);
336 subpatterns.shrink_to_fit();
337 Ok(subpatterns)
338}
339
340fn build_regex(subpatterns: &[Subpattern], case_insensitive: bool) -> Result<Regex, EvalError> {
342 let mut r = String::from("^");
343 for sp in subpatterns {
344 sp.write_regex_to(&mut r);
345 }
346 r.push('$');
347 match Regex::new(&r, case_insensitive) {
348 Ok(regex) => Ok(regex),
349 Err(regex::Error::CompiledTooBig(_)) => Err(EvalError::LikePatternTooLong),
350 Err(e) => Err(EvalError::Internal(
351 format!("build_regex produced invalid regex: {}", e).into(),
352 )),
353 }
354}
355
356#[cfg(test)]
365mod test {
366 use super::*;
367
368 #[mz_ore::test]
369 fn test_normalize_pattern() {
370 struct TestCase<'a> {
371 pattern: &'a str,
372 escape: EscapeBehavior,
373 expected: &'a str,
374 }
375 let test_cases = vec![
376 TestCase {
377 pattern: "",
378 escape: EscapeBehavior::Disabled,
379 expected: "",
380 },
381 TestCase {
382 pattern: "ban%na!",
383 escape: EscapeBehavior::default(),
384 expected: "ban%na!",
385 },
386 TestCase {
387 pattern: "ban%%%na!",
388 escape: EscapeBehavior::Char('%'),
389 expected: "ban\\%\\na!",
390 },
391 TestCase {
392 pattern: "ban%na\\!",
393 escape: EscapeBehavior::Char('n'),
394 expected: "ba\\%\\a\\\\!",
395 },
396 TestCase {
397 pattern: "ban%na\\!",
398 escape: EscapeBehavior::Disabled,
399 expected: "ban%na\\\\!",
400 },
401 TestCase {
402 pattern: "ban\\na!",
403 escape: EscapeBehavior::Char('n'),
404 expected: "ba\\\\\\a!",
405 },
406 TestCase {
407 pattern: "ban\\\\na!",
408 escape: EscapeBehavior::Char('n'),
409 expected: "ba\\\\\\\\\\a!",
410 },
411 TestCase {
412 pattern: "food",
413 escape: EscapeBehavior::Char('o'),
414 expected: "f\\od",
415 },
416 TestCase {
417 pattern: "漢漢",
418 escape: EscapeBehavior::Char('漢'),
419 expected: "\\漢",
420 },
421 ];
422
423 for input in test_cases {
424 let actual = normalize_pattern(input.pattern, input.escape).unwrap();
425 assert!(
426 actual == input.expected,
427 "normalize_pattern({:?}, {:?}):\n\tactual: {:?}\n\texpected: {:?}\n",
428 input.pattern,
429 input.escape,
430 actual,
431 input.expected,
432 );
433 }
434 }
435
436 #[mz_ore::test]
437 fn test_escape_too_long() {
438 match EscapeBehavior::from_str("foo") {
439 Err(EvalError::LikeEscapeTooLong) => {}
440 _ => {
441 panic!("expected error when using escape string with >1 character");
442 }
443 }
444 }
445
446 #[mz_ore::test]
447 fn test_like() {
448 struct Input<'a> {
449 haystack: &'a str,
450 matches: bool,
451 }
452 let input = |haystack, matches| Input { haystack, matches };
453 struct Pattern<'a> {
454 needle: &'a str,
455 case_insensitive: bool,
456 inputs: Vec<Input<'a>>,
457 }
458 let test_cases = vec![
459 Pattern {
460 needle: "ban%na!",
461 case_insensitive: false,
462 inputs: vec![input("banana!", true)],
463 },
464 Pattern {
465 needle: "foo",
466 case_insensitive: true,
467 inputs: vec![
468 input("", false),
469 input("f", false),
470 input("fo", false),
471 input("foo", true),
472 input("FOO", true),
473 input("Foo", true),
474 input("fOO", true),
475 input("food", false),
476 ],
477 },
478 ];
479
480 for tc in test_cases {
481 let matcher = compile(tc.needle, tc.case_insensitive).unwrap();
482 for input in tc.inputs {
483 let actual = matcher.is_match(input.haystack);
484 assert!(
485 actual == input.matches,
486 "{:?} {} {:?}:\n\tactual: {:?}\n\texpected: {:?}\n",
487 input.haystack,
488 match tc.case_insensitive {
489 true => "ILIKE",
490 false => "LIKE",
491 },
492 tc.needle,
493 actual,
494 input.matches,
495 );
496 }
497 }
498 }
499}