thrift/protocol/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
19use std::convert::{From, TryFrom};
20
21use super::{
22    TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier,
23    TMessageIdentifier, TMessageType,
24};
25use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
26use crate::transport::{TReadTransport, TWriteTransport};
27use crate::{ProtocolError, ProtocolErrorKind};
28
29const BINARY_PROTOCOL_VERSION_1: u32 = 0x8001_0000;
30
31/// Read messages encoded in the Thrift simple binary encoding.
32///
33/// There are two available modes: `strict` and `non-strict`, where the
34/// `non-strict` version does not check for the protocol version in the
35/// received message header.
36///
37/// # Examples
38///
39/// Create and use a `TBinaryInputProtocol`.
40///
41/// ```no_run
42/// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol};
43/// use thrift::transport::TTcpChannel;
44///
45/// let mut channel = TTcpChannel::new();
46/// channel.open("localhost:9090").unwrap();
47///
48/// let mut protocol = TBinaryInputProtocol::new(channel, true);
49///
50/// let recvd_bool = protocol.read_bool().unwrap();
51/// let recvd_string = protocol.read_string().unwrap();
52/// ```
53#[derive(Debug)]
54pub struct TBinaryInputProtocol<T>
55where
56    T: TReadTransport,
57{
58    strict: bool,
59    pub transport: T, // FIXME: shouldn't be public
60}
61
62impl<'a, T> TBinaryInputProtocol<T>
63where
64    T: TReadTransport,
65{
66    /// Create a `TBinaryInputProtocol` that reads bytes from `transport`.
67    ///
68    /// Set `strict` to `true` if all incoming messages contain the protocol
69    /// version number in the protocol header.
70    pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> {
71        TBinaryInputProtocol { strict, transport }
72    }
73}
74
75impl<T> TInputProtocol for TBinaryInputProtocol<T>
76where
77    T: TReadTransport,
78{
79    #[allow(clippy::collapsible_if)]
80    fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
81        let mut first_bytes = vec![0; 4];
82        self.transport.read_exact(&mut first_bytes[..])?;
83
84        // the thrift version header is intentionally negative
85        // so the first check we'll do is see if the sign bit is set
86        // and if so - assume it's the protocol-version header
87        if (first_bytes[0] & 0x80) != 0 {
88            // apparently we got a protocol-version header - check
89            // it, and if it matches, read the rest of the fields
90            if first_bytes[0..2] != [0x80, 0x01] {
91                Err(crate::Error::Protocol(ProtocolError {
92                    kind: ProtocolErrorKind::BadVersion,
93                    message: format!("received bad version: {:?}", &first_bytes[0..2]),
94                }))
95            } else {
96                let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?;
97                let name = self.read_string()?;
98                let sequence_number = self.read_i32()?;
99                Ok(TMessageIdentifier::new(name, message_type, sequence_number))
100            }
101        } else {
102            // apparently we didn't get a protocol-version header,
103            // which happens if the sender is not using the strict protocol
104            if self.strict {
105                // we're in strict mode however, and that always
106                // requires the protocol-version header to be written first
107                Err(crate::Error::Protocol(ProtocolError {
108                    kind: ProtocolErrorKind::BadVersion,
109                    message: format!("received bad version: {:?}", &first_bytes[0..2]),
110                }))
111            } else {
112                // in the non-strict version the first message field
113                // is the message name. strings (byte arrays) are length-prefixed,
114                // so we've just read the length in the first 4 bytes
115                let name_size = BigEndian::read_i32(&first_bytes) as usize;
116                let mut name_buf: Vec<u8> = vec![0; name_size];
117                self.transport.read_exact(&mut name_buf)?;
118                let name = String::from_utf8(name_buf)?;
119
120                // read the rest of the fields
121                let message_type: TMessageType = self.read_byte().and_then(TryFrom::try_from)?;
122                let sequence_number = self.read_i32()?;
123                Ok(TMessageIdentifier::new(name, message_type, sequence_number))
124            }
125        }
126    }
127
128    fn read_message_end(&mut self) -> crate::Result<()> {
129        Ok(())
130    }
131
132    fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
133        Ok(None)
134    }
135
136    fn read_struct_end(&mut self) -> crate::Result<()> {
137        Ok(())
138    }
139
140    fn read_field_begin(&mut self) -> crate::Result<TFieldIdentifier> {
141        let field_type_byte = self.read_byte()?;
142        let field_type = field_type_from_u8(field_type_byte)?;
143        let id = match field_type {
144            TType::Stop => Ok(0),
145            _ => self.read_i16(),
146        }?;
147        Ok(TFieldIdentifier::new::<Option<String>, String, i16>(
148            None, field_type, id,
149        ))
150    }
151
152    fn read_field_end(&mut self) -> crate::Result<()> {
153        Ok(())
154    }
155
156    fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
157        let num_bytes = self.transport.read_i32::<BigEndian>()? as usize;
158        let mut buf = vec![0u8; num_bytes];
159        self.transport
160            .read_exact(&mut buf)
161            .map(|_| buf)
162            .map_err(From::from)
163    }
164
165    fn read_bool(&mut self) -> crate::Result<bool> {
166        let b = self.read_i8()?;
167        match b {
168            0 => Ok(false),
169            _ => Ok(true),
170        }
171    }
172
173    fn read_i8(&mut self) -> crate::Result<i8> {
174        self.transport.read_i8().map_err(From::from)
175    }
176
177    fn read_i16(&mut self) -> crate::Result<i16> {
178        self.transport.read_i16::<BigEndian>().map_err(From::from)
179    }
180
181    fn read_i32(&mut self) -> crate::Result<i32> {
182        self.transport.read_i32::<BigEndian>().map_err(From::from)
183    }
184
185    fn read_i64(&mut self) -> crate::Result<i64> {
186        self.transport.read_i64::<BigEndian>().map_err(From::from)
187    }
188
189    fn read_double(&mut self) -> crate::Result<f64> {
190        self.transport.read_f64::<BigEndian>().map_err(From::from)
191    }
192
193    fn read_string(&mut self) -> crate::Result<String> {
194        let bytes = self.read_bytes()?;
195        String::from_utf8(bytes).map_err(From::from)
196    }
197
198    fn read_list_begin(&mut self) -> crate::Result<TListIdentifier> {
199        let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
200        let size = self.read_i32()?;
201        Ok(TListIdentifier::new(element_type, size))
202    }
203
204    fn read_list_end(&mut self) -> crate::Result<()> {
205        Ok(())
206    }
207
208    fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier> {
209        let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
210        let size = self.read_i32()?;
211        Ok(TSetIdentifier::new(element_type, size))
212    }
213
214    fn read_set_end(&mut self) -> crate::Result<()> {
215        Ok(())
216    }
217
218    fn read_map_begin(&mut self) -> crate::Result<TMapIdentifier> {
219        let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
220        let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
221        let size = self.read_i32()?;
222        Ok(TMapIdentifier::new(key_type, value_type, size))
223    }
224
225    fn read_map_end(&mut self) -> crate::Result<()> {
226        Ok(())
227    }
228
229    // utility
230    //
231
232    fn read_byte(&mut self) -> crate::Result<u8> {
233        self.transport.read_u8().map_err(From::from)
234    }
235}
236
237/// Factory for creating instances of `TBinaryInputProtocol`.
238#[derive(Default)]
239pub struct TBinaryInputProtocolFactory;
240
241impl TBinaryInputProtocolFactory {
242    /// Create a `TBinaryInputProtocolFactory`.
243    pub fn new() -> TBinaryInputProtocolFactory {
244        TBinaryInputProtocolFactory {}
245    }
246}
247
248impl TInputProtocolFactory for TBinaryInputProtocolFactory {
249    fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send> {
250        Box::new(TBinaryInputProtocol::new(transport, true))
251    }
252}
253
254/// Write messages using the Thrift simple binary encoding.
255///
256/// There are two available modes: `strict` and `non-strict`, where the
257/// `strict` version writes the protocol version number in the outgoing message
258/// header and the `non-strict` version does not.
259///
260/// # Examples
261///
262/// Create and use a `TBinaryOutputProtocol`.
263///
264/// ```no_run
265/// use thrift::protocol::{TBinaryOutputProtocol, TOutputProtocol};
266/// use thrift::transport::TTcpChannel;
267///
268/// let mut channel = TTcpChannel::new();
269/// channel.open("localhost:9090").unwrap();
270///
271/// let mut protocol = TBinaryOutputProtocol::new(channel, true);
272///
273/// protocol.write_bool(true).unwrap();
274/// protocol.write_string("test_string").unwrap();
275/// ```
276#[derive(Debug)]
277pub struct TBinaryOutputProtocol<T>
278where
279    T: TWriteTransport,
280{
281    strict: bool,
282    pub transport: T, // FIXME: do not make public; only public for testing!
283}
284
285impl<T> TBinaryOutputProtocol<T>
286where
287    T: TWriteTransport,
288{
289    /// Create a `TBinaryOutputProtocol` that writes bytes to `transport`.
290    ///
291    /// Set `strict` to `true` if all outgoing messages should contain the
292    /// protocol version number in the protocol header.
293    pub fn new(transport: T, strict: bool) -> TBinaryOutputProtocol<T> {
294        TBinaryOutputProtocol { strict, transport }
295    }
296}
297
298impl<T> TOutputProtocol for TBinaryOutputProtocol<T>
299where
300    T: TWriteTransport,
301{
302    fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> crate::Result<()> {
303        if self.strict {
304            let message_type: u8 = identifier.message_type.into();
305            let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32);
306            self.transport.write_u32::<BigEndian>(header)?;
307            self.write_string(&identifier.name)?;
308            self.write_i32(identifier.sequence_number)
309        } else {
310            self.write_string(&identifier.name)?;
311            self.write_byte(identifier.message_type.into())?;
312            self.write_i32(identifier.sequence_number)
313        }
314    }
315
316    fn write_message_end(&mut self) -> crate::Result<()> {
317        Ok(())
318    }
319
320    fn write_struct_begin(&mut self, _: &TStructIdentifier) -> crate::Result<()> {
321        Ok(())
322    }
323
324    fn write_struct_end(&mut self) -> crate::Result<()> {
325        Ok(())
326    }
327
328    fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> crate::Result<()> {
329        if identifier.id.is_none() && identifier.field_type != TType::Stop {
330            return Err(crate::Error::Protocol(ProtocolError {
331                kind: ProtocolErrorKind::Unknown,
332                message: format!(
333                    "cannot write identifier {:?} without sequence number",
334                    &identifier
335                ),
336            }));
337        }
338
339        self.write_byte(field_type_to_u8(identifier.field_type))?;
340        if let Some(id) = identifier.id {
341            self.write_i16(id)
342        } else {
343            Ok(())
344        }
345    }
346
347    fn write_field_end(&mut self) -> crate::Result<()> {
348        Ok(())
349    }
350
351    fn write_field_stop(&mut self) -> crate::Result<()> {
352        self.write_byte(field_type_to_u8(TType::Stop))
353    }
354
355    fn write_bytes(&mut self, b: &[u8]) -> crate::Result<()> {
356        self.write_i32(b.len() as i32)?;
357        self.transport.write_all(b).map_err(From::from)
358    }
359
360    fn write_bool(&mut self, b: bool) -> crate::Result<()> {
361        if b {
362            self.write_i8(1)
363        } else {
364            self.write_i8(0)
365        }
366    }
367
368    fn write_i8(&mut self, i: i8) -> crate::Result<()> {
369        self.transport.write_i8(i).map_err(From::from)
370    }
371
372    fn write_i16(&mut self, i: i16) -> crate::Result<()> {
373        self.transport.write_i16::<BigEndian>(i).map_err(From::from)
374    }
375
376    fn write_i32(&mut self, i: i32) -> crate::Result<()> {
377        self.transport.write_i32::<BigEndian>(i).map_err(From::from)
378    }
379
380    fn write_i64(&mut self, i: i64) -> crate::Result<()> {
381        self.transport.write_i64::<BigEndian>(i).map_err(From::from)
382    }
383
384    fn write_double(&mut self, d: f64) -> crate::Result<()> {
385        self.transport.write_f64::<BigEndian>(d).map_err(From::from)
386    }
387
388    fn write_string(&mut self, s: &str) -> crate::Result<()> {
389        self.write_bytes(s.as_bytes())
390    }
391
392    fn write_list_begin(&mut self, identifier: &TListIdentifier) -> crate::Result<()> {
393        self.write_byte(field_type_to_u8(identifier.element_type))?;
394        self.write_i32(identifier.size)
395    }
396
397    fn write_list_end(&mut self) -> crate::Result<()> {
398        Ok(())
399    }
400
401    fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> crate::Result<()> {
402        self.write_byte(field_type_to_u8(identifier.element_type))?;
403        self.write_i32(identifier.size)
404    }
405
406    fn write_set_end(&mut self) -> crate::Result<()> {
407        Ok(())
408    }
409
410    fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> crate::Result<()> {
411        let key_type = identifier
412            .key_type
413            .expect("map identifier to write should contain key type");
414        self.write_byte(field_type_to_u8(key_type))?;
415        let val_type = identifier
416            .value_type
417            .expect("map identifier to write should contain value type");
418        self.write_byte(field_type_to_u8(val_type))?;
419        self.write_i32(identifier.size)
420    }
421
422    fn write_map_end(&mut self) -> crate::Result<()> {
423        Ok(())
424    }
425
426    fn flush(&mut self) -> crate::Result<()> {
427        self.transport.flush().map_err(From::from)
428    }
429
430    // utility
431    //
432
433    fn write_byte(&mut self, b: u8) -> crate::Result<()> {
434        self.transport.write_u8(b).map_err(From::from)
435    }
436}
437
438/// Factory for creating instances of `TBinaryOutputProtocol`.
439#[derive(Default)]
440pub struct TBinaryOutputProtocolFactory;
441
442impl TBinaryOutputProtocolFactory {
443    /// Create a `TBinaryOutputProtocolFactory`.
444    pub fn new() -> TBinaryOutputProtocolFactory {
445        TBinaryOutputProtocolFactory {}
446    }
447}
448
449impl TOutputProtocolFactory for TBinaryOutputProtocolFactory {
450    fn create(
451        &self,
452        transport: Box<dyn TWriteTransport + Send>,
453    ) -> Box<dyn TOutputProtocol + Send> {
454        Box::new(TBinaryOutputProtocol::new(transport, true))
455    }
456}
457
458fn field_type_to_u8(field_type: TType) -> u8 {
459    match field_type {
460        TType::Stop => 0x00,
461        TType::Void => 0x01,
462        TType::Bool => 0x02,
463        TType::I08 => 0x03, // equivalent to TType::Byte
464        TType::Double => 0x04,
465        TType::I16 => 0x06,
466        TType::I32 => 0x08,
467        TType::I64 => 0x0A,
468        TType::String | TType::Utf7 => 0x0B,
469        TType::Struct => 0x0C,
470        TType::Map => 0x0D,
471        TType::Set => 0x0E,
472        TType::List => 0x0F,
473        TType::Utf8 => 0x10,
474        TType::Utf16 => 0x11,
475    }
476}
477
478fn field_type_from_u8(b: u8) -> crate::Result<TType> {
479    match b {
480        0x00 => Ok(TType::Stop),
481        0x01 => Ok(TType::Void),
482        0x02 => Ok(TType::Bool),
483        0x03 => Ok(TType::I08), // Equivalent to TType::Byte
484        0x04 => Ok(TType::Double),
485        0x06 => Ok(TType::I16),
486        0x08 => Ok(TType::I32),
487        0x0A => Ok(TType::I64),
488        0x0B => Ok(TType::String), // technically, also a UTF7, but we'll treat it as string
489        0x0C => Ok(TType::Struct),
490        0x0D => Ok(TType::Map),
491        0x0E => Ok(TType::Set),
492        0x0F => Ok(TType::List),
493        0x10 => Ok(TType::Utf8),
494        0x11 => Ok(TType::Utf16),
495        unkn => Err(crate::Error::Protocol(ProtocolError {
496            kind: ProtocolErrorKind::InvalidData,
497            message: format!("cannot convert {} to TType", unkn),
498        })),
499    }
500}
501
502#[cfg(test)]
503mod tests {
504
505    use crate::protocol::{
506        TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
507        TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
508    };
509    use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
510
511    use super::*;
512
513    #[test]
514    fn must_write_strict_message_call_begin() {
515        let (_, mut o_prot) = test_objects(true);
516
517        let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
518        assert!(o_prot.write_message_begin(&ident).is_ok());
519
520        #[rustfmt::skip]
521        let expected: [u8; 16] = [
522            0x80,
523            0x01,
524            0x00,
525            0x01,
526            0x00,
527            0x00,
528            0x00,
529            0x04,
530            0x74,
531            0x65,
532            0x73,
533            0x74,
534            0x00,
535            0x00,
536            0x00,
537            0x01,
538        ];
539
540        assert_eq_written_bytes!(o_prot, expected);
541    }
542
543    #[test]
544    fn must_write_non_strict_message_call_begin() {
545        let (_, mut o_prot) = test_objects(false);
546
547        let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
548        assert!(o_prot.write_message_begin(&ident).is_ok());
549
550        #[rustfmt::skip]
551        let expected: [u8; 13] = [
552            0x00,
553            0x00,
554            0x00,
555            0x04,
556            0x74,
557            0x65,
558            0x73,
559            0x74,
560            0x01,
561            0x00,
562            0x00,
563            0x00,
564            0x01,
565        ];
566
567        assert_eq_written_bytes!(o_prot, expected);
568    }
569
570    #[test]
571    fn must_write_strict_message_reply_begin() {
572        let (_, mut o_prot) = test_objects(true);
573
574        let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
575        assert!(o_prot.write_message_begin(&ident).is_ok());
576
577        #[rustfmt::skip]
578        let expected: [u8; 16] = [
579            0x80,
580            0x01,
581            0x00,
582            0x02,
583            0x00,
584            0x00,
585            0x00,
586            0x04,
587            0x74,
588            0x65,
589            0x73,
590            0x74,
591            0x00,
592            0x00,
593            0x00,
594            0x0A,
595        ];
596
597        assert_eq_written_bytes!(o_prot, expected);
598    }
599
600    #[test]
601    fn must_write_non_strict_message_reply_begin() {
602        let (_, mut o_prot) = test_objects(false);
603
604        let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
605        assert!(o_prot.write_message_begin(&ident).is_ok());
606
607        #[rustfmt::skip]
608        let expected: [u8; 13] = [
609            0x00,
610            0x00,
611            0x00,
612            0x04,
613            0x74,
614            0x65,
615            0x73,
616            0x74,
617            0x02,
618            0x00,
619            0x00,
620            0x00,
621            0x0A,
622        ];
623
624        assert_eq_written_bytes!(o_prot, expected);
625    }
626
627    #[test]
628    fn must_round_trip_strict_message_begin() {
629        let (mut i_prot, mut o_prot) = test_objects(true);
630
631        let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
632        assert!(o_prot.write_message_begin(&sent_ident).is_ok());
633
634        copy_write_buffer_to_read_buffer!(o_prot);
635
636        let received_ident = assert_success!(i_prot.read_message_begin());
637        assert_eq!(&received_ident, &sent_ident);
638    }
639
640    #[test]
641    fn must_round_trip_non_strict_message_begin() {
642        let (mut i_prot, mut o_prot) = test_objects(false);
643
644        let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
645        assert!(o_prot.write_message_begin(&sent_ident).is_ok());
646
647        copy_write_buffer_to_read_buffer!(o_prot);
648
649        let received_ident = assert_success!(i_prot.read_message_begin());
650        assert_eq!(&received_ident, &sent_ident);
651    }
652
653    #[test]
654    fn must_write_message_end() {
655        assert_no_write(|o| o.write_message_end(), true);
656    }
657
658    #[test]
659    fn must_write_struct_begin() {
660        assert_no_write(
661            |o| o.write_struct_begin(&TStructIdentifier::new("foo")),
662            true,
663        );
664    }
665
666    #[test]
667    fn must_write_struct_end() {
668        assert_no_write(|o| o.write_struct_end(), true);
669    }
670
671    #[test]
672    fn must_write_field_begin() {
673        let (_, mut o_prot) = test_objects(true);
674
675        assert!(o_prot
676            .write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22))
677            .is_ok());
678
679        let expected: [u8; 3] = [0x0B, 0x00, 0x16];
680        assert_eq_written_bytes!(o_prot, expected);
681    }
682
683    #[test]
684    fn must_round_trip_field_begin() {
685        let (mut i_prot, mut o_prot) = test_objects(true);
686
687        let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20);
688        assert!(o_prot.write_field_begin(&sent_field_ident).is_ok());
689
690        copy_write_buffer_to_read_buffer!(o_prot);
691
692        let expected_ident = TFieldIdentifier {
693            name: None,
694            field_type: TType::I64,
695            id: Some(20),
696        }; // no name
697        let received_ident = assert_success!(i_prot.read_field_begin());
698        assert_eq!(&received_ident, &expected_ident);
699    }
700
701    #[test]
702    fn must_write_stop_field() {
703        let (_, mut o_prot) = test_objects(true);
704
705        assert!(o_prot.write_field_stop().is_ok());
706
707        let expected: [u8; 1] = [0x00];
708        assert_eq_written_bytes!(o_prot, expected);
709    }
710
711    #[test]
712    fn must_round_trip_field_stop() {
713        let (mut i_prot, mut o_prot) = test_objects(true);
714
715        assert!(o_prot.write_field_stop().is_ok());
716
717        copy_write_buffer_to_read_buffer!(o_prot);
718
719        let expected_ident = TFieldIdentifier {
720            name: None,
721            field_type: TType::Stop,
722            id: Some(0),
723        }; // we get id 0
724
725        let received_ident = assert_success!(i_prot.read_field_begin());
726        assert_eq!(&received_ident, &expected_ident);
727    }
728
729    #[test]
730    fn must_write_field_end() {
731        assert_no_write(|o| o.write_field_end(), true);
732    }
733
734    #[test]
735    fn must_write_list_begin() {
736        let (_, mut o_prot) = test_objects(true);
737
738        assert!(o_prot
739            .write_list_begin(&TListIdentifier::new(TType::Bool, 5))
740            .is_ok());
741
742        let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05];
743        assert_eq_written_bytes!(o_prot, expected);
744    }
745
746    #[test]
747    fn must_round_trip_list_begin() {
748        let (mut i_prot, mut o_prot) = test_objects(true);
749
750        let ident = TListIdentifier::new(TType::List, 900);
751        assert!(o_prot.write_list_begin(&ident).is_ok());
752
753        copy_write_buffer_to_read_buffer!(o_prot);
754
755        let received_ident = assert_success!(i_prot.read_list_begin());
756        assert_eq!(&received_ident, &ident);
757    }
758
759    #[test]
760    fn must_write_list_end() {
761        assert_no_write(|o| o.write_list_end(), true);
762    }
763
764    #[test]
765    fn must_write_set_begin() {
766        let (_, mut o_prot) = test_objects(true);
767
768        assert!(o_prot
769            .write_set_begin(&TSetIdentifier::new(TType::I16, 7))
770            .is_ok());
771
772        let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07];
773        assert_eq_written_bytes!(o_prot, expected);
774    }
775
776    #[test]
777    fn must_round_trip_set_begin() {
778        let (mut i_prot, mut o_prot) = test_objects(true);
779
780        let ident = TSetIdentifier::new(TType::I64, 2000);
781        assert!(o_prot.write_set_begin(&ident).is_ok());
782
783        copy_write_buffer_to_read_buffer!(o_prot);
784
785        let received_ident_result = i_prot.read_set_begin();
786        assert!(received_ident_result.is_ok());
787        assert_eq!(&received_ident_result.unwrap(), &ident);
788    }
789
790    #[test]
791    fn must_write_set_end() {
792        assert_no_write(|o| o.write_set_end(), true);
793    }
794
795    #[test]
796    fn must_write_map_begin() {
797        let (_, mut o_prot) = test_objects(true);
798
799        assert!(o_prot
800            .write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32))
801            .is_ok());
802
803        let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20];
804        assert_eq_written_bytes!(o_prot, expected);
805    }
806
807    #[test]
808    fn must_round_trip_map_begin() {
809        let (mut i_prot, mut o_prot) = test_objects(true);
810
811        let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
812        assert!(o_prot.write_map_begin(&ident).is_ok());
813
814        copy_write_buffer_to_read_buffer!(o_prot);
815
816        let received_ident = assert_success!(i_prot.read_map_begin());
817        assert_eq!(&received_ident, &ident);
818    }
819
820    #[test]
821    fn must_write_map_end() {
822        assert_no_write(|o| o.write_map_end(), true);
823    }
824
825    #[test]
826    fn must_write_bool_true() {
827        let (_, mut o_prot) = test_objects(true);
828
829        assert!(o_prot.write_bool(true).is_ok());
830
831        let expected: [u8; 1] = [0x01];
832        assert_eq_written_bytes!(o_prot, expected);
833    }
834
835    #[test]
836    fn must_write_bool_false() {
837        let (_, mut o_prot) = test_objects(true);
838
839        assert!(o_prot.write_bool(false).is_ok());
840
841        let expected: [u8; 1] = [0x00];
842        assert_eq_written_bytes!(o_prot, expected);
843    }
844
845    #[test]
846    fn must_read_bool_true() {
847        let (mut i_prot, _) = test_objects(true);
848
849        set_readable_bytes!(i_prot, &[0x01]);
850
851        let read_bool = assert_success!(i_prot.read_bool());
852        assert_eq!(read_bool, true);
853    }
854
855    #[test]
856    fn must_read_bool_false() {
857        let (mut i_prot, _) = test_objects(true);
858
859        set_readable_bytes!(i_prot, &[0x00]);
860
861        let read_bool = assert_success!(i_prot.read_bool());
862        assert_eq!(read_bool, false);
863    }
864
865    #[test]
866    fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() {
867        let (mut i_prot, _) = test_objects(true);
868
869        set_readable_bytes!(i_prot, &[0xAC]);
870
871        let read_bool = assert_success!(i_prot.read_bool());
872        assert_eq!(read_bool, true);
873    }
874
875    #[test]
876    fn must_write_bytes() {
877        let (_, mut o_prot) = test_objects(true);
878
879        let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF];
880
881        assert!(o_prot.write_bytes(&bytes).is_ok());
882
883        let buf = o_prot.transport.write_bytes();
884        assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); // length
885        assert_eq!(&buf[4..], bytes); // actual bytes
886    }
887
888    #[test]
889    fn must_round_trip_bytes() {
890        let (mut i_prot, mut o_prot) = test_objects(true);
891
892        #[rustfmt::skip]
893        let bytes: [u8; 25] = [
894            0x20,
895            0xFD,
896            0x18,
897            0x84,
898            0x99,
899            0x12,
900            0xAB,
901            0xBB,
902            0x45,
903            0xDF,
904            0x34,
905            0xDC,
906            0x98,
907            0xA4,
908            0x6D,
909            0xF3,
910            0x99,
911            0xB4,
912            0xB7,
913            0xD4,
914            0x9C,
915            0xA5,
916            0xB3,
917            0xC9,
918            0x88,
919        ];
920
921        assert!(o_prot.write_bytes(&bytes).is_ok());
922
923        copy_write_buffer_to_read_buffer!(o_prot);
924
925        let received_bytes = assert_success!(i_prot.read_bytes());
926        assert_eq!(&received_bytes, &bytes);
927    }
928
929    fn test_objects(
930        strict: bool,
931    ) -> (
932        TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
933        TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
934    ) {
935        let mem = TBufferChannel::with_capacity(40, 40);
936
937        let (r_mem, w_mem) = mem.split().unwrap();
938
939        let i_prot = TBinaryInputProtocol::new(r_mem, strict);
940        let o_prot = TBinaryOutputProtocol::new(w_mem, strict);
941
942        (i_prot, o_prot)
943    }
944
945    fn assert_no_write<F>(mut write_fn: F, strict: bool)
946    where
947        F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> crate::Result<()>,
948    {
949        let (_, mut o_prot) = test_objects(strict);
950        assert!(write_fn(&mut o_prot).is_ok());
951        assert_eq!(o_prot.transport.write_bytes().len(), 0);
952    }
953}