use std::borrow::Cow;
use std::char::CharTryFromError;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::num::{NonZeroU64, TryFromIntError};
use std::sync::Arc;
use mz_ore::cast::CastFrom;
use mz_ore::num::{NonNeg, NonNegError};
use num::Signed;
use proptest::prelude::Strategy;
use prost::UnknownEnumValue;
use uuid::Uuid;
#[cfg(feature = "chrono")]
pub mod chrono;
#[cfg(feature = "tokio-postgres")]
pub mod tokio_postgres;
include!(concat!(env!("OUT_DIR"), "/mz_proto.rs"));
#[derive(Debug)]
pub enum TryFromProtoError {
TryFromIntError(TryFromIntError),
NonNegError(NonNegError),
CharTryFromError(CharTryFromError),
DateConversionError(String),
RegexError(regex::Error),
RowConversionError(String),
DeserializationError(serde_json::Error),
MissingField(String),
UnknownEnumVariant(String),
InvalidShardId(String),
CodecMismatch(String),
InvalidPersistState(String),
InvalidSemverVersion(String),
InvalidUri(http::uri::InvalidUri),
GlobError(globset::Error),
InvalidUrl(url::ParseError),
InvalidBitFlags(String),
LikePatternDeserializationError(String),
InvalidFieldError(String),
}
impl TryFromProtoError {
pub fn missing_field<T: ToString>(s: T) -> TryFromProtoError {
TryFromProtoError::MissingField(s.to_string())
}
pub fn unknown_enum_variant<T: ToString>(s: T) -> TryFromProtoError {
TryFromProtoError::UnknownEnumVariant(s.to_string())
}
}
impl From<TryFromIntError> for TryFromProtoError {
fn from(error: TryFromIntError) -> Self {
TryFromProtoError::TryFromIntError(error)
}
}
impl From<NonNegError> for TryFromProtoError {
fn from(error: NonNegError) -> Self {
TryFromProtoError::NonNegError(error)
}
}
impl From<CharTryFromError> for TryFromProtoError {
fn from(error: CharTryFromError) -> Self {
TryFromProtoError::CharTryFromError(error)
}
}
impl From<UnknownEnumValue> for TryFromProtoError {
fn from(UnknownEnumValue(n): UnknownEnumValue) -> Self {
TryFromProtoError::UnknownEnumVariant(format!("value {n}"))
}
}
impl From<regex::Error> for TryFromProtoError {
fn from(error: regex::Error) -> Self {
TryFromProtoError::RegexError(error)
}
}
impl From<serde_json::Error> for TryFromProtoError {
fn from(error: serde_json::Error) -> Self {
TryFromProtoError::DeserializationError(error)
}
}
impl From<http::uri::InvalidUri> for TryFromProtoError {
fn from(error: http::uri::InvalidUri) -> Self {
TryFromProtoError::InvalidUri(error)
}
}
impl From<globset::Error> for TryFromProtoError {
fn from(error: globset::Error) -> Self {
TryFromProtoError::GlobError(error)
}
}
impl From<url::ParseError> for TryFromProtoError {
fn from(error: url::ParseError) -> Self {
TryFromProtoError::InvalidUrl(error)
}
}
impl std::fmt::Display for TryFromProtoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use TryFromProtoError::*;
match self {
TryFromIntError(error) => error.fmt(f),
NonNegError(error) => error.fmt(f),
CharTryFromError(error) => error.fmt(f),
DateConversionError(msg) => write!(f, "Date conversion failed: `{}`", msg),
RegexError(error) => error.fmt(f),
DeserializationError(error) => error.fmt(f),
RowConversionError(msg) => write!(f, "Row packing failed: `{}`", msg),
MissingField(field) => write!(f, "Missing value for `{}`", field),
UnknownEnumVariant(field) => write!(f, "Unknown enum value for `{}`", field),
InvalidShardId(value) => write!(f, "Invalid value of ShardId found: `{}`", value),
CodecMismatch(error) => error.fmt(f),
InvalidPersistState(error) => error.fmt(f),
InvalidSemverVersion(error) => error.fmt(f),
InvalidUri(error) => error.fmt(f),
GlobError(error) => error.fmt(f),
InvalidUrl(error) => error.fmt(f),
InvalidBitFlags(error) => error.fmt(f),
LikePatternDeserializationError(inner_error) => write!(
f,
"Protobuf deserialization failed for a LIKE/ILIKE pattern: `{}`",
inner_error
),
InvalidFieldError(error) => error.fmt(f),
}
}
}
impl From<TryFromProtoError> for String {
fn from(error: TryFromProtoError) -> Self {
error.to_string()
}
}
impl std::error::Error for TryFromProtoError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use TryFromProtoError::*;
match self {
TryFromIntError(error) => Some(error),
NonNegError(error) => Some(error),
CharTryFromError(error) => Some(error),
RegexError(error) => Some(error),
DeserializationError(error) => Some(error),
DateConversionError(_) => None,
RowConversionError(_) => None,
MissingField(_) => None,
UnknownEnumVariant(_) => None,
InvalidShardId(_) => None,
CodecMismatch(_) => None,
InvalidPersistState(_) => None,
InvalidSemverVersion(_) => None,
InvalidUri(error) => Some(error),
GlobError(error) => Some(error),
InvalidUrl(error) => Some(error),
InvalidBitFlags(_) => None,
LikePatternDeserializationError(_) => None,
InvalidFieldError(_) => None,
}
}
}
pub fn any_uuid() -> impl Strategy<Value = Uuid> {
(0..u128::MAX).prop_map(Uuid::from_u128)
}
pub trait ProtoRepr: Sized + RustType<Self::Proto> {
type Proto: ::prost::Message;
}
pub trait RustType<Proto>: Sized {
fn into_proto(&self) -> Proto;
fn into_proto_owned(self) -> Proto {
self.into_proto()
}
fn from_proto(proto: Proto) -> Result<Self, TryFromProtoError>;
}
pub trait ProtoMapEntry<K, V> {
fn from_rust<'a>(entry: (&'a K, &'a V)) -> Self;
fn into_rust(self) -> Result<(K, V), TryFromProtoError>;
}
macro_rules! rust_type_id(
($($t:ty),*) => (
$(
impl RustType<$t> for $t {
#[inline]
fn into_proto(&self) -> $t {
self.clone()
}
#[inline]
fn from_proto(proto: $t) -> Result<Self, TryFromProtoError> {
Ok(proto)
}
}
)+
);
);
rust_type_id![bool, f32, f64, i32, i64, String, u32, u64, Vec<u8>];
impl RustType<u64> for Option<NonZeroU64> {
fn into_proto(&self) -> u64 {
match self {
Some(d) => d.get(),
None => 0,
}
}
fn from_proto(proto: u64) -> Result<Self, TryFromProtoError> {
Ok(NonZeroU64::new(proto)) }
}
impl<K, V, T> RustType<Vec<T>> for BTreeMap<K, V>
where
K: std::cmp::Eq + std::cmp::Ord,
T: ProtoMapEntry<K, V>,
{
fn into_proto(&self) -> Vec<T> {
self.iter().map(T::from_rust).collect()
}
fn from_proto(proto: Vec<T>) -> Result<Self, TryFromProtoError> {
proto
.into_iter()
.map(T::into_rust)
.collect::<Result<BTreeMap<_, _>, _>>()
}
}
impl<R, P> RustType<Vec<P>> for BTreeSet<R>
where
R: RustType<P> + std::cmp::Ord,
{
fn into_proto(&self) -> Vec<P> {
self.iter().map(R::into_proto).collect()
}
fn from_proto(proto: Vec<P>) -> Result<Self, TryFromProtoError> {
proto
.into_iter()
.map(R::from_proto)
.collect::<Result<BTreeSet<_>, _>>()
}
}
impl<R, P> RustType<Vec<P>> for Vec<R>
where
R: RustType<P>,
{
fn into_proto(&self) -> Vec<P> {
self.iter().map(R::into_proto).collect()
}
fn from_proto(proto: Vec<P>) -> Result<Self, TryFromProtoError> {
proto.into_iter().map(R::from_proto).collect()
}
}
impl<R, P> RustType<Vec<P>> for Box<[R]>
where
R: RustType<P>,
{
fn into_proto(&self) -> Vec<P> {
self.iter().map(R::into_proto).collect()
}
fn from_proto(proto: Vec<P>) -> Result<Self, TryFromProtoError> {
proto.into_iter().map(R::from_proto).collect()
}
}
impl<R, P> RustType<Option<P>> for Option<R>
where
R: RustType<P>,
{
fn into_proto(&self) -> Option<P> {
self.as_ref().map(R::into_proto)
}
fn from_proto(proto: Option<P>) -> Result<Self, TryFromProtoError> {
proto.map(R::from_proto).transpose()
}
}
impl<R, P> RustType<Box<P>> for Box<R>
where
R: RustType<P>,
{
fn into_proto(&self) -> Box<P> {
Box::new((**self).into_proto())
}
fn from_proto(proto: Box<P>) -> Result<Self, TryFromProtoError> {
(*proto).into_rust().map(Box::new)
}
}
impl<R, P> RustType<P> for Arc<R>
where
R: RustType<P>,
{
fn into_proto(&self) -> P {
(**self).into_proto()
}
fn from_proto(proto: P) -> Result<Self, TryFromProtoError> {
proto.into_rust().map(Arc::new)
}
}
impl<R1, R2, P1, P2> RustType<(P1, P2)> for (R1, R2)
where
R1: RustType<P1>,
R2: RustType<P2>,
{
fn into_proto(&self) -> (P1, P2) {
(self.0.into_proto(), self.1.into_proto())
}
fn from_proto(proto: (P1, P2)) -> Result<Self, TryFromProtoError> {
let first = proto.0.into_rust()?;
let second = proto.1.into_rust()?;
Ok((first, second))
}
}
impl RustType<()> for () {
fn into_proto(&self) -> () {
*self
}
fn from_proto(proto: ()) -> Result<Self, TryFromProtoError> {
Ok(proto)
}
}
impl RustType<u64> for usize {
fn into_proto(&self) -> u64 {
u64::cast_from(*self)
}
fn from_proto(proto: u64) -> Result<Self, TryFromProtoError> {
usize::try_from(proto).map_err(TryFromProtoError::from)
}
}
impl RustType<u32> for char {
fn into_proto(&self) -> u32 {
(*self).into()
}
fn from_proto(proto: u32) -> Result<Self, TryFromProtoError> {
char::try_from(proto).map_err(TryFromProtoError::from)
}
}
impl RustType<u32> for u8 {
fn into_proto(&self) -> u32 {
u32::from(*self)
}
fn from_proto(proto: u32) -> Result<Self, TryFromProtoError> {
u8::try_from(proto).map_err(TryFromProtoError::from)
}
}
impl RustType<u32> for u16 {
fn into_proto(&self) -> u32 {
u32::from(*self)
}
fn from_proto(repr: u32) -> Result<Self, TryFromProtoError> {
u16::try_from(repr).map_err(TryFromProtoError::from)
}
}
impl RustType<i32> for i8 {
fn into_proto(&self) -> i32 {
i32::from(*self)
}
fn from_proto(proto: i32) -> Result<Self, TryFromProtoError> {
i8::try_from(proto).map_err(TryFromProtoError::from)
}
}
impl RustType<i32> for i16 {
fn into_proto(&self) -> i32 {
i32::from(*self)
}
fn from_proto(repr: i32) -> Result<Self, TryFromProtoError> {
i16::try_from(repr).map_err(TryFromProtoError::from)
}
}
impl RustType<ProtoU128> for u128 {
#[allow(clippy::as_conversions)]
fn into_proto(&self) -> ProtoU128 {
let lo = (self & u128::from(u64::MAX)) as u64;
let hi = (self >> 64) as u64;
ProtoU128 { hi, lo }
}
fn from_proto(proto: ProtoU128) -> Result<Self, TryFromProtoError> {
Ok(u128::from(proto.hi) << 64 | u128::from(proto.lo))
}
}
impl RustType<ProtoU128> for Uuid {
fn into_proto(&self) -> ProtoU128 {
self.as_u128().into_proto()
}
fn from_proto(proto: ProtoU128) -> Result<Self, TryFromProtoError> {
Ok(Uuid::from_u128(u128::from_proto(proto)?))
}
}
impl RustType<u64> for std::num::NonZeroUsize {
fn into_proto(&self) -> u64 {
usize::from(*self).into_proto()
}
fn from_proto(proto: u64) -> Result<Self, TryFromProtoError> {
Ok(usize::from_proto(proto)?.try_into()?)
}
}
impl<T> RustType<T> for NonNeg<T>
where
T: Clone + Signed + fmt::Display,
{
fn into_proto(&self) -> T {
(**self).clone()
}
fn from_proto(proto: T) -> Result<Self, TryFromProtoError> {
Ok(NonNeg::<T>::try_from(proto)?)
}
}
impl RustType<ProtoDuration> for std::time::Duration {
fn into_proto(&self) -> ProtoDuration {
ProtoDuration {
secs: self.as_secs(),
nanos: self.subsec_nanos(),
}
}
fn from_proto(proto: ProtoDuration) -> Result<Self, TryFromProtoError> {
Ok(std::time::Duration::new(proto.secs, proto.nanos))
}
}
impl<'a> RustType<String> for Cow<'a, str> {
fn into_proto(&self) -> String {
self.to_string()
}
fn from_proto(proto: String) -> Result<Self, TryFromProtoError> {
Ok(Cow::Owned(proto))
}
}
impl RustType<String> for Box<str> {
fn into_proto(&self) -> String {
self.to_string()
}
fn from_proto(proto: String) -> Result<Self, TryFromProtoError> {
Ok(proto.into())
}
}
pub trait ProtoType<Rust>: Sized {
fn into_rust(self) -> Result<Rust, TryFromProtoError>;
fn from_rust(rust: &Rust) -> Self;
}
impl<P, R> ProtoType<R> for P
where
R: RustType<P>,
{
#[inline]
fn into_rust(self) -> Result<R, TryFromProtoError> {
R::from_proto(self)
}
#[inline]
fn from_rust(rust: &R) -> Self {
R::into_proto(rust)
}
}
pub fn any_duration() -> impl Strategy<Value = std::time::Duration> {
(0..u64::MAX, 0..1_000_000_000u32)
.prop_map(|(secs, nanos)| std::time::Duration::new(secs, nanos))
}
pub trait IntoRustIfSome<T> {
fn into_rust_if_some<S: ToString>(self, field: S) -> Result<T, TryFromProtoError>;
}
impl<R, P> IntoRustIfSome<R> for Option<P>
where
R: RustType<P>,
{
fn into_rust_if_some<S: ToString>(self, field: S) -> Result<R, TryFromProtoError> {
R::from_proto(self.ok_or_else(|| TryFromProtoError::missing_field(field))?)
}
}
pub trait TryIntoIfSome<T> {
fn try_into_if_some<S: ToString>(self, field: S) -> Result<T, TryFromProtoError>;
}
impl<T, U> TryIntoIfSome<T> for Option<U>
where
T: TryFrom<U, Error = TryFromProtoError>,
{
fn try_into_if_some<S: ToString>(self, field: S) -> Result<T, TryFromProtoError> {
self.ok_or_else(|| TryFromProtoError::missing_field(field))?
.try_into()
}
}
pub fn protobuf_roundtrip<R, P>(val: &R) -> anyhow::Result<R>
where
P: ProtoType<R> + ::prost::Message + Default,
{
let vec = P::from_rust(val).encode_to_vec();
let val = P::decode(&*vec)?.into_rust()?;
Ok(val)
}
#[cfg(test)]
mod tests {
use mz_ore::assert_ok;
use proptest::prelude::*;
use super::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(4096))]
#[mz_ore::test]
#[cfg_attr(miri, ignore)] fn duration_protobuf_roundtrip(expect in any_duration() ) {
let actual = protobuf_roundtrip::<_, ProtoDuration>(&expect);
assert_ok!(actual);
assert_eq!(actual.unwrap(), expect);
}
}
}