Skip to main content

tonic/codec/
decode.rs

1use super::compression::{CompressionEncoding, CompressionSettings, decompress};
2use super::{BufferSettings, DEFAULT_MAX_RECV_MESSAGE_SIZE, DecodeBuf, Decoder, HEADER_SIZE};
3use crate::{Code, Status, body::Body, metadata::MetadataMap};
4use bytes::{Buf, BufMut, BytesMut};
5use http::{HeaderMap, StatusCode};
6use http_body::Body as HttpBody;
7use http_body_util::BodyExt;
8use std::{
9    fmt, future,
10    pin::Pin,
11    task::ready,
12    task::{Context, Poll},
13};
14use sync_wrapper::SyncWrapper;
15use tokio_stream::Stream;
16use tracing::{debug, trace};
17
18/// Streaming requests and responses.
19///
20/// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface
21/// to fetch the message stream and trailing metadata
22pub struct Streaming<T> {
23    decoder: SyncWrapper<Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>>,
24    inner: StreamingInner,
25}
26
27struct StreamingInner {
28    body: SyncWrapper<Body>,
29    state: State,
30    direction: Direction,
31    buf: BytesMut,
32    trailers: Option<HeaderMap>,
33    decompress_buf: BytesMut,
34    encoding: Option<CompressionEncoding>,
35    max_message_size: Option<usize>,
36}
37
38impl<T> Unpin for Streaming<T> {}
39
40#[derive(Debug, Clone)]
41enum State {
42    ReadHeader,
43    ReadBody {
44        compression: Option<CompressionEncoding>,
45        len: usize,
46    },
47    Error(Option<Status>),
48}
49
50#[derive(Debug, PartialEq, Eq)]
51enum Direction {
52    Request,
53    Response(StatusCode),
54    EmptyResponse,
55}
56
57impl<T> Streaming<T> {
58    /// Create a new streaming response in the grpc response format for decoding a response [Body]
59    /// into message of type T
60    pub fn new_response<B, D>(
61        decoder: D,
62        body: B,
63        status_code: StatusCode,
64        encoding: Option<CompressionEncoding>,
65        max_message_size: Option<usize>,
66    ) -> Self
67    where
68        B: HttpBody + Send + 'static,
69        B::Error: Into<crate::BoxError>,
70        D: Decoder<Item = T, Error = Status> + Send + 'static,
71    {
72        Self::new(
73            decoder,
74            body,
75            Direction::Response(status_code),
76            encoding,
77            max_message_size,
78        )
79    }
80
81    /// Create empty response. For creating responses that have no content (headers + trailers only)
82    pub fn new_empty<B, D>(decoder: D, body: B) -> Self
83    where
84        B: HttpBody + Send + 'static,
85        B::Error: Into<crate::BoxError>,
86        D: Decoder<Item = T, Error = Status> + Send + 'static,
87    {
88        Self::new(decoder, body, Direction::EmptyResponse, None, None)
89    }
90
91    /// Create a new streaming request in the grpc response format for decoding a request [Body]
92    /// into message of type T
93    pub fn new_request<B, D>(
94        decoder: D,
95        body: B,
96        encoding: Option<CompressionEncoding>,
97        max_message_size: Option<usize>,
98    ) -> Self
99    where
100        B: HttpBody + Send + 'static,
101        B::Error: Into<crate::BoxError>,
102        D: Decoder<Item = T, Error = Status> + Send + 'static,
103    {
104        Self::new(
105            decoder,
106            body,
107            Direction::Request,
108            encoding,
109            max_message_size,
110        )
111    }
112
113    fn new<B, D>(
114        decoder: D,
115        body: B,
116        direction: Direction,
117        encoding: Option<CompressionEncoding>,
118        max_message_size: Option<usize>,
119    ) -> Self
120    where
121        B: HttpBody + Send + 'static,
122        B::Error: Into<crate::BoxError>,
123        D: Decoder<Item = T, Error = Status> + Send + 'static,
124    {
125        let buffer_size = decoder.buffer_settings().buffer_size;
126        Self {
127            decoder: SyncWrapper::new(Box::new(decoder)),
128            inner: StreamingInner {
129                body: SyncWrapper::new(Body::new(
130                    body.map_frame(|frame| {
131                        frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
132                    })
133                    .map_err(|err| Status::map_error(err.into())),
134                )),
135                state: State::ReadHeader,
136                direction,
137                buf: BytesMut::with_capacity(buffer_size),
138                trailers: None,
139                decompress_buf: BytesMut::new(),
140                encoding,
141                max_message_size,
142            },
143        }
144    }
145}
146
147impl StreamingInner {
148    fn decode_chunk(
149        &mut self,
150        buffer_settings: BufferSettings,
151    ) -> Result<Option<DecodeBuf<'_>>, Status> {
152        if let State::ReadHeader = self.state {
153            if self.buf.remaining() < HEADER_SIZE {
154                return Ok(None);
155            }
156
157            let compression_encoding = match self.buf.get_u8() {
158                0 => None,
159                1 => {
160                    {
161                        if self.encoding.is_some() {
162                            self.encoding
163                        } else {
164                            // https://grpc.github.io/grpc/core/md_doc_compression.html
165                            // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding
166                            // entry different from identity in its metadata MUST fail with INTERNAL status,
167                            // its associated description indicating the invalid Compressed-Flag condition.
168                            return Err(Status::internal(
169                                "protocol error: received message with compressed-flag but no grpc-encoding was specified",
170                            ));
171                        }
172                    }
173                }
174                f => {
175                    trace!("unexpected compression flag");
176                    let message = if let Direction::Response(status) = self.direction {
177                        format!(
178                            "protocol error: received message with invalid compression flag: {f} (valid flags are 0 and 1) while receiving response with status: {status}"
179                        )
180                    } else {
181                        format!(
182                            "protocol error: received message with invalid compression flag: {f} (valid flags are 0 and 1), while sending request"
183                        )
184                    };
185                    return Err(Status::internal(message));
186                }
187            };
188
189            let len = self.buf.get_u32() as usize;
190            let limit = self
191                .max_message_size
192                .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
193            if len > limit {
194                return Err(Status::out_of_range(format!(
195                    "Error, decoded message length too large: found {len} bytes, the limit is: {limit} bytes"
196                )));
197            }
198
199            self.buf.reserve(len);
200
201            self.state = State::ReadBody {
202                compression: compression_encoding,
203                len,
204            }
205        }
206
207        if let State::ReadBody { len, compression } = self.state {
208            // if we haven't read enough of the message then return and keep
209            // reading
210            if self.buf.remaining() < len || self.buf.len() < len {
211                return Ok(None);
212            }
213
214            let decode_buf = if let Some(encoding) = compression {
215                self.decompress_buf.clear();
216                let limit = self
217                    .max_message_size
218                    .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
219                let limited_out_buf = (&mut self.decompress_buf).limit(limit);
220
221                if let Err(err) = decompress(
222                    CompressionSettings {
223                        encoding,
224                        buffer_growth_interval: buffer_settings.buffer_size,
225                    },
226                    &mut self.buf,
227                    limited_out_buf,
228                    len,
229                ) {
230                    if matches!(err.kind(), std::io::ErrorKind::WriteZero) {
231                        return Err(Status::resource_exhausted(format!(
232                            "Error decompressing: size limit, of {limit} bytes, exceeded while decompressing message"
233                        )));
234                    }
235                    let message = if let Direction::Response(status) = self.direction {
236                        format!(
237                            "Error decompressing: {err}, while receiving response with status: {status}"
238                        )
239                    } else {
240                        format!("Error decompressing: {err}, while sending request")
241                    };
242                    return Err(Status::internal(message));
243                }
244                let decompressed_len = self.decompress_buf.len();
245                DecodeBuf::new(&mut self.decompress_buf, decompressed_len)
246            } else {
247                DecodeBuf::new(&mut self.buf, len)
248            };
249
250            return Ok(Some(decode_buf));
251        }
252
253        Ok(None)
254    }
255
256    // Returns Some(()) if data was found or None if the loop in `poll_next` should break
257    fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> {
258        let frame = match ready!(Pin::new(self.body.get_mut()).poll_frame(cx)) {
259            Some(Ok(frame)) => frame,
260            Some(Err(status)) => {
261                if self.direction == Direction::Request && status.code() == Code::Cancelled {
262                    return Poll::Ready(Ok(None));
263                }
264
265                let _ = std::mem::replace(&mut self.state, State::Error(Some(status.clone())));
266                debug!("decoder inner stream error: {:?}", status);
267                return Poll::Ready(Err(status));
268            }
269            None => {
270                // FIXME: improve buf usage.
271                return Poll::Ready(if self.buf.has_remaining() {
272                    trace!("unexpected EOF decoding stream, state: {:?}", self.state);
273                    Err(Status::internal("Unexpected EOF decoding stream."))
274                } else {
275                    Ok(None)
276                });
277            }
278        };
279
280        Poll::Ready(if frame.is_data() {
281            self.buf.put(frame.into_data().unwrap());
282            Ok(Some(()))
283        } else if frame.is_trailers() {
284            if let Some(trailers) = &mut self.trailers {
285                trailers.extend(frame.into_trailers().unwrap());
286            } else {
287                self.trailers = Some(frame.into_trailers().unwrap());
288            }
289
290            Ok(None)
291        } else {
292            panic!("unexpected frame: {frame:?}");
293        })
294    }
295
296    fn response(&mut self) -> Result<(), Status> {
297        if let Direction::Response(status) = self.direction {
298            if let Err(Some(e)) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) {
299                // If the trailers contain a grpc-status, then we should return that as the error
300                // and otherwise stop the stream (by taking the error state)
301                self.trailers.take();
302                return Err(e);
303            }
304        }
305        Ok(())
306    }
307}
308
309impl<T> Streaming<T> {
310    /// Fetch the next message from this stream.
311    ///
312    /// # Return value
313    ///
314    /// - `Result::Err(val)` means a gRPC error was sent by the sender instead
315    ///   of a valid response message. Refer to [`Status::code`] and
316    ///   [`Status::message`] to examine possible error causes.
317    ///
318    /// - `Result::Ok(None)` means the stream was closed by the sender and no
319    ///   more messages will be delivered. Further attempts to call
320    ///   [`Streaming::message`] will result in the same return value.
321    ///
322    /// - `Result::Ok(Some(val))` means the sender streamed a valid response
323    ///   message `val`.
324    ///
325    /// ```rust
326    /// # use tonic::{Streaming, Status, codec::Decoder};
327    /// # use std::fmt::Debug;
328    /// # async fn next_message_ex<T, D>(mut request: Streaming<T>) -> Result<(), Status>
329    /// # where T: Debug,
330    /// # D: Decoder<Item = T, Error = Status> + Send  + 'static,
331    /// # {
332    /// if let Some(next_message) = request.message().await? {
333    ///     println!("{:?}", next_message);
334    /// }
335    /// # Ok(())
336    /// # }
337    /// ```
338    pub async fn message(&mut self) -> Result<Option<T>, Status> {
339        match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
340            Some(Ok(m)) => Ok(Some(m)),
341            Some(Err(e)) => Err(e),
342            None => Ok(None),
343        }
344    }
345
346    /// Fetch the trailing metadata.
347    ///
348    /// This will drain the stream of all its messages to receive the trailing
349    /// metadata. If [`Streaming::message`] returns `None` then this function
350    /// will not need to poll for trailers since the body was totally consumed.
351    ///
352    /// ```rust
353    /// # use tonic::{Streaming, Status};
354    /// # async fn trailers_ex<T>(mut request: Streaming<T>) -> Result<(), Status> {
355    /// if let Some(metadata) = request.trailers().await? {
356    ///     println!("{:?}", metadata);
357    /// }
358    /// # Ok(())
359    /// # }
360    /// ```
361    pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
362        // Shortcut to see if we already pulled the trailers in the stream step
363        // we need to do that so that the stream can error on trailing grpc-status
364        if let Some(trailers) = self.inner.trailers.take() {
365            return Ok(Some(MetadataMap::from_headers(trailers)));
366        }
367
368        // To fetch the trailers we must clear the body and drop it.
369        while self.message().await?.is_some() {}
370
371        // Since we call poll_trailers internally on poll_next we need to
372        // check if it got cached again.
373        if let Some(trailers) = self.inner.trailers.take() {
374            return Ok(Some(MetadataMap::from_headers(trailers)));
375        }
376
377        // We've polled through all the frames, and still no trailers, return None
378        Ok(None)
379    }
380
381    fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
382        match self
383            .inner
384            .decode_chunk(self.decoder.get_mut().buffer_settings())?
385        {
386            Some(mut decode_buf) => match self.decoder.get_mut().decode(&mut decode_buf)? {
387                Some(msg) => {
388                    self.inner.state = State::ReadHeader;
389                    Ok(Some(msg))
390                }
391                None => Ok(None),
392            },
393            None => Ok(None),
394        }
395    }
396}
397
398impl<T> Stream for Streaming<T> {
399    type Item = Result<T, Status>;
400
401    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
402        loop {
403            // When the stream encounters an error yield that error once and then on subsequent
404            // calls to poll_next return Poll::Ready(None) indicating that the stream has been
405            // fully exhausted.
406            if let State::Error(status) = &mut self.inner.state {
407                return Poll::Ready(status.take().map(Err));
408            }
409
410            if let Some(item) = self.decode_chunk()? {
411                return Poll::Ready(Some(Ok(item)));
412            }
413
414            if ready!(self.inner.poll_frame(cx))?.is_none() {
415                match self.inner.response() {
416                    Ok(()) => return Poll::Ready(None),
417                    Err(err) => self.inner.state = State::Error(Some(err)),
418                }
419            }
420        }
421    }
422}
423
424impl<T> fmt::Debug for Streaming<T> {
425    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426        f.debug_struct("Streaming").finish()
427    }
428}
429
430#[cfg(test)]
431static_assertions::assert_impl_all!(Streaming<()>: Send, Sync);