parquet/
thrift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Custom thrift definitions
19
20pub use thrift::protocol::TCompactOutputProtocol;
21use thrift::protocol::{
22    TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
23    TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
24};
25
26/// Reads and writes the struct to Thrift protocols.
27///
28/// Unlike [`thrift::protocol::TSerializable`] this uses generics instead of trait objects
29pub trait TSerializable: Sized {
30    /// Reads the struct from the input Thrift protocol
31    fn read_from_in_protocol<T: TInputProtocol>(i_prot: &mut T) -> thrift::Result<Self>;
32    /// Writes the struct to the output Thrift protocol
33    fn write_to_out_protocol<T: TOutputProtocol>(&self, o_prot: &mut T) -> thrift::Result<()>;
34}
35
36/// A more performant implementation of [`TCompactInputProtocol`] that reads a slice
37///
38/// [`TCompactInputProtocol`]: thrift::protocol::TCompactInputProtocol
39pub(crate) struct TCompactSliceInputProtocol<'a> {
40    buf: &'a [u8],
41    // Identifier of the last field deserialized for a struct.
42    last_read_field_id: i16,
43    // Stack of the last read field ids (a new entry is added each time a nested struct is read).
44    read_field_id_stack: Vec<i16>,
45    // Boolean value for a field.
46    // Saved because boolean fields and their value are encoded in a single byte,
47    // and reading the field only occurs after the field id is read.
48    pending_read_bool_value: Option<bool>,
49}
50
51impl<'a> TCompactSliceInputProtocol<'a> {
52    pub fn new(buf: &'a [u8]) -> Self {
53        Self {
54            buf,
55            last_read_field_id: 0,
56            read_field_id_stack: Vec::with_capacity(16),
57            pending_read_bool_value: None,
58        }
59    }
60
61    pub fn as_slice(&self) -> &'a [u8] {
62        self.buf
63    }
64
65    fn read_vlq(&mut self) -> thrift::Result<u64> {
66        let mut in_progress = 0;
67        let mut shift = 0;
68        loop {
69            let byte = self.read_byte()?;
70            in_progress |= ((byte & 0x7F) as u64) << shift;
71            shift += 7;
72            if byte & 0x80 == 0 {
73                return Ok(in_progress);
74            }
75        }
76    }
77
78    fn read_zig_zag(&mut self) -> thrift::Result<i64> {
79        let val = self.read_vlq()?;
80        Ok((val >> 1) as i64 ^ -((val & 1) as i64))
81    }
82
83    fn read_list_set_begin(&mut self) -> thrift::Result<(TType, i32)> {
84        let header = self.read_byte()?;
85        let element_type = collection_u8_to_type(header & 0x0F)?;
86
87        let possible_element_count = (header & 0xF0) >> 4;
88        let element_count = if possible_element_count != 15 {
89            // high bits set high if count and type encoded separately
90            possible_element_count as i32
91        } else {
92            self.read_vlq()? as _
93        };
94
95        Ok((element_type, element_count))
96    }
97}
98
99impl TInputProtocol for TCompactSliceInputProtocol<'_> {
100    fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
101        unimplemented!()
102    }
103
104    fn read_message_end(&mut self) -> thrift::Result<()> {
105        unimplemented!()
106    }
107
108    fn read_struct_begin(&mut self) -> thrift::Result<Option<TStructIdentifier>> {
109        self.read_field_id_stack.push(self.last_read_field_id);
110        self.last_read_field_id = 0;
111        Ok(None)
112    }
113
114    fn read_struct_end(&mut self) -> thrift::Result<()> {
115        self.last_read_field_id = self
116            .read_field_id_stack
117            .pop()
118            .expect("should have previous field ids");
119        Ok(())
120    }
121
122    fn read_field_begin(&mut self) -> thrift::Result<TFieldIdentifier> {
123        // we can read at least one byte, which is:
124        // - the type
125        // - the field delta and the type
126        let field_type = self.read_byte()?;
127        let field_delta = (field_type & 0xF0) >> 4;
128        let field_type = match field_type & 0x0F {
129            0x01 => {
130                self.pending_read_bool_value = Some(true);
131                Ok(TType::Bool)
132            }
133            0x02 => {
134                self.pending_read_bool_value = Some(false);
135                Ok(TType::Bool)
136            }
137            ttu8 => u8_to_type(ttu8),
138        }?;
139
140        match field_type {
141            TType::Stop => Ok(
142                TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
143                    None,
144                    TType::Stop,
145                    None,
146                ),
147            ),
148            _ => {
149                if field_delta != 0 {
150                    self.last_read_field_id += field_delta as i16;
151                } else {
152                    self.last_read_field_id = self.read_i16()?;
153                };
154
155                Ok(TFieldIdentifier {
156                    name: None,
157                    field_type,
158                    id: Some(self.last_read_field_id),
159                })
160            }
161        }
162    }
163
164    fn read_field_end(&mut self) -> thrift::Result<()> {
165        Ok(())
166    }
167
168    fn read_bool(&mut self) -> thrift::Result<bool> {
169        match self.pending_read_bool_value.take() {
170            Some(b) => Ok(b),
171            None => {
172                let b = self.read_byte()?;
173                match b {
174                    0x01 => Ok(true),
175                    0x02 => Ok(false),
176                    unkn => Err(thrift::Error::Protocol(thrift::ProtocolError {
177                        kind: thrift::ProtocolErrorKind::InvalidData,
178                        message: format!("cannot convert {} into bool", unkn),
179                    })),
180                }
181            }
182        }
183    }
184
185    fn read_bytes(&mut self) -> thrift::Result<Vec<u8>> {
186        let len = self.read_vlq()? as usize;
187        let ret = self.buf.get(..len).ok_or_else(eof_error)?.to_vec();
188        self.buf = &self.buf[len..];
189        Ok(ret)
190    }
191
192    fn read_i8(&mut self) -> thrift::Result<i8> {
193        Ok(self.read_byte()? as _)
194    }
195
196    fn read_i16(&mut self) -> thrift::Result<i16> {
197        Ok(self.read_zig_zag()? as _)
198    }
199
200    fn read_i32(&mut self) -> thrift::Result<i32> {
201        Ok(self.read_zig_zag()? as _)
202    }
203
204    fn read_i64(&mut self) -> thrift::Result<i64> {
205        self.read_zig_zag()
206    }
207
208    fn read_double(&mut self) -> thrift::Result<f64> {
209        let slice = (self.buf[..8]).try_into().unwrap();
210        self.buf = &self.buf[8..];
211        Ok(f64::from_le_bytes(slice))
212    }
213
214    fn read_string(&mut self) -> thrift::Result<String> {
215        let bytes = self.read_bytes()?;
216        String::from_utf8(bytes).map_err(From::from)
217    }
218
219    fn read_list_begin(&mut self) -> thrift::Result<TListIdentifier> {
220        let (element_type, element_count) = self.read_list_set_begin()?;
221        Ok(TListIdentifier::new(element_type, element_count))
222    }
223
224    fn read_list_end(&mut self) -> thrift::Result<()> {
225        Ok(())
226    }
227
228    fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
229        unimplemented!()
230    }
231
232    fn read_set_end(&mut self) -> thrift::Result<()> {
233        unimplemented!()
234    }
235
236    fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
237        unimplemented!()
238    }
239
240    fn read_map_end(&mut self) -> thrift::Result<()> {
241        Ok(())
242    }
243
244    #[inline]
245    fn read_byte(&mut self) -> thrift::Result<u8> {
246        let ret = *self.buf.first().ok_or_else(eof_error)?;
247        self.buf = &self.buf[1..];
248        Ok(ret)
249    }
250}
251
252fn collection_u8_to_type(b: u8) -> thrift::Result<TType> {
253    match b {
254        0x01 => Ok(TType::Bool),
255        o => u8_to_type(o),
256    }
257}
258
259fn u8_to_type(b: u8) -> thrift::Result<TType> {
260    match b {
261        0x00 => Ok(TType::Stop),
262        0x03 => Ok(TType::I08), // equivalent to TType::Byte
263        0x04 => Ok(TType::I16),
264        0x05 => Ok(TType::I32),
265        0x06 => Ok(TType::I64),
266        0x07 => Ok(TType::Double),
267        0x08 => Ok(TType::String),
268        0x09 => Ok(TType::List),
269        0x0A => Ok(TType::Set),
270        0x0B => Ok(TType::Map),
271        0x0C => Ok(TType::Struct),
272        unkn => Err(thrift::Error::Protocol(thrift::ProtocolError {
273            kind: thrift::ProtocolErrorKind::InvalidData,
274            message: format!("cannot convert {} into TType", unkn),
275        })),
276    }
277}
278
279fn eof_error() -> thrift::Error {
280    thrift::Error::Transport(thrift::TransportError {
281        kind: thrift::TransportErrorKind::EndOfFile,
282        message: "Unexpected EOF".to_string(),
283    })
284}