Skip to main content

mz_repr/adt/
varchar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12
13use anyhow::bail;
14use mz_lowertest::MzReflect;
15use mz_ore::cast::CastFrom;
16use mz_proto::{RustType, TryFromProtoError};
17use proptest::arbitrary::Arbitrary;
18use proptest::strategy::{BoxedStrategy, Strategy};
19use serde::{Deserialize, Serialize};
20
21include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.varchar.rs"));
22
23// https://github.com/postgres/postgres/blob/REL_14_0/src/include/access/htup_details.h#L577-L584
24pub const MAX_MAX_LENGTH: u32 = 10_485_760;
25
26/// A marker type indicating that a Rust string should be interpreted as a
27/// [`SqlScalarType::VarChar`].
28///
29/// [`SqlScalarType::VarChar`]: crate::SqlScalarType::VarChar
30#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
31pub struct VarChar<S: AsRef<str>>(pub S);
32
33/// The `max_length` of a [`SqlScalarType::VarChar`].
34///
35/// This newtype wrapper ensures that the length is within the valid range.
36///
37/// [`SqlScalarType::VarChar`]: crate::SqlScalarType::VarChar
38#[derive(
39    Debug,
40    Clone,
41    Copy,
42    Eq,
43    PartialEq,
44    Ord,
45    PartialOrd,
46    Hash,
47    Serialize,
48    Deserialize,
49    MzReflect
50)]
51pub struct VarCharMaxLength(pub(crate) u32);
52
53impl VarCharMaxLength {
54    /// Consumes the newtype wrapper, returning the inner `u32`.
55    pub fn into_u32(self) -> u32 {
56        self.0
57    }
58}
59
60impl TryFrom<i64> for VarCharMaxLength {
61    type Error = InvalidVarCharMaxLengthError;
62
63    fn try_from(max_length: i64) -> Result<Self, Self::Error> {
64        match u32::try_from(max_length) {
65            Ok(max_length) if max_length > 0 && max_length < MAX_MAX_LENGTH => {
66                Ok(VarCharMaxLength(max_length))
67            }
68            _ => Err(InvalidVarCharMaxLengthError),
69        }
70    }
71}
72
73impl RustType<ProtoVarCharMaxLength> for VarCharMaxLength {
74    fn into_proto(&self) -> ProtoVarCharMaxLength {
75        ProtoVarCharMaxLength { value: self.0 }
76    }
77
78    fn from_proto(proto: ProtoVarCharMaxLength) -> Result<Self, TryFromProtoError> {
79        Ok(VarCharMaxLength(proto.value))
80    }
81}
82
83impl Arbitrary for VarCharMaxLength {
84    type Parameters = ();
85    type Strategy = BoxedStrategy<VarCharMaxLength>;
86
87    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
88        proptest::arbitrary::any::<u32>()
89            // We cap the maximum VarCharMaxLength to prevent generating
90            // massive strings which can greatly slow down tests and are
91            // relatively uninteresting.
92            .prop_map(|len| VarCharMaxLength(len % 300))
93            .boxed()
94    }
95}
96
97/// The error returned when constructing a [`VarCharMaxLength`] from an invalid
98/// value.
99#[derive(Debug, Clone)]
100pub struct InvalidVarCharMaxLengthError;
101
102impl fmt::Display for InvalidVarCharMaxLengthError {
103    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104        write!(
105            f,
106            "length for type character varying must be between 1 and {}",
107            MAX_MAX_LENGTH
108        )
109    }
110}
111
112impl Error for InvalidVarCharMaxLengthError {}
113
114pub fn format_str(
115    s: &str,
116    length: Option<VarCharMaxLength>,
117    fail_on_len: bool,
118) -> Result<&str, anyhow::Error> {
119    Ok(match length {
120        // Note that length is 1-indexed, so finding `None` means the string's
121        // characters don't exceed the length, while finding `Some` means it
122        // does.
123        Some(l) => {
124            let l = usize::cast_from(l.into_u32());
125            match s.char_indices().nth(l) {
126                None => s,
127                Some((idx, _)) => {
128                    if !fail_on_len || s[idx..].chars().all(|c| c.is_ascii_whitespace()) {
129                        &s[..idx]
130                    } else {
131                        bail!("{} exceeds maximum length of {}", s, l)
132                    }
133                }
134            }
135        }
136        None => s,
137    })
138}
139
140#[cfg(test)]
141mod tests {
142    use mz_ore::assert_ok;
143    use mz_proto::protobuf_roundtrip;
144    use proptest::prelude::*;
145
146    use super::*;
147
148    proptest! {
149        #[mz_ore::test]
150        fn var_char_max_length_protobuf_roundtrip(expect in any::<VarCharMaxLength>()) {
151            let actual = protobuf_roundtrip::<_, ProtoVarCharMaxLength>(&expect);
152            assert_ok!(actual);
153            assert_eq!(actual.unwrap(), expect);
154        }
155    }
156}