1use std::mem;
11use std::str::FromStr;
12
13use derivative::Derivative;
14use mz_lowertest::MzReflect;
15use mz_ore::fmt::FormatBuffer;
16use mz_proto::{ProtoType, RustType, TryFromProtoError};
17use mz_repr::adt::regex::Regex;
18use proptest::prelude::{Arbitrary, Strategy};
19use serde::{Deserialize, Serialize};
20
21use crate::scalar::EvalError;
22
23include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.like_pattern.rs"));
24
25const MAX_SUBPATTERNS: usize = 5;
27
28const DEFAULT_ESCAPE: char = '\\';
30const DOUBLED_ESCAPE: &str = "\\\\";
31
32#[derive(Clone, Copy, Debug)]
34pub enum EscapeBehavior {
35 Disabled,
37 Char(char),
39}
40
41impl Default for EscapeBehavior {
42 fn default() -> EscapeBehavior {
43 EscapeBehavior::Char(DEFAULT_ESCAPE)
44 }
45}
46
47impl FromStr for EscapeBehavior {
48 type Err = EvalError;
49
50 fn from_str(s: &str) -> Result<EscapeBehavior, EvalError> {
51 let mut chars = s.chars();
52 match chars.next() {
53 None => Ok(EscapeBehavior::Disabled),
54 Some(c) => match chars.next() {
55 None => Ok(EscapeBehavior::Char(c)),
56 Some(_) => Err(EvalError::LikeEscapeTooLong),
57 },
58 }
59 }
60}
61
62pub fn normalize_pattern(pattern: &str, escape: EscapeBehavior) -> Result<String, EvalError> {
64 match escape {
65 EscapeBehavior::Disabled => Ok(pattern.replace(DEFAULT_ESCAPE, DOUBLED_ESCAPE)),
66 EscapeBehavior::Char(DEFAULT_ESCAPE) => Ok(pattern.into()),
67 EscapeBehavior::Char(custom_escape_char) => {
68 let mut p = String::with_capacity(2 * pattern.len());
69 let mut cs = pattern.chars();
70 while let Some(c) = cs.next() {
71 if c == custom_escape_char {
72 match cs.next() {
73 Some(c2) => {
74 p.push(DEFAULT_ESCAPE);
75 p.push(c2);
76 }
77 None => return Err(EvalError::UnterminatedLikeEscapeSequence),
78 }
79 } else if c == DEFAULT_ESCAPE {
80 p.push_str(DOUBLED_ESCAPE);
81 } else {
82 p.push(c);
83 }
84 }
85 p.shrink_to_fit();
86 Ok(p)
87 }
88 }
89}
90
91pub use matcher::Matcher;
103use matcher::MatcherImpl;
104
105#[allow(clippy::non_canonical_partial_ord_impl)]
109mod matcher {
110 use super::*;
111
112 #[derive(Debug, Clone, Deserialize, Serialize, Derivative, MzReflect)]
114 #[derivative(Eq, PartialEq, Ord, PartialOrd, Hash)]
115 pub struct Matcher {
116 pub pattern: String,
117 pub case_insensitive: bool,
118 #[derivative(
119 PartialEq = "ignore",
120 Hash = "ignore",
121 Ord = "ignore",
122 PartialOrd = "ignore"
123 )]
124 pub(super) matcher_impl: MatcherImpl,
125 }
126
127 impl Matcher {
128 pub fn is_match(&self, text: &str) -> bool {
129 match &self.matcher_impl {
130 MatcherImpl::String(subpatterns) => is_match_subpatterns(subpatterns, text),
131 MatcherImpl::Regex(r) => r.is_match(text),
132 }
133 }
134 }
135
136 impl RustType<ProtoMatcher> for Matcher {
137 fn into_proto(&self) -> ProtoMatcher {
138 ProtoMatcher {
139 pattern: self.pattern.clone(),
140 case_insensitive: self.case_insensitive,
141 }
142 }
143
144 fn from_proto(proto: ProtoMatcher) -> Result<Self, TryFromProtoError> {
145 compile(proto.pattern.as_str(), proto.case_insensitive).map_err(|eval_err| {
146 TryFromProtoError::LikePatternDeserializationError(eval_err.to_string())
147 })
148 }
149 }
150
151 #[derive(Debug, Clone, Deserialize, Serialize, MzReflect)]
152 pub(super) enum MatcherImpl {
153 String(Vec<Subpattern>),
154 Regex(Regex),
155 }
156}
157
158pub fn compile(pattern: &str, case_insensitive: bool) -> Result<Matcher, EvalError> {
160 if pattern.len() > 8 << 10 {
168 return Err(EvalError::LikePatternTooLong);
169 }
170 let subpatterns = build_subpatterns(pattern)?;
171 let matcher_impl = match case_insensitive || subpatterns.len() > MAX_SUBPATTERNS {
172 false => MatcherImpl::String(subpatterns),
173 true => MatcherImpl::Regex(build_regex(&subpatterns, case_insensitive)?),
174 };
175 Ok(Matcher {
176 pattern: pattern.into(),
177 case_insensitive,
178 matcher_impl,
179 })
180}
181
182pub fn any_matcher() -> impl Strategy<Value = Matcher> {
183 (
198 r"([[:alnum:]]|[[:cntrl:]]|([[[:punct:]]&&[^\\]])|[[:space:]]|华|_|%|(\\_)|(\\%)|(\\\\)){0, 50}",
199 bool::arbitrary(),
200 )
201 .prop_map(|(pattern, case_insensitive)| compile(&pattern, case_insensitive).unwrap())
202}
203
204#[derive(Debug, Default, Clone, Deserialize, Serialize, MzReflect)]
230struct Subpattern {
231 consume: usize,
233 many: bool,
235 suffix: String,
237}
238
239impl Subpattern {
240 fn write_regex_to(&self, r: &mut String) {
242 match self.consume {
243 0 => {
244 if self.many {
245 r.push_str(".*");
246 }
247 }
248 1 => {
249 r.push('.');
250 if self.many {
251 r.push('+');
252 }
253 }
254 n => {
255 r.push_str(".{");
256 write!(r, "{}", n);
257 if self.many {
258 r.push(',');
259 }
260 r.push('}');
261 }
262 }
263 regex_syntax::escape_into(&self.suffix, r);
264 }
265}
266
267impl RustType<ProtoSubpattern> for Subpattern {
268 fn into_proto(&self) -> ProtoSubpattern {
269 ProtoSubpattern {
270 consume: self.consume.into_proto(),
271 many: self.many,
272 suffix: self.suffix.clone(),
273 }
274 }
275
276 fn from_proto(proto: ProtoSubpattern) -> Result<Self, TryFromProtoError> {
277 Ok(Subpattern {
278 consume: proto.consume.into_rust()?,
279 many: proto.many,
280 suffix: proto.suffix,
281 })
282 }
283}
284
285fn is_match_subpatterns(subpatterns: &[Subpattern], mut text: &str) -> bool {
286 let (subpattern, subpatterns) = match subpatterns {
287 [] => return text.is_empty(),
288 [subpattern, subpatterns @ ..] => (subpattern, subpatterns),
289 };
290 if subpattern.consume > 0 {
292 let mut chars = text.chars();
293 if chars.nth(subpattern.consume - 1).is_none() {
294 return false;
295 }
296 text = chars.as_str();
297 }
298 if subpattern.many {
299 if subpattern.suffix.len() == 0 {
315 assert!(
317 subpatterns.is_empty(),
318 "empty suffix in middle of a pattern"
319 );
320 return true;
321 }
322 let mut found = text.rfind(&subpattern.suffix);
324 loop {
325 match found {
326 None => return false,
327 Some(offset) => {
328 let mut end = offset + subpattern.suffix.len();
329 if is_match_subpatterns(subpatterns, &text[end..]) {
330 return true;
331 }
332 if offset == 0 {
334 return false;
335 }
336 loop {
338 end -= 1;
339 if text.is_char_boundary(end) {
340 break;
341 }
342 }
343 found = text[..end].rfind(&subpattern.suffix);
344 }
345 }
346 }
347 }
348 if !text.starts_with(&subpattern.suffix) {
350 return false;
351 }
352 is_match_subpatterns(subpatterns, &text[subpattern.suffix.len()..])
353}
354
355fn build_subpatterns(pattern: &str) -> Result<Vec<Subpattern>, EvalError> {
357 let mut subpatterns = Vec::with_capacity(MAX_SUBPATTERNS);
358 let mut current = Subpattern::default();
359 let mut in_wildcard = true;
360 let mut in_escape = false;
361 for c in pattern.chars() {
362 match c {
363 c if !in_escape && c == DEFAULT_ESCAPE => {
364 in_escape = true;
365 in_wildcard = false;
366 }
367 '_' if !in_escape => {
368 if !in_wildcard {
369 current.suffix.shrink_to_fit();
370 subpatterns.push(mem::take(&mut current));
371 in_wildcard = true;
372 }
373 current.consume += 1;
374 }
375 '%' if !in_escape => {
376 if !in_wildcard {
377 current.suffix.shrink_to_fit();
378 subpatterns.push(mem::take(&mut current));
379 in_wildcard = true;
380 }
381 current.many = true;
382 }
383 c => {
384 current.suffix.push(c);
385 in_escape = false;
386 in_wildcard = false;
387 }
388 }
389 }
390 if in_escape {
391 return Err(EvalError::UnterminatedLikeEscapeSequence);
392 }
393 current.suffix.shrink_to_fit();
394 subpatterns.push(current);
395 subpatterns.shrink_to_fit();
396 Ok(subpatterns)
397}
398
399fn build_regex(subpatterns: &[Subpattern], case_insensitive: bool) -> Result<Regex, EvalError> {
401 let mut r = String::from("^");
402 for sp in subpatterns {
403 sp.write_regex_to(&mut r);
404 }
405 r.push('$');
406 match Regex::new(&r, case_insensitive) {
407 Ok(regex) => Ok(regex),
408 Err(regex::Error::CompiledTooBig(_)) => Err(EvalError::LikePatternTooLong),
409 Err(e) => Err(EvalError::Internal(
410 format!("build_regex produced invalid regex: {}", e).into(),
411 )),
412 }
413}
414
415#[cfg(test)]
424mod test {
425 use super::*;
426
427 #[mz_ore::test]
428 fn test_normalize_pattern() {
429 struct TestCase<'a> {
430 pattern: &'a str,
431 escape: EscapeBehavior,
432 expected: &'a str,
433 }
434 let test_cases = vec![
435 TestCase {
436 pattern: "",
437 escape: EscapeBehavior::Disabled,
438 expected: "",
439 },
440 TestCase {
441 pattern: "ban%na!",
442 escape: EscapeBehavior::default(),
443 expected: "ban%na!",
444 },
445 TestCase {
446 pattern: "ban%%%na!",
447 escape: EscapeBehavior::Char('%'),
448 expected: "ban\\%\\na!",
449 },
450 TestCase {
451 pattern: "ban%na\\!",
452 escape: EscapeBehavior::Char('n'),
453 expected: "ba\\%\\a\\\\!",
454 },
455 TestCase {
456 pattern: "ban%na\\!",
457 escape: EscapeBehavior::Disabled,
458 expected: "ban%na\\\\!",
459 },
460 TestCase {
461 pattern: "ban\\na!",
462 escape: EscapeBehavior::Char('n'),
463 expected: "ba\\\\\\a!",
464 },
465 TestCase {
466 pattern: "ban\\\\na!",
467 escape: EscapeBehavior::Char('n'),
468 expected: "ba\\\\\\\\\\a!",
469 },
470 TestCase {
471 pattern: "food",
472 escape: EscapeBehavior::Char('o'),
473 expected: "f\\od",
474 },
475 TestCase {
476 pattern: "漢漢",
477 escape: EscapeBehavior::Char('漢'),
478 expected: "\\漢",
479 },
480 ];
481
482 for input in test_cases {
483 let actual = normalize_pattern(input.pattern, input.escape).unwrap();
484 assert!(
485 actual == input.expected,
486 "normalize_pattern({:?}, {:?}):\n\tactual: {:?}\n\texpected: {:?}\n",
487 input.pattern,
488 input.escape,
489 actual,
490 input.expected,
491 );
492 }
493 }
494
495 #[mz_ore::test]
496 fn test_escape_too_long() {
497 match EscapeBehavior::from_str("foo") {
498 Err(EvalError::LikeEscapeTooLong) => {}
499 _ => {
500 panic!("expected error when using escape string with >1 character");
501 }
502 }
503 }
504
505 #[mz_ore::test]
506 fn test_like() {
507 struct Input<'a> {
508 haystack: &'a str,
509 matches: bool,
510 }
511 let input = |haystack, matches| Input { haystack, matches };
512 struct Pattern<'a> {
513 needle: &'a str,
514 case_insensitive: bool,
515 inputs: Vec<Input<'a>>,
516 }
517 let test_cases = vec![
518 Pattern {
519 needle: "ban%na!",
520 case_insensitive: false,
521 inputs: vec![input("banana!", true)],
522 },
523 Pattern {
524 needle: "foo",
525 case_insensitive: true,
526 inputs: vec![
527 input("", false),
528 input("f", false),
529 input("fo", false),
530 input("foo", true),
531 input("FOO", true),
532 input("Foo", true),
533 input("fOO", true),
534 input("food", false),
535 ],
536 },
537 ];
538
539 for tc in test_cases {
540 let matcher = compile(tc.needle, tc.case_insensitive).unwrap();
541 for input in tc.inputs {
542 let actual = matcher.is_match(input.haystack);
543 assert!(
544 actual == input.matches,
545 "{:?} {} {:?}:\n\tactual: {:?}\n\texpected: {:?}\n",
546 input.haystack,
547 match tc.case_insensitive {
548 true => "ILIKE",
549 false => "LIKE",
550 },
551 tc.needle,
552 actual,
553 input.matches,
554 );
555 }
556 }
557 }
558}