tonic/codec/
prost.rs

1use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder};
2use crate::codec::EncodeBuf;
3use crate::Status;
4use prost::Message;
5use std::marker::PhantomData;
6
7/// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8#[derive(Debug, Clone)]
9pub struct ProstCodec<T, U> {
10    _pd: PhantomData<(T, U)>,
11}
12
13impl<T, U> ProstCodec<T, U> {
14    /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control
15    /// how memory is allocated and grows per RPC.
16    pub fn new() -> Self {
17        Self { _pd: PhantomData }
18    }
19}
20
21impl<T, U> Default for ProstCodec<T, U> {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl<T, U> ProstCodec<T, U>
28where
29    T: Message + Send + 'static,
30    U: Message + Default + Send + 'static,
31{
32    /// A tool for building custom codecs based on prost encoding and decoding.
33    /// See the codec_buffers example for one possible way to use this.
34    pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
35        ProstEncoder {
36            _pd: PhantomData,
37            buffer_settings,
38        }
39    }
40
41    /// A tool for building custom codecs based on prost encoding and decoding.
42    /// See the codec_buffers example for one possible way to use this.
43    pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
44        ProstDecoder {
45            _pd: PhantomData,
46            buffer_settings,
47        }
48    }
49}
50
51impl<T, U> Codec for ProstCodec<T, U>
52where
53    T: Message + Send + 'static,
54    U: Message + Default + Send + 'static,
55{
56    type Encode = T;
57    type Decode = U;
58
59    type Encoder = ProstEncoder<T>;
60    type Decoder = ProstDecoder<U>;
61
62    fn encoder(&mut self) -> Self::Encoder {
63        ProstEncoder {
64            _pd: PhantomData,
65            buffer_settings: BufferSettings::default(),
66        }
67    }
68
69    fn decoder(&mut self) -> Self::Decoder {
70        ProstDecoder {
71            _pd: PhantomData,
72            buffer_settings: BufferSettings::default(),
73        }
74    }
75}
76
77/// A [`Encoder`] that knows how to encode `T`.
78#[derive(Debug, Clone, Default)]
79pub struct ProstEncoder<T> {
80    _pd: PhantomData<T>,
81    buffer_settings: BufferSettings,
82}
83
84impl<T> ProstEncoder<T> {
85    /// Get a new encoder with explicit buffer settings
86    pub fn new(buffer_settings: BufferSettings) -> Self {
87        Self {
88            _pd: PhantomData,
89            buffer_settings,
90        }
91    }
92}
93
94impl<T: Message> Encoder for ProstEncoder<T> {
95    type Item = T;
96    type Error = Status;
97
98    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
99        item.encode(buf)
100            .expect("Message only errors if not enough space");
101
102        Ok(())
103    }
104
105    fn buffer_settings(&self) -> BufferSettings {
106        self.buffer_settings
107    }
108}
109
110/// A [`Decoder`] that knows how to decode `U`.
111#[derive(Debug, Clone, Default)]
112pub struct ProstDecoder<U> {
113    _pd: PhantomData<U>,
114    buffer_settings: BufferSettings,
115}
116
117impl<U> ProstDecoder<U> {
118    /// Get a new decoder with explicit buffer settings
119    pub fn new(buffer_settings: BufferSettings) -> Self {
120        Self {
121            _pd: PhantomData,
122            buffer_settings,
123        }
124    }
125}
126
127impl<U: Message + Default> Decoder for ProstDecoder<U> {
128    type Item = U;
129    type Error = Status;
130
131    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
132        let item = Message::decode(buf)
133            .map(Option::Some)
134            .map_err(from_decode_error)?;
135
136        Ok(item)
137    }
138
139    fn buffer_settings(&self) -> BufferSettings {
140        self.buffer_settings
141    }
142}
143
144fn from_decode_error(error: prost::DecodeError) -> crate::Status {
145    // Map Protobuf parse errors to an INTERNAL status code, as per
146    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
147    Status::internal(error.to_string())
148}
149
150#[cfg(test)]
151mod tests {
152    use crate::codec::compression::SingleMessageCompressionOverride;
153    use crate::codec::{
154        DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
155    };
156    use crate::Status;
157    use bytes::{Buf, BufMut, BytesMut};
158    use http_body::Body;
159    use http_body_util::BodyExt as _;
160    use std::pin::pin;
161
162    const LEN: usize = 10000;
163    // The maximum uncompressed size in bytes for a message. Set to 2MB.
164    const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
165
166    #[tokio::test]
167    async fn decode() {
168        let decoder = MockDecoder::default();
169
170        let msg = vec![0u8; LEN];
171
172        let mut buf = BytesMut::new();
173
174        buf.reserve(msg.len() + HEADER_SIZE);
175        buf.put_u8(0);
176        buf.put_u32(msg.len() as u32);
177
178        buf.put(&msg[..]);
179
180        let body = body::MockBody::new(&buf[..], 10005, 0);
181
182        let mut stream = Streaming::new_request(decoder, body, None, None);
183
184        let mut i = 0usize;
185        while let Some(output_msg) = stream.message().await.unwrap() {
186            assert_eq!(output_msg.len(), msg.len());
187            i += 1;
188        }
189        assert_eq!(i, 1);
190    }
191
192    #[tokio::test]
193    async fn decode_max_message_size_exceeded() {
194        let decoder = MockDecoder::default();
195
196        let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
197
198        let mut buf = BytesMut::new();
199
200        buf.reserve(msg.len() + HEADER_SIZE);
201        buf.put_u8(0);
202        buf.put_u32(msg.len() as u32);
203
204        buf.put(&msg[..]);
205
206        let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
207
208        let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
209
210        let actual = stream.message().await.unwrap_err();
211
212        let expected = Status::out_of_range(format!(
213            "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
214            msg.len(),
215            MAX_MESSAGE_SIZE
216        ));
217
218        assert_eq!(actual.code(), expected.code());
219        assert_eq!(actual.message(), expected.message());
220    }
221
222    #[tokio::test]
223    async fn encode() {
224        let encoder = MockEncoder::default();
225
226        let msg = Vec::from(&[0u8; 1024][..]);
227
228        let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
229        let source = tokio_stream::iter(messages);
230
231        let mut body = pin!(EncodeBody::new_server(
232            encoder,
233            source,
234            None,
235            SingleMessageCompressionOverride::default(),
236            None,
237        ));
238
239        while let Some(r) = body.frame().await {
240            r.unwrap();
241        }
242    }
243
244    #[tokio::test]
245    async fn encode_max_message_size_exceeded() {
246        let encoder = MockEncoder::default();
247
248        let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
249
250        let messages = std::iter::once(Ok::<_, Status>(msg));
251        let source = tokio_stream::iter(messages);
252
253        let mut body = pin!(EncodeBody::new_server(
254            encoder,
255            source,
256            None,
257            SingleMessageCompressionOverride::default(),
258            Some(MAX_MESSAGE_SIZE),
259        ));
260
261        let frame = body
262            .frame()
263            .await
264            .expect("at least one frame")
265            .expect("no error polling frame");
266        assert_eq!(
267            frame
268                .into_trailers()
269                .expect("got trailers")
270                .get(Status::GRPC_STATUS)
271                .expect("grpc-status header"),
272            "11"
273        );
274        assert!(body.is_end_stream());
275    }
276
277    // skip on windows because CI stumbles over our 4GB allocation
278    #[cfg(not(target_family = "windows"))]
279    #[tokio::test]
280    async fn encode_too_big() {
281        use crate::codec::EncodeBody;
282
283        let encoder = MockEncoder::default();
284
285        let msg = vec![0u8; u32::MAX as usize + 1];
286
287        let messages = std::iter::once(Ok::<_, Status>(msg));
288        let source = tokio_stream::iter(messages);
289
290        let mut body = pin!(EncodeBody::new_server(
291            encoder,
292            source,
293            None,
294            SingleMessageCompressionOverride::default(),
295            Some(usize::MAX),
296        ));
297
298        let frame = body
299            .frame()
300            .await
301            .expect("at least one frame")
302            .expect("no error polling frame");
303        assert_eq!(
304            frame
305                .into_trailers()
306                .expect("got trailers")
307                .get(Status::GRPC_STATUS)
308                .expect("grpc-status header"),
309            "8"
310        );
311        assert!(body.is_end_stream());
312    }
313
314    #[derive(Debug, Clone, Default)]
315    struct MockEncoder {}
316
317    impl Encoder for MockEncoder {
318        type Item = Vec<u8>;
319        type Error = Status;
320
321        fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
322            buf.put(&item[..]);
323            Ok(())
324        }
325
326        fn buffer_settings(&self) -> crate::codec::BufferSettings {
327            Default::default()
328        }
329    }
330
331    #[derive(Debug, Clone, Default)]
332    struct MockDecoder {}
333
334    impl Decoder for MockDecoder {
335        type Item = Vec<u8>;
336        type Error = Status;
337
338        fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
339            let out = Vec::from(buf.chunk());
340            buf.advance(LEN);
341            Ok(Some(out))
342        }
343
344        fn buffer_settings(&self) -> crate::codec::BufferSettings {
345            Default::default()
346        }
347    }
348
349    mod body {
350        use crate::Status;
351        use bytes::Bytes;
352        use http_body::{Body, Frame};
353        use std::{
354            pin::Pin,
355            task::{Context, Poll},
356        };
357
358        #[derive(Debug)]
359        pub(super) struct MockBody {
360            data: Bytes,
361
362            // the size of the partial message to send
363            partial_len: usize,
364
365            // the number of times we've sent
366            count: usize,
367        }
368
369        impl MockBody {
370            pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
371                MockBody {
372                    data: Bytes::copy_from_slice(b),
373                    partial_len,
374                    count,
375                }
376            }
377        }
378
379        impl Body for MockBody {
380            type Data = Bytes;
381            type Error = Status;
382
383            fn poll_frame(
384                mut self: Pin<&mut Self>,
385                cx: &mut Context<'_>,
386            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
387                // every other call to poll_data returns data
388                let should_send = self.count % 2 == 0;
389                let data_len = self.data.len();
390                let partial_len = self.partial_len;
391                let count = self.count;
392                if data_len > 0 {
393                    let result = if should_send {
394                        let response =
395                            self.data
396                                .split_to(if count == 0 { partial_len } else { data_len });
397                        Poll::Ready(Some(Ok(Frame::data(response))))
398                    } else {
399                        cx.waker().wake_by_ref();
400                        Poll::Pending
401                    };
402                    // make some fake progress
403                    self.count += 1;
404                    result
405                } else {
406                    Poll::Ready(None)
407                }
408            }
409        }
410    }
411}