1use crate::like::StringArrayType;
22
23use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder};
24use arrow_array::cast::AsArray;
25use arrow_array::*;
26use arrow_buffer::NullBuffer;
27use arrow_data::{ArrayData, ArrayDataBuilder};
28use arrow_schema::{ArrowError, DataType, Field};
29use regex::Regex;
30
31use std::collections::HashMap;
32use std::sync::Arc;
33
34#[deprecated(since = "54.0.0", note = "please use `regex_is_match` instead")]
42pub fn regexp_is_match_utf8<OffsetSize: OffsetSizeTrait>(
43 array: &GenericStringArray<OffsetSize>,
44 regex_array: &GenericStringArray<OffsetSize>,
45 flags_array: Option<&GenericStringArray<OffsetSize>>,
46) -> Result<BooleanArray, ArrowError> {
47 regexp_is_match(array, regex_array, flags_array)
48}
49
50pub fn regexp_is_match<'a, S1, S2, S3>(
84 array: &'a S1,
85 regex_array: &'a S2,
86 flags_array: Option<&'a S3>,
87) -> Result<BooleanArray, ArrowError>
88where
89 &'a S1: StringArrayType<'a>,
90 &'a S2: StringArrayType<'a>,
91 &'a S3: StringArrayType<'a>,
92{
93 if array.len() != regex_array.len() {
94 return Err(ArrowError::ComputeError(
95 "Cannot perform comparison operation on arrays of different length".to_string(),
96 ));
97 }
98
99 let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());
100
101 let mut patterns: HashMap<String, Regex> = HashMap::new();
102 let mut result = BooleanBufferBuilder::new(array.len());
103
104 let complete_pattern = match flags_array {
105 Some(flags) => Box::new(
106 regex_array
107 .iter()
108 .zip(flags.iter())
109 .map(|(pattern, flags)| {
110 pattern.map(|pattern| match flags {
111 Some(flag) => format!("(?{flag}){pattern}"),
112 None => pattern.to_string(),
113 })
114 }),
115 ) as Box<dyn Iterator<Item = Option<String>>>,
116 None => Box::new(
117 regex_array
118 .iter()
119 .map(|pattern| pattern.map(|pattern| pattern.to_string())),
120 ),
121 };
122
123 array
124 .iter()
125 .zip(complete_pattern)
126 .map(|(value, pattern)| {
127 match (value, pattern) {
128 (Some(_), Some(pattern)) if pattern == *"" => {
131 result.append(true);
132 }
133 (Some(value), Some(pattern)) => {
134 let existing_pattern = patterns.get(&pattern);
135 let re = match existing_pattern {
136 Some(re) => re,
137 None => {
138 let re = Regex::new(pattern.as_str()).map_err(|e| {
139 ArrowError::ComputeError(format!(
140 "Regular expression did not compile: {e:?}"
141 ))
142 })?;
143 patterns.entry(pattern).or_insert(re)
144 }
145 };
146 result.append(re.is_match(value));
147 }
148 _ => result.append(false),
149 }
150 Ok(())
151 })
152 .collect::<Result<Vec<()>, ArrowError>>()?;
153
154 let data = unsafe {
155 ArrayDataBuilder::new(DataType::Boolean)
156 .len(array.len())
157 .buffers(vec![result.into()])
158 .nulls(nulls)
159 .build_unchecked()
160 };
161
162 Ok(BooleanArray::from(data))
163}
164
165#[deprecated(since = "54.0.0", note = "please use `regex_is_match_scalar` instead")]
170pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
171 array: &GenericStringArray<OffsetSize>,
172 regex: &str,
173 flag: Option<&str>,
174) -> Result<BooleanArray, ArrowError> {
175 regexp_is_match_scalar(array, regex, flag)
176}
177
178pub fn regexp_is_match_scalar<'a, S>(
203 array: &'a S,
204 regex: &str,
205 flag: Option<&str>,
206) -> Result<BooleanArray, ArrowError>
207where
208 &'a S: StringArrayType<'a>,
209{
210 let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
211 let mut result = BooleanBufferBuilder::new(array.len());
212
213 let pattern = match flag {
214 Some(flag) => format!("(?{flag}){regex}"),
215 None => regex.to_string(),
216 };
217
218 if pattern.is_empty() {
219 result.append_n(array.len(), true);
220 } else {
221 let re = Regex::new(pattern.as_str()).map_err(|e| {
222 ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
223 })?;
224 for i in 0..array.len() {
225 let value = array.value(i);
226 result.append(re.is_match(value));
227 }
228 }
229
230 let buffer = result.into();
231 let data = unsafe {
232 ArrayData::new_unchecked(
233 DataType::Boolean,
234 array.len(),
235 None,
236 null_bit_buffer,
237 0,
238 vec![buffer],
239 vec![],
240 )
241 };
242
243 Ok(BooleanArray::from(data))
244}
245
246fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
247 array: &GenericStringArray<OffsetSize>,
248 regex_array: &GenericStringArray<OffsetSize>,
249 flags_array: Option<&GenericStringArray<OffsetSize>>,
250) -> Result<ArrayRef, ArrowError> {
251 let mut patterns: HashMap<String, Regex> = HashMap::new();
252 let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
253 let mut list_builder = ListBuilder::new(builder);
254
255 let complete_pattern = match flags_array {
256 Some(flags) => Box::new(
257 regex_array
258 .iter()
259 .zip(flags.iter())
260 .map(|(pattern, flags)| {
261 pattern.map(|pattern| match flags {
262 Some(value) => format!("(?{value}){pattern}"),
263 None => pattern.to_string(),
264 })
265 }),
266 ) as Box<dyn Iterator<Item = Option<String>>>,
267 None => Box::new(
268 regex_array
269 .iter()
270 .map(|pattern| pattern.map(|pattern| pattern.to_string())),
271 ),
272 };
273
274 array
275 .iter()
276 .zip(complete_pattern)
277 .map(|(value, pattern)| {
278 match (value, pattern) {
279 (Some(_), Some(pattern)) if pattern == *"" => {
282 list_builder.values().append_value("");
283 list_builder.append(true);
284 }
285 (Some(value), Some(pattern)) => {
286 let existing_pattern = patterns.get(&pattern);
287 let re = match existing_pattern {
288 Some(re) => re,
289 None => {
290 let re = Regex::new(pattern.as_str()).map_err(|e| {
291 ArrowError::ComputeError(format!(
292 "Regular expression did not compile: {e:?}"
293 ))
294 })?;
295 patterns.entry(pattern).or_insert(re)
296 }
297 };
298 match re.captures(value) {
299 Some(caps) => {
300 let mut iter = caps.iter();
301 if caps.len() > 1 {
302 iter.next();
303 }
304 for m in iter.flatten() {
305 list_builder.values().append_value(m.as_str());
306 }
307
308 list_builder.append(true);
309 }
310 None => list_builder.append(false),
311 }
312 }
313 _ => list_builder.append(false),
314 }
315 Ok(())
316 })
317 .collect::<Result<Vec<()>, ArrowError>>()?;
318 Ok(Arc::new(list_builder.finish()))
319}
320
321fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
322 regex_array: &'a dyn Array,
323 flag_array: Option<&'a dyn Array>,
324) -> (Option<&'a str>, Option<&'a str>) {
325 let regex = regex_array.as_string::<OffsetSize>();
326 let regex = regex.is_valid(0).then(|| regex.value(0));
327
328 if let Some(flag_array) = flag_array {
329 let flag = flag_array.as_string::<OffsetSize>();
330 (regex, flag.is_valid(0).then(|| flag.value(0)))
331 } else {
332 (regex, None)
333 }
334}
335
336fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
337 array: &GenericStringArray<OffsetSize>,
338 regex: &Regex,
339) -> Result<ArrayRef, ArrowError> {
340 let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
341 let mut list_builder = ListBuilder::new(builder);
342
343 array
344 .iter()
345 .map(|value| {
346 match value {
347 Some(_) if regex.as_str() == "" => {
350 list_builder.values().append_value("");
351 list_builder.append(true);
352 }
353 Some(value) => match regex.captures(value) {
354 Some(caps) => {
355 let mut iter = caps.iter();
356 if caps.len() > 1 {
357 iter.next();
358 }
359 for m in iter.flatten() {
360 list_builder.values().append_value(m.as_str());
361 }
362
363 list_builder.append(true);
364 }
365 None => list_builder.append(false),
366 },
367 _ => list_builder.append(false),
368 }
369 Ok(())
370 })
371 .collect::<Result<Vec<()>, ArrowError>>()?;
372
373 Ok(Arc::new(list_builder.finish()))
374}
375
376pub fn regexp_match(
401 array: &dyn Array,
402 regex_array: &dyn Datum,
403 flags_array: Option<&dyn Datum>,
404) -> Result<ArrayRef, ArrowError> {
405 let (rhs, is_rhs_scalar) = regex_array.get();
406
407 if array.data_type() != rhs.data_type() {
408 return Err(ArrowError::ComputeError(
409 "regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8"
410 .to_string(),
411 ));
412 }
413
414 let (flags, is_flags_scalar) = match flags_array {
415 Some(flags) => {
416 let (flags, is_flags_scalar) = flags.get();
417 (Some(flags), Some(is_flags_scalar))
418 }
419 None => (None, None),
420 };
421
422 if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
423 return Err(ArrowError::ComputeError(
424 "regexp_match() requires both pattern and flags to be either scalar or array"
425 .to_string(),
426 ));
427 }
428
429 if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
430 return Err(ArrowError::ComputeError(
431 "regexp_match() requires both pattern and flags to be either string or largestring"
432 .to_string(),
433 ));
434 }
435
436 if is_rhs_scalar {
437 let (regex, flag) = match rhs.data_type() {
439 DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
440 DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
441 _ => {
442 return Err(ArrowError::ComputeError(
443 "regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(),
444 ));
445 }
446 };
447
448 if regex.is_none() {
449 return Ok(new_null_array(
450 &DataType::List(Arc::new(Field::new(
451 "item",
452 array.data_type().clone(),
453 true,
454 ))),
455 array.len(),
456 ));
457 }
458
459 let regex = regex.unwrap();
460
461 let pattern = if let Some(flag) = flag {
462 format!("(?{flag}){regex}")
463 } else {
464 regex.to_string()
465 };
466
467 let re = Regex::new(pattern.as_str()).map_err(|e| {
468 ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
469 })?;
470
471 match array.data_type() {
472 DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
473 DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
474 _ => Err(ArrowError::ComputeError(
475 "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(),
476 )),
477 }
478 } else {
479 match array.data_type() {
480 DataType::Utf8 => {
481 let regex_array = rhs.as_string();
482 let flags_array = flags.map(|flags| flags.as_string());
483 regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
484 }
485 DataType::LargeUtf8 => {
486 let regex_array = rhs.as_string();
487 let flags_array = flags.map(|flags| flags.as_string());
488 regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
489 }
490 _ => Err(ArrowError::ComputeError(
491 "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(),
492 )),
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn match_single_group() {
503 let values = vec![
504 Some("abc-005-def"),
505 Some("X-7-5"),
506 Some("X545"),
507 None,
508 Some("foobarbequebaz"),
509 Some("foobarbequebaz"),
510 ];
511 let array = StringArray::from(values);
512 let mut pattern_values = vec![r".*-(\d*)-.*"; 4];
513 pattern_values.push(r"(bar)(bequ1e)");
514 pattern_values.push("");
515 let pattern = GenericStringArray::<i32>::from(pattern_values);
516 let actual = regexp_match(&array, &pattern, None).unwrap();
517 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
518 let mut expected_builder = ListBuilder::new(elem_builder);
519 expected_builder.values().append_value("005");
520 expected_builder.append(true);
521 expected_builder.values().append_value("7");
522 expected_builder.append(true);
523 expected_builder.append(false);
524 expected_builder.append(false);
525 expected_builder.append(false);
526 expected_builder.values().append_value("");
527 expected_builder.append(true);
528 let expected = expected_builder.finish();
529 let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
530 assert_eq!(&expected, result);
531 }
532
533 #[test]
534 fn match_single_group_with_flags() {
535 let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
536 let array = StringArray::from(values);
537 let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]);
538 let flags = StringArray::from(vec!["i"; 4]);
539 let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
540 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
541 let mut expected_builder = ListBuilder::new(elem_builder);
542 expected_builder.append(false);
543 expected_builder.values().append_value("7");
544 expected_builder.append(true);
545 expected_builder.append(false);
546 expected_builder.append(false);
547 let expected = expected_builder.finish();
548 let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
549 assert_eq!(&expected, result);
550 }
551
552 #[test]
553 fn match_scalar_pattern() {
554 let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
555 let array = StringArray::from(values);
556 let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1]));
557 let flags = Scalar::new(StringArray::from(vec!["i"; 1]));
558 let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
559 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
560 let mut expected_builder = ListBuilder::new(elem_builder);
561 expected_builder.append(false);
562 expected_builder.values().append_value("7");
563 expected_builder.append(true);
564 expected_builder.append(false);
565 expected_builder.append(false);
566 let expected = expected_builder.finish();
567 let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
568 assert_eq!(&expected, result);
569
570 let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None];
572 let array = StringArray::from(values);
573 let actual = regexp_match(&array, &pattern, None).unwrap();
574 let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
575 assert_eq!(&expected, result);
576 }
577
578 #[test]
579 fn match_scalar_no_pattern() {
580 let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
581 let array = StringArray::from(values);
582 let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1));
583 let actual = regexp_match(&array, &pattern, None).unwrap();
584 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
585 let mut expected_builder = ListBuilder::new(elem_builder);
586 expected_builder.append(false);
587 expected_builder.append(false);
588 expected_builder.append(false);
589 expected_builder.append(false);
590 let expected = expected_builder.finish();
591 let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
592 assert_eq!(&expected, result);
593 }
594
595 #[test]
596 fn test_single_group_not_skip_match() {
597 let array = StringArray::from(vec![Some("foo"), Some("bar")]);
598 let pattern = GenericStringArray::<i32>::from(vec![r"foo"]);
599 let actual = regexp_match(&array, &pattern, None).unwrap();
600 let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
601 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
602 let mut expected_builder = ListBuilder::new(elem_builder);
603 expected_builder.values().append_value("foo");
604 expected_builder.append(true);
605 let expected = expected_builder.finish();
606 assert_eq!(&expected, result);
607 }
608
609 macro_rules! test_flag_utf8 {
610 ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
611 #[test]
612 fn $test_name() {
613 let left = $left;
614 let right = $right;
615 let res = $op(&left, &right, None).unwrap();
616 let expected = $expected;
617 assert_eq!(expected.len(), res.len());
618 for i in 0..res.len() {
619 let v = res.value(i);
620 assert_eq!(v, expected[i]);
621 }
622 }
623 };
624 ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
625 #[test]
626 fn $test_name() {
627 let left = $left;
628 let right = $right;
629 let flag = Some($flag);
630 let res = $op(&left, &right, flag.as_ref()).unwrap();
631 let expected = $expected;
632 assert_eq!(expected.len(), res.len());
633 for i in 0..res.len() {
634 let v = res.value(i);
635 assert_eq!(v, expected[i]);
636 }
637 }
638 };
639 }
640
641 macro_rules! test_flag_utf8_scalar {
642 ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
643 #[test]
644 fn $test_name() {
645 let left = $left;
646 let res = $op(&left, $right, None).unwrap();
647 let expected = $expected;
648 assert_eq!(expected.len(), res.len());
649 for i in 0..res.len() {
650 let v = res.value(i);
651 assert_eq!(
652 v,
653 expected[i],
654 "unexpected result when comparing {} at position {} to {} ",
655 left.value(i),
656 i,
657 $right
658 );
659 }
660 }
661 };
662 ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
663 #[test]
664 fn $test_name() {
665 let left = $left;
666 let flag = Some($flag);
667 let res = $op(&left, $right, flag).unwrap();
668 let expected = $expected;
669 assert_eq!(expected.len(), res.len());
670 for i in 0..res.len() {
671 let v = res.value(i);
672 assert_eq!(
673 v,
674 expected[i],
675 "unexpected result when comparing {} at position {} to {} ",
676 left.value(i),
677 i,
678 $right
679 );
680 }
681 }
682 };
683 }
684
685 test_flag_utf8!(
686 test_array_regexp_is_match_utf8,
687 StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
688 StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
689 regexp_is_match::<StringArray, StringArray, StringArray>,
690 [true, false, true, false, false, true]
691 );
692 test_flag_utf8!(
693 test_array_regexp_is_match_utf8_insensitive,
694 StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
695 StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
696 StringArray::from(vec!["i"; 6]),
697 regexp_is_match,
698 [true, true, true, true, false, true]
699 );
700
701 test_flag_utf8_scalar!(
702 test_array_regexp_is_match_utf8_scalar,
703 StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
704 "^ar",
705 regexp_is_match_scalar,
706 [true, false, false, false]
707 );
708 test_flag_utf8_scalar!(
709 test_array_regexp_is_match_utf8_scalar_empty,
710 StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
711 "",
712 regexp_is_match_scalar,
713 [true, true, true, true]
714 );
715 test_flag_utf8_scalar!(
716 test_array_regexp_is_match_utf8_scalar_insensitive,
717 StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
718 "^ar",
719 "i",
720 regexp_is_match_scalar,
721 [true, true, false, false]
722 );
723
724 test_flag_utf8!(
725 tes_array_regexp_is_match,
726 StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
727 StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
728 regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
729 [true, false, true, false, false, true]
730 );
731 test_flag_utf8!(
732 test_array_regexp_is_match_2,
733 StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
734 StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
735 regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
736 [true, false, true, false, false, true]
737 );
738 test_flag_utf8!(
739 test_array_regexp_is_match_insensitive,
740 StringViewArray::from(vec![
741 "Official Rust implementation of Apache Arrow",
742 "apache/arrow-rs",
743 "apache/arrow-rs",
744 "parquet",
745 "parquet",
746 "row",
747 "row",
748 ]),
749 StringViewArray::from(vec![
750 ".*rust implement.*",
751 "^ap",
752 "^AP",
753 "et$",
754 "ET$",
755 "foo",
756 ""
757 ]),
758 StringViewArray::from(vec!["i"; 7]),
759 regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
760 [true, true, true, true, true, false, true]
761 );
762 test_flag_utf8!(
763 test_array_regexp_is_match_insensitive_2,
764 LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
765 StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
766 StringArray::from(vec!["i"; 6]),
767 regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
768 [true, true, true, true, false, true]
769 );
770
771 test_flag_utf8_scalar!(
772 test_array_regexp_is_match_scalar,
773 StringViewArray::from(vec![
774 "apache/arrow-rs",
775 "APACHE/ARROW-RS",
776 "parquet",
777 "PARQUET",
778 ]),
779 "^ap",
780 regexp_is_match_scalar::<StringViewArray>,
781 [true, false, false, false]
782 );
783 test_flag_utf8_scalar!(
784 test_array_regexp_is_match_scalar_empty,
785 StringViewArray::from(vec![
786 "apache/arrow-rs",
787 "APACHE/ARROW-RS",
788 "parquet",
789 "PARQUET",
790 ]),
791 "",
792 regexp_is_match_scalar::<StringViewArray>,
793 [true, true, true, true]
794 );
795 test_flag_utf8_scalar!(
796 test_array_regexp_is_match_scalar_insensitive,
797 StringViewArray::from(vec![
798 "apache/arrow-rs",
799 "APACHE/ARROW-RS",
800 "parquet",
801 "PARQUET",
802 ]),
803 "^ap",
804 "i",
805 regexp_is_match_scalar::<StringViewArray>,
806 [true, true, false, false]
807 );
808}