use std::cmp::Ordering;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use mz_lowertest::MzReflect;
use mz_proto::{RustType, TryFromProtoError};
use proptest::prelude::any;
use proptest::prop_compose;
use regex::{Error, RegexBuilder};
use serde::de::Error as DeError;
use serde::ser::SerializeStruct;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.regex.rs"));
#[derive(Debug, Clone, MzReflect)]
pub struct Regex {
pub case_insensitive: bool,
pub dot_matches_new_line: bool,
pub regex: regex::Regex,
}
impl Regex {
pub fn new(pattern: &str, case_insensitive: bool) -> Result<Regex, Error> {
Self::new_dot_matches_new_line(pattern, case_insensitive, true)
}
pub fn new_dot_matches_new_line(
pattern: &str,
case_insensitive: bool,
dot_matches_new_line: bool,
) -> Result<Regex, Error> {
let mut regex_builder = RegexBuilder::new(pattern);
regex_builder.case_insensitive(case_insensitive);
regex_builder.dot_matches_new_line(dot_matches_new_line);
Ok(Regex {
case_insensitive,
dot_matches_new_line,
regex: regex_builder.build()?,
})
}
pub fn pattern(&self) -> &str {
self.regex.as_str()
}
}
impl PartialEq<Regex> for Regex {
fn eq(&self, other: &Regex) -> bool {
self.pattern() == other.pattern()
&& self.case_insensitive == other.case_insensitive
&& self.dot_matches_new_line == other.dot_matches_new_line
}
}
impl Eq for Regex {}
impl PartialOrd for Regex {
fn partial_cmp(&self, other: &Regex) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Regex {
fn cmp(&self, other: &Regex) -> Ordering {
(
self.pattern(),
self.case_insensitive,
self.dot_matches_new_line,
)
.cmp(&(
other.pattern(),
other.case_insensitive,
other.dot_matches_new_line,
))
}
}
impl Hash for Regex {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.pattern().hash(hasher);
self.case_insensitive.hash(hasher);
self.dot_matches_new_line.hash(hasher);
}
}
impl Deref for Regex {
type Target = regex::Regex;
fn deref(&self) -> ®ex::Regex {
&self.regex
}
}
impl RustType<ProtoRegex> for Regex {
fn into_proto(&self) -> ProtoRegex {
ProtoRegex {
pattern: self.pattern().to_owned(),
case_insensitive: self.case_insensitive,
dot_matches_new_line: self.dot_matches_new_line,
}
}
fn from_proto(proto: ProtoRegex) -> Result<Self, TryFromProtoError> {
Ok(Regex::new_dot_matches_new_line(
&proto.pattern,
proto.case_insensitive,
proto.dot_matches_new_line,
)?)
}
}
impl Serialize for Regex {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("Regex", 3)?;
state.serialize_field("pattern", &self.pattern())?;
state.serialize_field("case_insensitive", &self.case_insensitive)?;
state.serialize_field("dot_matches_new_line", &self.dot_matches_new_line)?;
state.end()
}
}
impl<'de> Deserialize<'de> for Regex {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Pattern,
CaseInsensitive,
DotMatchesNewLine,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;
impl<'de> de::Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(
"pattern string or case_insensitive bool or dot_matches_new_line bool",
)
}
fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: de::Error,
{
match value {
"pattern" => Ok(Field::Pattern),
"case_insensitive" => Ok(Field::CaseInsensitive),
"dot_matches_new_line" => Ok(Field::DotMatchesNewLine),
_ => Err(de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct RegexVisitor;
impl<'de> de::Visitor<'de> for RegexVisitor {
type Value = Regex;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("Regex serialized by the manual Serialize impl from above")
}
fn visit_seq<V>(self, mut seq: V) -> Result<Regex, V::Error>
where
V: de::SeqAccess<'de>,
{
let pattern = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let case_insensitive = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
let dot_matches_new_line = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
Regex::new_dot_matches_new_line(pattern, case_insensitive, dot_matches_new_line)
.map_err(|err| {
V::Error::custom(format!(
"Unable to recreate regex during deserialization: {}",
err
))
})
}
fn visit_map<V>(self, mut map: V) -> Result<Regex, V::Error>
where
V: de::MapAccess<'de>,
{
let mut pattern: Option<&str> = None;
let mut case_insensitive: Option<bool> = None;
let mut dot_matches_new_line: Option<bool> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Pattern => {
if pattern.is_some() {
return Err(de::Error::duplicate_field("pattern"));
}
pattern = Some(map.next_value()?);
}
Field::CaseInsensitive => {
if case_insensitive.is_some() {
return Err(de::Error::duplicate_field("case_insensitive"));
}
case_insensitive = Some(map.next_value()?);
}
Field::DotMatchesNewLine => {
if dot_matches_new_line.is_some() {
return Err(de::Error::duplicate_field("dot_matches_new_line"));
}
dot_matches_new_line = Some(map.next_value()?);
}
}
}
let pattern = pattern.ok_or_else(|| de::Error::missing_field("pattern"))?;
let case_insensitive =
case_insensitive.ok_or_else(|| de::Error::missing_field("case_insensitive"))?;
let dot_matches_new_line = dot_matches_new_line
.ok_or_else(|| de::Error::missing_field("dot_matches_new_line"))?;
Regex::new_dot_matches_new_line(pattern, case_insensitive, dot_matches_new_line)
.map_err(|err| {
V::Error::custom(format!(
"Unable to recreate regex during deserialization: {}",
err
))
})
}
}
const FIELDS: &[&str] = &["pattern", "case_insensitive", "dot_matches_new_line"];
deserializer.deserialize_struct("Regex", FIELDS, RegexVisitor)
}
}
const BEGINNING_SYMBOLS: &str = r"((\\A)|\^)?";
const CHARACTERS: &str = r"[\.x]{1}";
const REPETITIONS: &str = r"((\*|\+|\?|(\{[1-9],?\}))\??)?";
const END_SYMBOLS: &str = r"(\$|(\\z))?";
prop_compose! {
pub fn any_regex()
(b in BEGINNING_SYMBOLS, c in CHARACTERS,
r in REPETITIONS, e in END_SYMBOLS, case_insensitive in any::<bool>(), dot_matches_new_line in any::<bool>())
-> Regex {
let string = format!("{}{}{}{}", b, c, r, e);
let regex = Regex::new_dot_matches_new_line(&string, case_insensitive, dot_matches_new_line).unwrap();
assert_eq!(regex.pattern(), string);
regex
}
}
#[cfg(test)]
mod tests {
use mz_ore::assert_ok;
use mz_proto::protobuf_roundtrip;
use proptest::prelude::*;
use super::*;
proptest! {
#[mz_ore::test]
#[cfg_attr(miri, ignore)] fn regex_protobuf_roundtrip( expect in any_regex() ) {
let actual = protobuf_roundtrip::<_, ProtoRegex>(&expect);
assert_ok!(actual);
assert_eq!(actual.unwrap(), expect);
}
}
#[mz_ore::test]
fn regex_serde_case_insensitive() {
let pattern = "AAA";
let orig_regex = Regex::new(pattern, true).unwrap();
let serialized: String = serde_json::to_string(&orig_regex).unwrap();
let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
assert_eq!(orig_regex.regex.is_match("aaa"), true);
assert_eq!(roundtrip_result.regex.is_match("aaa"), true);
assert_eq!(pattern, roundtrip_result.pattern());
}
#[mz_ore::test]
fn regex_serde_dot_matches_new_line() {
{
let pattern = "A.*B";
let orig_regex = Regex::new_dot_matches_new_line(pattern, true, true).unwrap();
let serialized: String = serde_json::to_string(&orig_regex).unwrap();
let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
assert_eq!(pattern, roundtrip_result.pattern());
}
{
let pattern = "A.*B";
let orig_regex = Regex::new_dot_matches_new_line(pattern, true, false).unwrap();
let serialized: String = serde_json::to_string(&orig_regex).unwrap();
let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), false);
assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), false);
assert_eq!(pattern, roundtrip_result.pattern());
}
{
let pattern = "A.*B";
let orig_regex = Regex::new(pattern, true).unwrap();
let serialized: String = serde_json::to_string(&orig_regex).unwrap();
let roundtrip_result: Regex = serde_json::from_str(&serialized).unwrap();
assert_eq!(orig_regex.regex.is_match("axxx\nxxxb"), true);
assert_eq!(roundtrip_result.regex.is_match("axxx\nxxxb"), true);
assert_eq!(pattern, roundtrip_result.pattern());
}
}
}