use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
use std::convert::{From, TryFrom};
use super::{
TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier,
TMessageIdentifier, TMessageType,
};
use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
use crate::transport::{TReadTransport, TWriteTransport};
use crate::{ProtocolError, ProtocolErrorKind};
const BINARY_PROTOCOL_VERSION_1: u32 = 0x8001_0000;
#[derive(Debug)]
pub struct TBinaryInputProtocol<T>
where
T: TReadTransport,
{
strict: bool,
pub transport: T, }
impl<'a, T> TBinaryInputProtocol<T>
where
T: TReadTransport,
{
pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> {
TBinaryInputProtocol { strict, transport }
}
}
impl<T> TInputProtocol for TBinaryInputProtocol<T>
where
T: TReadTransport,
{
#[allow(clippy::collapsible_if)]
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
let mut first_bytes = vec![0; 4];
self.transport.read_exact(&mut first_bytes[..])?;
if (first_bytes[0] & 0x80) != 0 {
if first_bytes[0..2] != [0x80, 0x01] {
Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::BadVersion,
message: format!("received bad version: {:?}", &first_bytes[0..2]),
}))
} else {
let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?;
let name = self.read_string()?;
let sequence_number = self.read_i32()?;
Ok(TMessageIdentifier::new(name, message_type, sequence_number))
}
} else {
if self.strict {
Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::BadVersion,
message: format!("received bad version: {:?}", &first_bytes[0..2]),
}))
} else {
let name_size = BigEndian::read_i32(&first_bytes) as usize;
let mut name_buf: Vec<u8> = vec![0; name_size];
self.transport.read_exact(&mut name_buf)?;
let name = String::from_utf8(name_buf)?;
let message_type: TMessageType = self.read_byte().and_then(TryFrom::try_from)?;
let sequence_number = self.read_i32()?;
Ok(TMessageIdentifier::new(name, message_type, sequence_number))
}
}
}
fn read_message_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
Ok(None)
}
fn read_struct_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn read_field_begin(&mut self) -> crate::Result<TFieldIdentifier> {
let field_type_byte = self.read_byte()?;
let field_type = field_type_from_u8(field_type_byte)?;
let id = match field_type {
TType::Stop => Ok(0),
_ => self.read_i16(),
}?;
Ok(TFieldIdentifier::new::<Option<String>, String, i16>(
None, field_type, id,
))
}
fn read_field_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
let num_bytes = self.transport.read_i32::<BigEndian>()? as usize;
let mut buf = vec![0u8; num_bytes];
self.transport
.read_exact(&mut buf)
.map(|_| buf)
.map_err(From::from)
}
fn read_bool(&mut self) -> crate::Result<bool> {
let b = self.read_i8()?;
match b {
0 => Ok(false),
_ => Ok(true),
}
}
fn read_i8(&mut self) -> crate::Result<i8> {
self.transport.read_i8().map_err(From::from)
}
fn read_i16(&mut self) -> crate::Result<i16> {
self.transport.read_i16::<BigEndian>().map_err(From::from)
}
fn read_i32(&mut self) -> crate::Result<i32> {
self.transport.read_i32::<BigEndian>().map_err(From::from)
}
fn read_i64(&mut self) -> crate::Result<i64> {
self.transport.read_i64::<BigEndian>().map_err(From::from)
}
fn read_double(&mut self) -> crate::Result<f64> {
self.transport.read_f64::<BigEndian>().map_err(From::from)
}
fn read_string(&mut self) -> crate::Result<String> {
let bytes = self.read_bytes()?;
String::from_utf8(bytes).map_err(From::from)
}
fn read_list_begin(&mut self) -> crate::Result<TListIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
Ok(TListIdentifier::new(element_type, size))
}
fn read_list_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
Ok(TSetIdentifier::new(element_type, size))
}
fn read_set_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn read_map_begin(&mut self) -> crate::Result<TMapIdentifier> {
let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
Ok(TMapIdentifier::new(key_type, value_type, size))
}
fn read_map_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn read_byte(&mut self) -> crate::Result<u8> {
self.transport.read_u8().map_err(From::from)
}
}
#[derive(Default)]
pub struct TBinaryInputProtocolFactory;
impl TBinaryInputProtocolFactory {
pub fn new() -> TBinaryInputProtocolFactory {
TBinaryInputProtocolFactory {}
}
}
impl TInputProtocolFactory for TBinaryInputProtocolFactory {
fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send> {
Box::new(TBinaryInputProtocol::new(transport, true))
}
}
#[derive(Debug)]
pub struct TBinaryOutputProtocol<T>
where
T: TWriteTransport,
{
strict: bool,
pub transport: T, }
impl<T> TBinaryOutputProtocol<T>
where
T: TWriteTransport,
{
pub fn new(transport: T, strict: bool) -> TBinaryOutputProtocol<T> {
TBinaryOutputProtocol { strict, transport }
}
}
impl<T> TOutputProtocol for TBinaryOutputProtocol<T>
where
T: TWriteTransport,
{
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> crate::Result<()> {
if self.strict {
let message_type: u8 = identifier.message_type.into();
let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32);
self.transport.write_u32::<BigEndian>(header)?;
self.write_string(&identifier.name)?;
self.write_i32(identifier.sequence_number)
} else {
self.write_string(&identifier.name)?;
self.write_byte(identifier.message_type.into())?;
self.write_i32(identifier.sequence_number)
}
}
fn write_message_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn write_struct_begin(&mut self, _: &TStructIdentifier) -> crate::Result<()> {
Ok(())
}
fn write_struct_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> crate::Result<()> {
if identifier.id.is_none() && identifier.field_type != TType::Stop {
return Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::Unknown,
message: format!(
"cannot write identifier {:?} without sequence number",
&identifier
),
}));
}
self.write_byte(field_type_to_u8(identifier.field_type))?;
if let Some(id) = identifier.id {
self.write_i16(id)
} else {
Ok(())
}
}
fn write_field_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn write_field_stop(&mut self) -> crate::Result<()> {
self.write_byte(field_type_to_u8(TType::Stop))
}
fn write_bytes(&mut self, b: &[u8]) -> crate::Result<()> {
self.write_i32(b.len() as i32)?;
self.transport.write_all(b).map_err(From::from)
}
fn write_bool(&mut self, b: bool) -> crate::Result<()> {
if b {
self.write_i8(1)
} else {
self.write_i8(0)
}
}
fn write_i8(&mut self, i: i8) -> crate::Result<()> {
self.transport.write_i8(i).map_err(From::from)
}
fn write_i16(&mut self, i: i16) -> crate::Result<()> {
self.transport.write_i16::<BigEndian>(i).map_err(From::from)
}
fn write_i32(&mut self, i: i32) -> crate::Result<()> {
self.transport.write_i32::<BigEndian>(i).map_err(From::from)
}
fn write_i64(&mut self, i: i64) -> crate::Result<()> {
self.transport.write_i64::<BigEndian>(i).map_err(From::from)
}
fn write_double(&mut self, d: f64) -> crate::Result<()> {
self.transport.write_f64::<BigEndian>(d).map_err(From::from)
}
fn write_string(&mut self, s: &str) -> crate::Result<()> {
self.write_bytes(s.as_bytes())
}
fn write_list_begin(&mut self, identifier: &TListIdentifier) -> crate::Result<()> {
self.write_byte(field_type_to_u8(identifier.element_type))?;
self.write_i32(identifier.size)
}
fn write_list_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> crate::Result<()> {
self.write_byte(field_type_to_u8(identifier.element_type))?;
self.write_i32(identifier.size)
}
fn write_set_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> crate::Result<()> {
let key_type = identifier
.key_type
.expect("map identifier to write should contain key type");
self.write_byte(field_type_to_u8(key_type))?;
let val_type = identifier
.value_type
.expect("map identifier to write should contain value type");
self.write_byte(field_type_to_u8(val_type))?;
self.write_i32(identifier.size)
}
fn write_map_end(&mut self) -> crate::Result<()> {
Ok(())
}
fn flush(&mut self) -> crate::Result<()> {
self.transport.flush().map_err(From::from)
}
fn write_byte(&mut self, b: u8) -> crate::Result<()> {
self.transport.write_u8(b).map_err(From::from)
}
}
#[derive(Default)]
pub struct TBinaryOutputProtocolFactory;
impl TBinaryOutputProtocolFactory {
pub fn new() -> TBinaryOutputProtocolFactory {
TBinaryOutputProtocolFactory {}
}
}
impl TOutputProtocolFactory for TBinaryOutputProtocolFactory {
fn create(
&self,
transport: Box<dyn TWriteTransport + Send>,
) -> Box<dyn TOutputProtocol + Send> {
Box::new(TBinaryOutputProtocol::new(transport, true))
}
}
fn field_type_to_u8(field_type: TType) -> u8 {
match field_type {
TType::Stop => 0x00,
TType::Void => 0x01,
TType::Bool => 0x02,
TType::I08 => 0x03, TType::Double => 0x04,
TType::I16 => 0x06,
TType::I32 => 0x08,
TType::I64 => 0x0A,
TType::String | TType::Utf7 => 0x0B,
TType::Struct => 0x0C,
TType::Map => 0x0D,
TType::Set => 0x0E,
TType::List => 0x0F,
TType::Utf8 => 0x10,
TType::Utf16 => 0x11,
}
}
fn field_type_from_u8(b: u8) -> crate::Result<TType> {
match b {
0x00 => Ok(TType::Stop),
0x01 => Ok(TType::Void),
0x02 => Ok(TType::Bool),
0x03 => Ok(TType::I08), 0x04 => Ok(TType::Double),
0x06 => Ok(TType::I16),
0x08 => Ok(TType::I32),
0x0A => Ok(TType::I64),
0x0B => Ok(TType::String), 0x0C => Ok(TType::Struct),
0x0D => Ok(TType::Map),
0x0E => Ok(TType::Set),
0x0F => Ok(TType::List),
0x10 => Ok(TType::Utf8),
0x11 => Ok(TType::Utf16),
unkn => Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::InvalidData,
message: format!("cannot convert {} to TType", unkn),
})),
}
}
#[cfg(test)]
mod tests {
use crate::protocol::{
TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
};
use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
use super::*;
#[test]
fn must_write_strict_message_call_begin() {
let (_, mut o_prot) = test_objects(true);
let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&ident).is_ok());
#[rustfmt::skip]
let expected: [u8; 16] = [
0x80,
0x01,
0x00,
0x01,
0x00,
0x00,
0x00,
0x04,
0x74,
0x65,
0x73,
0x74,
0x00,
0x00,
0x00,
0x01,
];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_write_non_strict_message_call_begin() {
let (_, mut o_prot) = test_objects(false);
let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&ident).is_ok());
#[rustfmt::skip]
let expected: [u8; 13] = [
0x00,
0x00,
0x00,
0x04,
0x74,
0x65,
0x73,
0x74,
0x01,
0x00,
0x00,
0x00,
0x01,
];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_write_strict_message_reply_begin() {
let (_, mut o_prot) = test_objects(true);
let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
assert!(o_prot.write_message_begin(&ident).is_ok());
#[rustfmt::skip]
let expected: [u8; 16] = [
0x80,
0x01,
0x00,
0x02,
0x00,
0x00,
0x00,
0x04,
0x74,
0x65,
0x73,
0x74,
0x00,
0x00,
0x00,
0x0A,
];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_write_non_strict_message_reply_begin() {
let (_, mut o_prot) = test_objects(false);
let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
assert!(o_prot.write_message_begin(&ident).is_ok());
#[rustfmt::skip]
let expected: [u8; 13] = [
0x00,
0x00,
0x00,
0x04,
0x74,
0x65,
0x73,
0x74,
0x02,
0x00,
0x00,
0x00,
0x0A,
];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_strict_message_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&sent_ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_message_begin());
assert_eq!(&received_ident, &sent_ident);
}
#[test]
fn must_round_trip_non_strict_message_begin() {
let (mut i_prot, mut o_prot) = test_objects(false);
let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&sent_ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_message_begin());
assert_eq!(&received_ident, &sent_ident);
}
#[test]
fn must_write_message_end() {
assert_no_write(|o| o.write_message_end(), true);
}
#[test]
fn must_write_struct_begin() {
assert_no_write(
|o| o.write_struct_begin(&TStructIdentifier::new("foo")),
true,
);
}
#[test]
fn must_write_struct_end() {
assert_no_write(|o| o.write_struct_end(), true);
}
#[test]
fn must_write_field_begin() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot
.write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22))
.is_ok());
let expected: [u8; 3] = [0x0B, 0x00, 0x16];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_field_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20);
assert!(o_prot.write_field_begin(&sent_field_ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let expected_ident = TFieldIdentifier {
name: None,
field_type: TType::I64,
id: Some(20),
}; let received_ident = assert_success!(i_prot.read_field_begin());
assert_eq!(&received_ident, &expected_ident);
}
#[test]
fn must_write_stop_field() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot.write_field_stop().is_ok());
let expected: [u8; 1] = [0x00];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_field_stop() {
let (mut i_prot, mut o_prot) = test_objects(true);
assert!(o_prot.write_field_stop().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let expected_ident = TFieldIdentifier {
name: None,
field_type: TType::Stop,
id: Some(0),
}; let received_ident = assert_success!(i_prot.read_field_begin());
assert_eq!(&received_ident, &expected_ident);
}
#[test]
fn must_write_field_end() {
assert_no_write(|o| o.write_field_end(), true);
}
#[test]
fn must_write_list_begin() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot
.write_list_begin(&TListIdentifier::new(TType::Bool, 5))
.is_ok());
let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_list_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
let ident = TListIdentifier::new(TType::List, 900);
assert!(o_prot.write_list_begin(&ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_list_begin());
assert_eq!(&received_ident, &ident);
}
#[test]
fn must_write_list_end() {
assert_no_write(|o| o.write_list_end(), true);
}
#[test]
fn must_write_set_begin() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot
.write_set_begin(&TSetIdentifier::new(TType::I16, 7))
.is_ok());
let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_set_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
let ident = TSetIdentifier::new(TType::I64, 2000);
assert!(o_prot.write_set_begin(&ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident_result = i_prot.read_set_begin();
assert!(received_ident_result.is_ok());
assert_eq!(&received_ident_result.unwrap(), &ident);
}
#[test]
fn must_write_set_end() {
assert_no_write(|o| o.write_set_end(), true);
}
#[test]
fn must_write_map_begin() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot
.write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32))
.is_ok());
let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_map_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
assert!(o_prot.write_map_begin(&ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_map_begin());
assert_eq!(&received_ident, &ident);
}
#[test]
fn must_write_map_end() {
assert_no_write(|o| o.write_map_end(), true);
}
#[test]
fn must_write_bool_true() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot.write_bool(true).is_ok());
let expected: [u8; 1] = [0x01];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_write_bool_false() {
let (_, mut o_prot) = test_objects(true);
assert!(o_prot.write_bool(false).is_ok());
let expected: [u8; 1] = [0x00];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_read_bool_true() {
let (mut i_prot, _) = test_objects(true);
set_readable_bytes!(i_prot, &[0x01]);
let read_bool = assert_success!(i_prot.read_bool());
assert_eq!(read_bool, true);
}
#[test]
fn must_read_bool_false() {
let (mut i_prot, _) = test_objects(true);
set_readable_bytes!(i_prot, &[0x00]);
let read_bool = assert_success!(i_prot.read_bool());
assert_eq!(read_bool, false);
}
#[test]
fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() {
let (mut i_prot, _) = test_objects(true);
set_readable_bytes!(i_prot, &[0xAC]);
let read_bool = assert_success!(i_prot.read_bool());
assert_eq!(read_bool, true);
}
#[test]
fn must_write_bytes() {
let (_, mut o_prot) = test_objects(true);
let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF];
assert!(o_prot.write_bytes(&bytes).is_ok());
let buf = o_prot.transport.write_bytes();
assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); assert_eq!(&buf[4..], bytes); }
#[test]
fn must_round_trip_bytes() {
let (mut i_prot, mut o_prot) = test_objects(true);
#[rustfmt::skip]
let bytes: [u8; 25] = [
0x20,
0xFD,
0x18,
0x84,
0x99,
0x12,
0xAB,
0xBB,
0x45,
0xDF,
0x34,
0xDC,
0x98,
0xA4,
0x6D,
0xF3,
0x99,
0xB4,
0xB7,
0xD4,
0x9C,
0xA5,
0xB3,
0xC9,
0x88,
];
assert!(o_prot.write_bytes(&bytes).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_bytes = assert_success!(i_prot.read_bytes());
assert_eq!(&received_bytes, &bytes);
}
fn test_objects(
strict: bool,
) -> (
TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
) {
let mem = TBufferChannel::with_capacity(40, 40);
let (r_mem, w_mem) = mem.split().unwrap();
let i_prot = TBinaryInputProtocol::new(r_mem, strict);
let o_prot = TBinaryOutputProtocol::new(w_mem, strict);
(i_prot, o_prot)
}
fn assert_no_write<F>(mut write_fn: F, strict: bool)
where
F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> crate::Result<()>,
{
let (_, mut o_prot) = test_objects(strict);
assert!(write_fn(&mut o_prot).is_ok());
assert_eq!(o_prot.transport.write_bytes().len(), 0);
}
}