tonic/codec/
decode.rs

1use super::compression::{decompress, CompressionEncoding, CompressionSettings};
2use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE};
3use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
4use bytes::{Buf, BufMut, BytesMut};
5use http::{HeaderMap, StatusCode};
6use http_body::Body;
7use http_body_util::BodyExt;
8use std::{
9    fmt, future,
10    pin::Pin,
11    task::ready,
12    task::{Context, Poll},
13};
14use tokio_stream::Stream;
15use tracing::{debug, trace};
16
17/// Streaming requests and responses.
18///
19/// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface
20/// to fetch the message stream and trailing metadata
21pub struct Streaming<T> {
22    decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>,
23    inner: StreamingInner,
24}
25
26struct StreamingInner {
27    body: BoxBody,
28    state: State,
29    direction: Direction,
30    buf: BytesMut,
31    trailers: Option<HeaderMap>,
32    decompress_buf: BytesMut,
33    encoding: Option<CompressionEncoding>,
34    max_message_size: Option<usize>,
35}
36
37impl<T> Unpin for Streaming<T> {}
38
39#[derive(Debug, Clone)]
40enum State {
41    ReadHeader,
42    ReadBody {
43        compression: Option<CompressionEncoding>,
44        len: usize,
45    },
46    Error(Option<Status>),
47}
48
49#[derive(Debug, PartialEq, Eq)]
50enum Direction {
51    Request,
52    Response(StatusCode),
53    EmptyResponse,
54}
55
56impl<T> Streaming<T> {
57    /// Create a new streaming response in the grpc response format for decoding a response [Body]
58    /// into message of type T
59    pub fn new_response<B, D>(
60        decoder: D,
61        body: B,
62        status_code: StatusCode,
63        encoding: Option<CompressionEncoding>,
64        max_message_size: Option<usize>,
65    ) -> Self
66    where
67        B: Body + Send + 'static,
68        B::Error: Into<crate::Error>,
69        D: Decoder<Item = T, Error = Status> + Send + 'static,
70    {
71        Self::new(
72            decoder,
73            body,
74            Direction::Response(status_code),
75            encoding,
76            max_message_size,
77        )
78    }
79
80    /// Create empty response. For creating responses that have no content (headers + trailers only)
81    pub fn new_empty<B, D>(decoder: D, body: B) -> Self
82    where
83        B: Body + Send + 'static,
84        B::Error: Into<crate::Error>,
85        D: Decoder<Item = T, Error = Status> + Send + 'static,
86    {
87        Self::new(decoder, body, Direction::EmptyResponse, None, None)
88    }
89
90    /// Create a new streaming request in the grpc response format for decoding a request [Body]
91    /// into message of type T
92    pub fn new_request<B, D>(
93        decoder: D,
94        body: B,
95        encoding: Option<CompressionEncoding>,
96        max_message_size: Option<usize>,
97    ) -> Self
98    where
99        B: Body + Send + 'static,
100        B::Error: Into<crate::Error>,
101        D: Decoder<Item = T, Error = Status> + Send + 'static,
102    {
103        Self::new(
104            decoder,
105            body,
106            Direction::Request,
107            encoding,
108            max_message_size,
109        )
110    }
111
112    fn new<B, D>(
113        decoder: D,
114        body: B,
115        direction: Direction,
116        encoding: Option<CompressionEncoding>,
117        max_message_size: Option<usize>,
118    ) -> Self
119    where
120        B: Body + Send + 'static,
121        B::Error: Into<crate::Error>,
122        D: Decoder<Item = T, Error = Status> + Send + 'static,
123    {
124        let buffer_size = decoder.buffer_settings().buffer_size;
125        Self {
126            decoder: Box::new(decoder),
127            inner: StreamingInner {
128                body: body
129                    .map_frame(|frame| frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())))
130                    .map_err(|err| Status::map_error(err.into()))
131                    .boxed_unsync(),
132                state: State::ReadHeader,
133                direction,
134                buf: BytesMut::with_capacity(buffer_size),
135                trailers: None,
136                decompress_buf: BytesMut::new(),
137                encoding,
138                max_message_size,
139            },
140        }
141    }
142}
143
144impl StreamingInner {
145    fn decode_chunk(
146        &mut self,
147        buffer_settings: BufferSettings,
148    ) -> Result<Option<DecodeBuf<'_>>, Status> {
149        if let State::ReadHeader = self.state {
150            if self.buf.remaining() < HEADER_SIZE {
151                return Ok(None);
152            }
153
154            let compression_encoding = match self.buf.get_u8() {
155                0 => None,
156                1 => {
157                    {
158                        if self.encoding.is_some() {
159                            self.encoding
160                        } else {
161                            // https://grpc.github.io/grpc/core/md_doc_compression.html
162                            // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding
163                            // entry different from identity in its metadata MUST fail with INTERNAL status,
164                            // its associated description indicating the invalid Compressed-Flag condition.
165                            return Err(Status::internal( "protocol error: received message with compressed-flag but no grpc-encoding was specified"));
166                        }
167                    }
168                }
169                f => {
170                    trace!("unexpected compression flag");
171                    let message = if let Direction::Response(status) = self.direction {
172                        format!(
173                            "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}",
174                            f, status
175                        )
176                    } else {
177                        format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f)
178                    };
179                    return Err(Status::internal(message));
180                }
181            };
182
183            let len = self.buf.get_u32() as usize;
184            let limit = self
185                .max_message_size
186                .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
187            if len > limit {
188                return Err(Status::out_of_range(
189                    format!(
190                        "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
191                        len, limit
192                    ),
193                ));
194            }
195
196            self.buf.reserve(len);
197
198            self.state = State::ReadBody {
199                compression: compression_encoding,
200                len,
201            }
202        }
203
204        if let State::ReadBody { len, compression } = self.state {
205            // if we haven't read enough of the message then return and keep
206            // reading
207            if self.buf.remaining() < len || self.buf.len() < len {
208                return Ok(None);
209            }
210
211            let decode_buf = if let Some(encoding) = compression {
212                self.decompress_buf.clear();
213
214                if let Err(err) = decompress(
215                    CompressionSettings {
216                        encoding,
217                        buffer_growth_interval: buffer_settings.buffer_size,
218                    },
219                    &mut self.buf,
220                    &mut self.decompress_buf,
221                    len,
222                ) {
223                    let message = if let Direction::Response(status) = self.direction {
224                        format!(
225                            "Error decompressing: {}, while receiving response with status: {}",
226                            err, status
227                        )
228                    } else {
229                        format!("Error decompressing: {}, while sending request", err)
230                    };
231                    return Err(Status::internal(message));
232                }
233                let decompressed_len = self.decompress_buf.len();
234                DecodeBuf::new(&mut self.decompress_buf, decompressed_len)
235            } else {
236                DecodeBuf::new(&mut self.buf, len)
237            };
238
239            return Ok(Some(decode_buf));
240        }
241
242        Ok(None)
243    }
244
245    // Returns Some(()) if data was found or None if the loop in `poll_next` should break
246    fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> {
247        let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) {
248            Some(Ok(d)) => Some(d),
249            Some(Err(status)) => {
250                if self.direction == Direction::Request && status.code() == Code::Cancelled {
251                    return Poll::Ready(Ok(None));
252                }
253
254                let _ = std::mem::replace(&mut self.state, State::Error(Some(status.clone())));
255                debug!("decoder inner stream error: {:?}", status);
256                return Poll::Ready(Err(status));
257            }
258            None => None,
259        };
260
261        Poll::Ready(if let Some(frame) = chunk {
262            match frame {
263                frame if frame.is_data() => {
264                    self.buf.put(frame.into_data().unwrap());
265                    Ok(Some(()))
266                }
267                frame if frame.is_trailers() => {
268                    match &mut self.trailers {
269                        Some(trailers) => {
270                            trailers.extend(frame.into_trailers().unwrap());
271                        }
272                        None => {
273                            self.trailers = Some(frame.into_trailers().unwrap());
274                        }
275                    }
276
277                    Ok(None)
278                }
279                frame => panic!("unexpected frame: {:?}", frame),
280            }
281        } else {
282            // FIXME: improve buf usage.
283            if self.buf.has_remaining() {
284                trace!("unexpected EOF decoding stream, state: {:?}", self.state);
285                Err(Status::internal("Unexpected EOF decoding stream."))
286            } else {
287                Ok(None)
288            }
289        })
290    }
291
292    fn response(&mut self) -> Result<(), Status> {
293        if let Direction::Response(status) = self.direction {
294            if let Err(Some(e)) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) {
295                // If the trailers contain a grpc-status, then we should return that as the error
296                // and otherwise stop the stream (by taking the error state)
297                self.trailers.take();
298                return Err(e);
299            }
300        }
301        Ok(())
302    }
303}
304
305impl<T> Streaming<T> {
306    /// Fetch the next message from this stream.
307    ///
308    /// # Return value
309    ///
310    /// - `Result::Err(val)` means a gRPC error was sent by the sender instead
311    ///   of a valid response message. Refer to [`Status::code`] and
312    ///   [`Status::message`] to examine possible error causes.
313    ///
314    /// - `Result::Ok(None)` means the stream was closed by the sender and no
315    ///   more messages will be delivered. Further attempts to call
316    ///   [`Streaming::message`] will result in the same return value.
317    ///
318    /// - `Result::Ok(Some(val))` means the sender streamed a valid response
319    ///   message `val`.
320    ///
321    /// ```rust
322    /// # use tonic::{Streaming, Status, codec::Decoder};
323    /// # use std::fmt::Debug;
324    /// # async fn next_message_ex<T, D>(mut request: Streaming<T>) -> Result<(), Status>
325    /// # where T: Debug,
326    /// # D: Decoder<Item = T, Error = Status> + Send  + 'static,
327    /// # {
328    /// if let Some(next_message) = request.message().await? {
329    ///     println!("{:?}", next_message);
330    /// }
331    /// # Ok(())
332    /// # }
333    /// ```
334    pub async fn message(&mut self) -> Result<Option<T>, Status> {
335        match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
336            Some(Ok(m)) => Ok(Some(m)),
337            Some(Err(e)) => Err(e),
338            None => Ok(None),
339        }
340    }
341
342    /// Fetch the trailing metadata.
343    ///
344    /// This will drain the stream of all its messages to receive the trailing
345    /// metadata. If [`Streaming::message`] returns `None` then this function
346    /// will not need to poll for trailers since the body was totally consumed.
347    ///
348    /// ```rust
349    /// # use tonic::{Streaming, Status};
350    /// # async fn trailers_ex<T>(mut request: Streaming<T>) -> Result<(), Status> {
351    /// if let Some(metadata) = request.trailers().await? {
352    ///     println!("{:?}", metadata);
353    /// }
354    /// # Ok(())
355    /// # }
356    /// ```
357    pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
358        // Shortcut to see if we already pulled the trailers in the stream step
359        // we need to do that so that the stream can error on trailing grpc-status
360        if let Some(trailers) = self.inner.trailers.take() {
361            return Ok(Some(MetadataMap::from_headers(trailers)));
362        }
363
364        // To fetch the trailers we must clear the body and drop it.
365        while self.message().await?.is_some() {}
366
367        // Since we call poll_trailers internally on poll_next we need to
368        // check if it got cached again.
369        if let Some(trailers) = self.inner.trailers.take() {
370            return Ok(Some(MetadataMap::from_headers(trailers)));
371        }
372
373        // We've polled through all the frames, and still no trailers, return None
374        Ok(None)
375    }
376
377    fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
378        match self.inner.decode_chunk(self.decoder.buffer_settings())? {
379            Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? {
380                Some(msg) => {
381                    self.inner.state = State::ReadHeader;
382                    Ok(Some(msg))
383                }
384                None => Ok(None),
385            },
386            None => Ok(None),
387        }
388    }
389}
390
391impl<T> Stream for Streaming<T> {
392    type Item = Result<T, Status>;
393
394    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395        loop {
396            // When the stream encounters an error yield that error once and then on subsequent
397            // calls to poll_next return Poll::Ready(None) indicating that the stream has been
398            // fully exhausted.
399            if let State::Error(status) = &mut self.inner.state {
400                return Poll::Ready(status.take().map(Err));
401            }
402
403            if let Some(item) = self.decode_chunk()? {
404                return Poll::Ready(Some(Ok(item)));
405            }
406
407            match ready!(self.inner.poll_frame(cx))? {
408                Some(()) => (),
409                None => break,
410            }
411        }
412
413        Poll::Ready(match self.inner.response() {
414            Ok(()) => None,
415            Err(err) => Some(Err(err)),
416        })
417    }
418}
419
420impl<T> fmt::Debug for Streaming<T> {
421    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422        f.debug_struct("Streaming").finish()
423    }
424}
425
426#[cfg(test)]
427static_assertions::assert_impl_all!(Streaming<()>: Send);