1use super::compression::{decompress, CompressionEncoding, CompressionSettings};
2use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE};
3use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
4use bytes::{Buf, BufMut, BytesMut};
5use http::{HeaderMap, StatusCode};
6use http_body::Body;
7use http_body_util::BodyExt;
8use std::{
9 fmt, future,
10 pin::Pin,
11 task::ready,
12 task::{Context, Poll},
13};
14use tokio_stream::Stream;
15use tracing::{debug, trace};
16
17pub struct Streaming<T> {
22 decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>,
23 inner: StreamingInner,
24}
25
26struct StreamingInner {
27 body: BoxBody,
28 state: State,
29 direction: Direction,
30 buf: BytesMut,
31 trailers: Option<HeaderMap>,
32 decompress_buf: BytesMut,
33 encoding: Option<CompressionEncoding>,
34 max_message_size: Option<usize>,
35}
36
37impl<T> Unpin for Streaming<T> {}
38
39#[derive(Debug, Clone)]
40enum State {
41 ReadHeader,
42 ReadBody {
43 compression: Option<CompressionEncoding>,
44 len: usize,
45 },
46 Error(Option<Status>),
47}
48
49#[derive(Debug, PartialEq, Eq)]
50enum Direction {
51 Request,
52 Response(StatusCode),
53 EmptyResponse,
54}
55
56impl<T> Streaming<T> {
57 pub fn new_response<B, D>(
60 decoder: D,
61 body: B,
62 status_code: StatusCode,
63 encoding: Option<CompressionEncoding>,
64 max_message_size: Option<usize>,
65 ) -> Self
66 where
67 B: Body + Send + 'static,
68 B::Error: Into<crate::Error>,
69 D: Decoder<Item = T, Error = Status> + Send + 'static,
70 {
71 Self::new(
72 decoder,
73 body,
74 Direction::Response(status_code),
75 encoding,
76 max_message_size,
77 )
78 }
79
80 pub fn new_empty<B, D>(decoder: D, body: B) -> Self
82 where
83 B: Body + Send + 'static,
84 B::Error: Into<crate::Error>,
85 D: Decoder<Item = T, Error = Status> + Send + 'static,
86 {
87 Self::new(decoder, body, Direction::EmptyResponse, None, None)
88 }
89
90 pub fn new_request<B, D>(
93 decoder: D,
94 body: B,
95 encoding: Option<CompressionEncoding>,
96 max_message_size: Option<usize>,
97 ) -> Self
98 where
99 B: Body + Send + 'static,
100 B::Error: Into<crate::Error>,
101 D: Decoder<Item = T, Error = Status> + Send + 'static,
102 {
103 Self::new(
104 decoder,
105 body,
106 Direction::Request,
107 encoding,
108 max_message_size,
109 )
110 }
111
112 fn new<B, D>(
113 decoder: D,
114 body: B,
115 direction: Direction,
116 encoding: Option<CompressionEncoding>,
117 max_message_size: Option<usize>,
118 ) -> Self
119 where
120 B: Body + Send + 'static,
121 B::Error: Into<crate::Error>,
122 D: Decoder<Item = T, Error = Status> + Send + 'static,
123 {
124 let buffer_size = decoder.buffer_settings().buffer_size;
125 Self {
126 decoder: Box::new(decoder),
127 inner: StreamingInner {
128 body: body
129 .map_frame(|frame| frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())))
130 .map_err(|err| Status::map_error(err.into()))
131 .boxed_unsync(),
132 state: State::ReadHeader,
133 direction,
134 buf: BytesMut::with_capacity(buffer_size),
135 trailers: None,
136 decompress_buf: BytesMut::new(),
137 encoding,
138 max_message_size,
139 },
140 }
141 }
142}
143
144impl StreamingInner {
145 fn decode_chunk(
146 &mut self,
147 buffer_settings: BufferSettings,
148 ) -> Result<Option<DecodeBuf<'_>>, Status> {
149 if let State::ReadHeader = self.state {
150 if self.buf.remaining() < HEADER_SIZE {
151 return Ok(None);
152 }
153
154 let compression_encoding = match self.buf.get_u8() {
155 0 => None,
156 1 => {
157 {
158 if self.encoding.is_some() {
159 self.encoding
160 } else {
161 return Err(Status::internal( "protocol error: received message with compressed-flag but no grpc-encoding was specified"));
166 }
167 }
168 }
169 f => {
170 trace!("unexpected compression flag");
171 let message = if let Direction::Response(status) = self.direction {
172 format!(
173 "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}",
174 f, status
175 )
176 } else {
177 format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f)
178 };
179 return Err(Status::internal(message));
180 }
181 };
182
183 let len = self.buf.get_u32() as usize;
184 let limit = self
185 .max_message_size
186 .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
187 if len > limit {
188 return Err(Status::out_of_range(
189 format!(
190 "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
191 len, limit
192 ),
193 ));
194 }
195
196 self.buf.reserve(len);
197
198 self.state = State::ReadBody {
199 compression: compression_encoding,
200 len,
201 }
202 }
203
204 if let State::ReadBody { len, compression } = self.state {
205 if self.buf.remaining() < len || self.buf.len() < len {
208 return Ok(None);
209 }
210
211 let decode_buf = if let Some(encoding) = compression {
212 self.decompress_buf.clear();
213
214 if let Err(err) = decompress(
215 CompressionSettings {
216 encoding,
217 buffer_growth_interval: buffer_settings.buffer_size,
218 },
219 &mut self.buf,
220 &mut self.decompress_buf,
221 len,
222 ) {
223 let message = if let Direction::Response(status) = self.direction {
224 format!(
225 "Error decompressing: {}, while receiving response with status: {}",
226 err, status
227 )
228 } else {
229 format!("Error decompressing: {}, while sending request", err)
230 };
231 return Err(Status::internal(message));
232 }
233 let decompressed_len = self.decompress_buf.len();
234 DecodeBuf::new(&mut self.decompress_buf, decompressed_len)
235 } else {
236 DecodeBuf::new(&mut self.buf, len)
237 };
238
239 return Ok(Some(decode_buf));
240 }
241
242 Ok(None)
243 }
244
245 fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> {
247 let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) {
248 Some(Ok(d)) => Some(d),
249 Some(Err(status)) => {
250 if self.direction == Direction::Request && status.code() == Code::Cancelled {
251 return Poll::Ready(Ok(None));
252 }
253
254 let _ = std::mem::replace(&mut self.state, State::Error(Some(status.clone())));
255 debug!("decoder inner stream error: {:?}", status);
256 return Poll::Ready(Err(status));
257 }
258 None => None,
259 };
260
261 Poll::Ready(if let Some(frame) = chunk {
262 match frame {
263 frame if frame.is_data() => {
264 self.buf.put(frame.into_data().unwrap());
265 Ok(Some(()))
266 }
267 frame if frame.is_trailers() => {
268 match &mut self.trailers {
269 Some(trailers) => {
270 trailers.extend(frame.into_trailers().unwrap());
271 }
272 None => {
273 self.trailers = Some(frame.into_trailers().unwrap());
274 }
275 }
276
277 Ok(None)
278 }
279 frame => panic!("unexpected frame: {:?}", frame),
280 }
281 } else {
282 if self.buf.has_remaining() {
284 trace!("unexpected EOF decoding stream, state: {:?}", self.state);
285 Err(Status::internal("Unexpected EOF decoding stream."))
286 } else {
287 Ok(None)
288 }
289 })
290 }
291
292 fn response(&mut self) -> Result<(), Status> {
293 if let Direction::Response(status) = self.direction {
294 if let Err(Some(e)) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) {
295 self.trailers.take();
298 return Err(e);
299 }
300 }
301 Ok(())
302 }
303}
304
305impl<T> Streaming<T> {
306 pub async fn message(&mut self) -> Result<Option<T>, Status> {
335 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
336 Some(Ok(m)) => Ok(Some(m)),
337 Some(Err(e)) => Err(e),
338 None => Ok(None),
339 }
340 }
341
342 pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
358 if let Some(trailers) = self.inner.trailers.take() {
361 return Ok(Some(MetadataMap::from_headers(trailers)));
362 }
363
364 while self.message().await?.is_some() {}
366
367 if let Some(trailers) = self.inner.trailers.take() {
370 return Ok(Some(MetadataMap::from_headers(trailers)));
371 }
372
373 Ok(None)
375 }
376
377 fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
378 match self.inner.decode_chunk(self.decoder.buffer_settings())? {
379 Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? {
380 Some(msg) => {
381 self.inner.state = State::ReadHeader;
382 Ok(Some(msg))
383 }
384 None => Ok(None),
385 },
386 None => Ok(None),
387 }
388 }
389}
390
391impl<T> Stream for Streaming<T> {
392 type Item = Result<T, Status>;
393
394 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395 loop {
396 if let State::Error(status) = &mut self.inner.state {
400 return Poll::Ready(status.take().map(Err));
401 }
402
403 if let Some(item) = self.decode_chunk()? {
404 return Poll::Ready(Some(Ok(item)));
405 }
406
407 match ready!(self.inner.poll_frame(cx))? {
408 Some(()) => (),
409 None => break,
410 }
411 }
412
413 Poll::Ready(match self.inner.response() {
414 Ok(()) => None,
415 Err(err) => Some(Err(err)),
416 })
417 }
418}
419
420impl<T> fmt::Debug for Streaming<T> {
421 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422 f.debug_struct("Streaming").finish()
423 }
424}
425
426#[cfg(test)]
427static_assertions::assert_impl_all!(Streaming<()>: Send);