Skip to main content

mz_repr/adt/
varchar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12
13use anyhow::bail;
14use mz_lowertest::MzReflect;
15use mz_ore::cast::CastFrom;
16use mz_proto::{RustType, TryFromProtoError};
17#[cfg(any(test, feature = "proptest"))]
18use proptest::arbitrary::Arbitrary;
19#[cfg(any(test, feature = "proptest"))]
20use proptest::strategy::{BoxedStrategy, Strategy};
21use serde::{Deserialize, Serialize};
22
23include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.varchar.rs"));
24
25// https://github.com/postgres/postgres/blob/REL_14_0/src/include/access/htup_details.h#L577-L584
26pub const MAX_MAX_LENGTH: u32 = 10_485_760;
27
28/// A marker type indicating that a Rust string should be interpreted as a
29/// [`SqlScalarType::VarChar`].
30///
31/// [`SqlScalarType::VarChar`]: crate::SqlScalarType::VarChar
32#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
33pub struct VarChar<S: AsRef<str>>(pub S);
34
35/// The `max_length` of a [`SqlScalarType::VarChar`].
36///
37/// This newtype wrapper ensures that the length is within the valid range.
38///
39/// [`SqlScalarType::VarChar`]: crate::SqlScalarType::VarChar
40#[derive(
41    Debug,
42    Clone,
43    Copy,
44    Eq,
45    PartialEq,
46    Ord,
47    PartialOrd,
48    Hash,
49    Serialize,
50    Deserialize,
51    MzReflect
52)]
53pub struct VarCharMaxLength(pub(crate) u32);
54
55impl VarCharMaxLength {
56    /// Consumes the newtype wrapper, returning the inner `u32`.
57    pub fn into_u32(self) -> u32 {
58        self.0
59    }
60}
61
62impl TryFrom<i64> for VarCharMaxLength {
63    type Error = InvalidVarCharMaxLengthError;
64
65    fn try_from(max_length: i64) -> Result<Self, Self::Error> {
66        match u32::try_from(max_length) {
67            Ok(max_length) if max_length > 0 && max_length < MAX_MAX_LENGTH => {
68                Ok(VarCharMaxLength(max_length))
69            }
70            _ => Err(InvalidVarCharMaxLengthError),
71        }
72    }
73}
74
75impl RustType<ProtoVarCharMaxLength> for VarCharMaxLength {
76    fn into_proto(&self) -> ProtoVarCharMaxLength {
77        ProtoVarCharMaxLength { value: self.0 }
78    }
79
80    fn from_proto(proto: ProtoVarCharMaxLength) -> Result<Self, TryFromProtoError> {
81        Ok(VarCharMaxLength(proto.value))
82    }
83}
84
85#[cfg(any(test, feature = "proptest"))]
86impl Arbitrary for VarCharMaxLength {
87    type Parameters = ();
88    type Strategy = BoxedStrategy<VarCharMaxLength>;
89
90    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
91        proptest::arbitrary::any::<u32>()
92            // We cap the maximum VarCharMaxLength to prevent generating
93            // massive strings which can greatly slow down tests and are
94            // relatively uninteresting.
95            .prop_map(|len| VarCharMaxLength(len % 300))
96            .boxed()
97    }
98}
99
100/// The error returned when constructing a [`VarCharMaxLength`] from an invalid
101/// value.
102#[derive(Debug, Clone)]
103pub struct InvalidVarCharMaxLengthError;
104
105impl fmt::Display for InvalidVarCharMaxLengthError {
106    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107        write!(
108            f,
109            "length for type character varying must be between 1 and {}",
110            MAX_MAX_LENGTH
111        )
112    }
113}
114
115impl Error for InvalidVarCharMaxLengthError {}
116
117pub fn format_str(
118    s: &str,
119    length: Option<VarCharMaxLength>,
120    fail_on_len: bool,
121) -> Result<&str, anyhow::Error> {
122    Ok(match length {
123        // Note that length is 1-indexed, so finding `None` means the string's
124        // characters don't exceed the length, while finding `Some` means it
125        // does.
126        Some(l) => {
127            let l = usize::cast_from(l.into_u32());
128            match s.char_indices().nth(l) {
129                None => s,
130                Some((idx, _)) => {
131                    if !fail_on_len || s[idx..].chars().all(|c| c.is_ascii_whitespace()) {
132                        &s[..idx]
133                    } else {
134                        bail!("{} exceeds maximum length of {}", s, l)
135                    }
136                }
137            }
138        }
139        None => s,
140    })
141}
142
143#[cfg(test)]
144mod tests {
145    use mz_ore::assert_ok;
146    use mz_proto::protobuf_roundtrip;
147    use proptest::prelude::*;
148
149    use super::*;
150
151    proptest! {
152        #[mz_ore::test]
153        fn var_char_max_length_protobuf_roundtrip(expect in any::<VarCharMaxLength>()) {
154            let actual = protobuf_roundtrip::<_, ProtoVarCharMaxLength>(&expect);
155            assert_ok!(actual);
156            assert_eq!(actual.unwrap(), expect);
157        }
158    }
159}