Skip to main content

mz_repr/adt/
char.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12
13use anyhow::bail;
14use mz_lowertest::MzReflect;
15use mz_ore::cast::CastFrom;
16use mz_proto::{RustType, TryFromProtoError};
17use proptest::arbitrary::Arbitrary;
18use proptest::strategy::{BoxedStrategy, Strategy};
19use serde::{Deserialize, Serialize};
20
21include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.char.rs"));
22
23// https://github.com/postgres/postgres/blob/REL_14_0/src/include/access/htup_details.h#L577-L584
24const MAX_LENGTH: u32 = 10_485_760;
25
26/// A marker type indicating that a Rust string should be interpreted as a
27/// [`SqlScalarType::Char`].
28///
29/// [`SqlScalarType::Char`]: crate::SqlScalarType::Char
30#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
31pub struct Char<S: AsRef<str>>(pub S);
32
33/// The `length` of a [`SqlScalarType::Char`].
34///
35/// This newtype wrapper ensures that the length is within the valid range.
36///
37/// [`SqlScalarType::Char`]: crate::SqlScalarType::Char
38#[derive(
39    Debug,
40    Clone,
41    Copy,
42    Eq,
43    PartialEq,
44    Ord,
45    PartialOrd,
46    Hash,
47    Serialize,
48    Deserialize,
49    MzReflect
50)]
51pub struct CharLength(pub(crate) u32);
52
53impl CharLength {
54    /// A length of one.
55    pub const ONE: CharLength = CharLength(1);
56
57    /// Consumes the newtype wrapper, returning the inner `u32`.
58    pub fn into_u32(self) -> u32 {
59        self.0
60    }
61}
62
63impl TryFrom<i64> for CharLength {
64    type Error = InvalidCharLengthError;
65
66    fn try_from(length: i64) -> Result<Self, Self::Error> {
67        match u32::try_from(length) {
68            Ok(length) if length > 0 && length < MAX_LENGTH => Ok(CharLength(length)),
69            _ => Err(InvalidCharLengthError),
70        }
71    }
72}
73
74impl Arbitrary for CharLength {
75    type Parameters = ();
76    type Strategy = BoxedStrategy<CharLength>;
77
78    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
79        proptest::arbitrary::any::<u32>()
80            // We cap the maximum CharLength to prevent generating massive
81            // strings which can greatly slow down tests and are relatively
82            // uninteresting.
83            .prop_map(|len| CharLength(len % 300))
84            .boxed()
85    }
86}
87
88/// The error returned when constructing a [`CharLength`] from an invalid value.
89#[derive(Debug, Clone)]
90pub struct InvalidCharLengthError;
91
92impl fmt::Display for InvalidCharLengthError {
93    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94        write!(
95            f,
96            "length for type character must be between 1 and {}",
97            MAX_LENGTH
98        )
99    }
100}
101
102impl Error for InvalidCharLengthError {}
103
104/// Controls how to handle trailing whitespace at the end of bpchar data.
105#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
106enum CharWhiteSpace {
107    /// Trim all whitespace from strings, which is appropriate for storing
108    /// bpchar data in Materialize. bpchar data is stored in datums with its
109    /// trailing whitespace trimmed to enforce the same equality semantics as
110    /// PG, while also allowing us to bit-wise equality on rows.
111    Trim,
112    /// Blank pad strings, which is appropriate for returning bpchar data out of Materialize.
113    Pad,
114}
115
116impl CharWhiteSpace {
117    fn process_str(&self, s: &str, length: Option<usize>) -> String {
118        use CharWhiteSpace::*;
119        match self {
120            Trim => s.trim_end().to_string(),
121            Pad => match length {
122                Some(length) => format!("{:width$}", s, width = length),
123                // This occurs when e.g. printing lists
124                None => s.to_string(),
125            },
126        }
127    }
128}
129
130/// Returns `s` as a `String` with options to enforce char and varchar
131/// semantics.
132///
133/// # Arguments
134/// * `s` - The `str` to format
135/// * `length` - An optional maximum length for the string
136/// * `fail_on_len` - Return an error if `s`'s character count exceeds the
137///   specified maximum length.
138/// * `white_space` - Express how to handle trailing whitespace on `s`
139///
140/// This function should only fail when `fail_on_len` is `true` _and_ `length`
141/// is present and exceeded.
142fn format_char_str(
143    s: &str,
144    length: Option<CharLength>,
145    fail_on_len: bool,
146    white_space: CharWhiteSpace,
147) -> Result<String, anyhow::Error> {
148    Ok(match length {
149        // Note that length is 1-indexed, so finding `None` means the string's
150        // characters don't exceed the length, while finding `Some` means it
151        // does.
152        Some(l) => {
153            let l = usize::cast_from(l.into_u32());
154            // The number of chars in a string is always less or equal to the length of the string.
155            // Hence, if the string is shorter than the length, we do not have to check for
156            // the maximum length.
157            if s.len() < l {
158                return Ok(white_space.process_str(s, Some(l)));
159            }
160            match s.char_indices().nth(l) {
161                None => white_space.process_str(s, Some(l)),
162                Some((idx, _)) => {
163                    if !fail_on_len || s[idx..].chars().all(|c| c.is_ascii_whitespace()) {
164                        white_space.process_str(&s[..idx], Some(l))
165                    } else {
166                        bail!("{} exceeds maximum length of {}", s, l)
167                    }
168                }
169            }
170        }
171        None => white_space.process_str(s, None),
172    })
173}
174
175/// Ensures that `s` has fewer than `length` characters, and returns a `String`
176/// version of it with all whitespace trimmed from the end.
177///
178/// The value returned is appropriate to store in `Datum::String`, but _is not_
179/// appropriate to return to clients.
180///
181/// This function should only fail when `fail_on_len` is `true` _and_ `length`
182/// is present and exceeded.
183pub fn format_str_trim(
184    s: &str,
185    length: Option<CharLength>,
186    fail_on_len: bool,
187) -> Result<String, anyhow::Error> {
188    format_char_str(s, length, fail_on_len, CharWhiteSpace::Trim)
189}
190
191/// Ensures that `s` has fewer than `length` characters, and returns a `String`
192/// version of it with blank padding so that its width is `length` characters.
193///
194/// The value returned is appropriate to return to clients, but _is not_
195/// appropriate to store in `Datum::String`.
196pub fn format_str_pad(s: &str, length: Option<CharLength>) -> String {
197    format_char_str(s, length, false, CharWhiteSpace::Pad).unwrap()
198}
199
200impl RustType<ProtoCharLength> for CharLength {
201    fn into_proto(&self) -> ProtoCharLength {
202        ProtoCharLength { value: self.0 }
203    }
204
205    fn from_proto(proto: ProtoCharLength) -> Result<Self, TryFromProtoError> {
206        Ok(CharLength(proto.value))
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use mz_ore::assert_ok;
213    use mz_proto::protobuf_roundtrip;
214    use proptest::prelude::*;
215
216    use super::*;
217
218    proptest! {
219        #[mz_ore::test]
220        fn char_length_protobuf_roundtrip(expect in any::<CharLength>()) {
221            let actual = protobuf_roundtrip::<_, ProtoCharLength>(&expect);
222            assert_ok!(actual);
223            assert_eq!(actual.unwrap(), expect);
224        }
225    }
226}