Skip to main content

mz_repr/adt/
char.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12
13use anyhow::bail;
14use mz_lowertest::MzReflect;
15use mz_ore::cast::CastFrom;
16use mz_proto::{RustType, TryFromProtoError};
17#[cfg(any(test, feature = "proptest"))]
18use proptest::arbitrary::Arbitrary;
19#[cfg(any(test, feature = "proptest"))]
20use proptest::strategy::{BoxedStrategy, Strategy};
21use serde::{Deserialize, Serialize};
22
23include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.char.rs"));
24
25// https://github.com/postgres/postgres/blob/REL_14_0/src/include/access/htup_details.h#L577-L584
26const MAX_LENGTH: u32 = 10_485_760;
27
28/// A marker type indicating that a Rust string should be interpreted as a
29/// [`SqlScalarType::Char`].
30///
31/// [`SqlScalarType::Char`]: crate::SqlScalarType::Char
32#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
33pub struct Char<S: AsRef<str>>(pub S);
34
35/// The `length` of a [`SqlScalarType::Char`].
36///
37/// This newtype wrapper ensures that the length is within the valid range.
38///
39/// [`SqlScalarType::Char`]: crate::SqlScalarType::Char
40#[derive(
41    Debug,
42    Clone,
43    Copy,
44    Eq,
45    PartialEq,
46    Ord,
47    PartialOrd,
48    Hash,
49    Serialize,
50    Deserialize,
51    MzReflect
52)]
53pub struct CharLength(pub(crate) u32);
54
55impl CharLength {
56    /// A length of one.
57    pub const ONE: CharLength = CharLength(1);
58
59    /// Consumes the newtype wrapper, returning the inner `u32`.
60    pub fn into_u32(self) -> u32 {
61        self.0
62    }
63}
64
65impl TryFrom<i64> for CharLength {
66    type Error = InvalidCharLengthError;
67
68    fn try_from(length: i64) -> Result<Self, Self::Error> {
69        match u32::try_from(length) {
70            Ok(length) if length > 0 && length < MAX_LENGTH => Ok(CharLength(length)),
71            _ => Err(InvalidCharLengthError),
72        }
73    }
74}
75
76#[cfg(any(test, feature = "proptest"))]
77impl Arbitrary for CharLength {
78    type Parameters = ();
79    type Strategy = BoxedStrategy<CharLength>;
80
81    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
82        proptest::arbitrary::any::<u32>()
83            // We cap the maximum CharLength to prevent generating massive
84            // strings which can greatly slow down tests and are relatively
85            // uninteresting.
86            .prop_map(|len| CharLength(len % 300))
87            .boxed()
88    }
89}
90
91/// The error returned when constructing a [`CharLength`] from an invalid value.
92#[derive(Debug, Clone)]
93pub struct InvalidCharLengthError;
94
95impl fmt::Display for InvalidCharLengthError {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        write!(
98            f,
99            "length for type character must be between 1 and {}",
100            MAX_LENGTH
101        )
102    }
103}
104
105impl Error for InvalidCharLengthError {}
106
107/// Controls how to handle trailing whitespace at the end of bpchar data.
108#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
109enum CharWhiteSpace {
110    /// Trim all whitespace from strings, which is appropriate for storing
111    /// bpchar data in Materialize. bpchar data is stored in datums with its
112    /// trailing whitespace trimmed to enforce the same equality semantics as
113    /// PG, while also allowing us to bit-wise equality on rows.
114    Trim,
115    /// Blank pad strings, which is appropriate for returning bpchar data out of Materialize.
116    Pad,
117}
118
119impl CharWhiteSpace {
120    fn process_str(&self, s: &str, length: Option<usize>) -> String {
121        use CharWhiteSpace::*;
122        match self {
123            Trim => s.trim_end().to_string(),
124            Pad => match length {
125                Some(length) => format!("{:width$}", s, width = length),
126                // This occurs when e.g. printing lists
127                None => s.to_string(),
128            },
129        }
130    }
131}
132
133/// Returns `s` as a `String` with options to enforce char and varchar
134/// semantics.
135///
136/// # Arguments
137/// * `s` - The `str` to format
138/// * `length` - An optional maximum length for the string
139/// * `fail_on_len` - Return an error if `s`'s character count exceeds the
140///   specified maximum length.
141/// * `white_space` - Express how to handle trailing whitespace on `s`
142///
143/// This function should only fail when `fail_on_len` is `true` _and_ `length`
144/// is present and exceeded.
145fn format_char_str(
146    s: &str,
147    length: Option<CharLength>,
148    fail_on_len: bool,
149    white_space: CharWhiteSpace,
150) -> Result<String, anyhow::Error> {
151    Ok(match length {
152        // Note that length is 1-indexed, so finding `None` means the string's
153        // characters don't exceed the length, while finding `Some` means it
154        // does.
155        Some(l) => {
156            let l = usize::cast_from(l.into_u32());
157            // The number of chars in a string is always less or equal to the length of the string.
158            // Hence, if the string is shorter than the length, we do not have to check for
159            // the maximum length.
160            if s.len() < l {
161                return Ok(white_space.process_str(s, Some(l)));
162            }
163            match s.char_indices().nth(l) {
164                None => white_space.process_str(s, Some(l)),
165                Some((idx, _)) => {
166                    if !fail_on_len || s[idx..].chars().all(|c| c.is_ascii_whitespace()) {
167                        white_space.process_str(&s[..idx], Some(l))
168                    } else {
169                        bail!("{} exceeds maximum length of {}", s, l)
170                    }
171                }
172            }
173        }
174        None => white_space.process_str(s, None),
175    })
176}
177
178/// Ensures that `s` has fewer than `length` characters, and returns a `String`
179/// version of it with all whitespace trimmed from the end.
180///
181/// The value returned is appropriate to store in `Datum::String`, but _is not_
182/// appropriate to return to clients.
183///
184/// This function should only fail when `fail_on_len` is `true` _and_ `length`
185/// is present and exceeded.
186pub fn format_str_trim(
187    s: &str,
188    length: Option<CharLength>,
189    fail_on_len: bool,
190) -> Result<String, anyhow::Error> {
191    format_char_str(s, length, fail_on_len, CharWhiteSpace::Trim)
192}
193
194/// Ensures that `s` has fewer than `length` characters, and returns a `String`
195/// version of it with blank padding so that its width is `length` characters.
196///
197/// The value returned is appropriate to return to clients, but _is not_
198/// appropriate to store in `Datum::String`.
199pub fn format_str_pad(s: &str, length: Option<CharLength>) -> String {
200    format_char_str(s, length, false, CharWhiteSpace::Pad).unwrap()
201}
202
203impl RustType<ProtoCharLength> for CharLength {
204    fn into_proto(&self) -> ProtoCharLength {
205        ProtoCharLength { value: self.0 }
206    }
207
208    fn from_proto(proto: ProtoCharLength) -> Result<Self, TryFromProtoError> {
209        Ok(CharLength(proto.value))
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use mz_ore::assert_ok;
216    use mz_proto::protobuf_roundtrip;
217    use proptest::prelude::*;
218
219    use super::*;
220
221    proptest! {
222        #[mz_ore::test]
223        fn char_length_protobuf_roundtrip(expect in any::<CharLength>()) {
224            let actual = protobuf_roundtrip::<_, ProtoCharLength>(&expect);
225            assert_ok!(actual);
226            assert_eq!(actual.unwrap(), expect);
227        }
228    }
229}