1use super::compression::{CompressionEncoding, CompressionSettings, decompress};
2use super::{BufferSettings, DEFAULT_MAX_RECV_MESSAGE_SIZE, DecodeBuf, Decoder, HEADER_SIZE};
3use crate::{Code, Status, body::Body, metadata::MetadataMap};
4use bytes::{Buf, BufMut, BytesMut};
5use http::{HeaderMap, StatusCode};
6use http_body::Body as HttpBody;
7use http_body_util::BodyExt;
8use std::{
9 fmt, future,
10 pin::Pin,
11 task::ready,
12 task::{Context, Poll},
13};
14use sync_wrapper::SyncWrapper;
15use tokio_stream::Stream;
16use tracing::{debug, trace};
17
18pub struct Streaming<T> {
23 decoder: SyncWrapper<Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>>,
24 inner: StreamingInner,
25}
26
27struct StreamingInner {
28 body: SyncWrapper<Body>,
29 state: State,
30 direction: Direction,
31 buf: BytesMut,
32 trailers: Option<HeaderMap>,
33 decompress_buf: BytesMut,
34 encoding: Option<CompressionEncoding>,
35 max_message_size: Option<usize>,
36}
37
38impl<T> Unpin for Streaming<T> {}
39
40#[derive(Debug, Clone)]
41enum State {
42 ReadHeader,
43 ReadBody {
44 compression: Option<CompressionEncoding>,
45 len: usize,
46 },
47 Error(Option<Status>),
48}
49
50#[derive(Debug, PartialEq, Eq)]
51enum Direction {
52 Request,
53 Response(StatusCode),
54 EmptyResponse,
55}
56
57impl<T> Streaming<T> {
58 pub fn new_response<B, D>(
61 decoder: D,
62 body: B,
63 status_code: StatusCode,
64 encoding: Option<CompressionEncoding>,
65 max_message_size: Option<usize>,
66 ) -> Self
67 where
68 B: HttpBody + Send + 'static,
69 B::Error: Into<crate::BoxError>,
70 D: Decoder<Item = T, Error = Status> + Send + 'static,
71 {
72 Self::new(
73 decoder,
74 body,
75 Direction::Response(status_code),
76 encoding,
77 max_message_size,
78 )
79 }
80
81 pub fn new_empty<B, D>(decoder: D, body: B) -> Self
83 where
84 B: HttpBody + Send + 'static,
85 B::Error: Into<crate::BoxError>,
86 D: Decoder<Item = T, Error = Status> + Send + 'static,
87 {
88 Self::new(decoder, body, Direction::EmptyResponse, None, None)
89 }
90
91 pub fn new_request<B, D>(
94 decoder: D,
95 body: B,
96 encoding: Option<CompressionEncoding>,
97 max_message_size: Option<usize>,
98 ) -> Self
99 where
100 B: HttpBody + Send + 'static,
101 B::Error: Into<crate::BoxError>,
102 D: Decoder<Item = T, Error = Status> + Send + 'static,
103 {
104 Self::new(
105 decoder,
106 body,
107 Direction::Request,
108 encoding,
109 max_message_size,
110 )
111 }
112
113 fn new<B, D>(
114 decoder: D,
115 body: B,
116 direction: Direction,
117 encoding: Option<CompressionEncoding>,
118 max_message_size: Option<usize>,
119 ) -> Self
120 where
121 B: HttpBody + Send + 'static,
122 B::Error: Into<crate::BoxError>,
123 D: Decoder<Item = T, Error = Status> + Send + 'static,
124 {
125 let buffer_size = decoder.buffer_settings().buffer_size;
126 Self {
127 decoder: SyncWrapper::new(Box::new(decoder)),
128 inner: StreamingInner {
129 body: SyncWrapper::new(Body::new(
130 body.map_frame(|frame| {
131 frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
132 })
133 .map_err(|err| Status::map_error(err.into())),
134 )),
135 state: State::ReadHeader,
136 direction,
137 buf: BytesMut::with_capacity(buffer_size),
138 trailers: None,
139 decompress_buf: BytesMut::new(),
140 encoding,
141 max_message_size,
142 },
143 }
144 }
145}
146
147impl StreamingInner {
148 fn decode_chunk(
149 &mut self,
150 buffer_settings: BufferSettings,
151 ) -> Result<Option<DecodeBuf<'_>>, Status> {
152 if let State::ReadHeader = self.state {
153 if self.buf.remaining() < HEADER_SIZE {
154 return Ok(None);
155 }
156
157 let compression_encoding = match self.buf.get_u8() {
158 0 => None,
159 1 => {
160 {
161 if self.encoding.is_some() {
162 self.encoding
163 } else {
164 return Err(Status::internal(
169 "protocol error: received message with compressed-flag but no grpc-encoding was specified",
170 ));
171 }
172 }
173 }
174 f => {
175 trace!("unexpected compression flag");
176 let message = if let Direction::Response(status) = self.direction {
177 format!(
178 "protocol error: received message with invalid compression flag: {f} (valid flags are 0 and 1) while receiving response with status: {status}"
179 )
180 } else {
181 format!(
182 "protocol error: received message with invalid compression flag: {f} (valid flags are 0 and 1), while sending request"
183 )
184 };
185 return Err(Status::internal(message));
186 }
187 };
188
189 let len = self.buf.get_u32() as usize;
190 let limit = self
191 .max_message_size
192 .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
193 if len > limit {
194 return Err(Status::out_of_range(format!(
195 "Error, decoded message length too large: found {len} bytes, the limit is: {limit} bytes"
196 )));
197 }
198
199 self.buf.reserve(len);
200
201 self.state = State::ReadBody {
202 compression: compression_encoding,
203 len,
204 }
205 }
206
207 if let State::ReadBody { len, compression } = self.state {
208 if self.buf.remaining() < len || self.buf.len() < len {
211 return Ok(None);
212 }
213
214 let decode_buf = if let Some(encoding) = compression {
215 self.decompress_buf.clear();
216 let limit = self
217 .max_message_size
218 .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
219 let limited_out_buf = (&mut self.decompress_buf).limit(limit);
220
221 if let Err(err) = decompress(
222 CompressionSettings {
223 encoding,
224 buffer_growth_interval: buffer_settings.buffer_size,
225 },
226 &mut self.buf,
227 limited_out_buf,
228 len,
229 ) {
230 if matches!(err.kind(), std::io::ErrorKind::WriteZero) {
231 return Err(Status::resource_exhausted(format!(
232 "Error decompressing: size limit, of {limit} bytes, exceeded while decompressing message"
233 )));
234 }
235 let message = if let Direction::Response(status) = self.direction {
236 format!(
237 "Error decompressing: {err}, while receiving response with status: {status}"
238 )
239 } else {
240 format!("Error decompressing: {err}, while sending request")
241 };
242 return Err(Status::internal(message));
243 }
244 let decompressed_len = self.decompress_buf.len();
245 DecodeBuf::new(&mut self.decompress_buf, decompressed_len)
246 } else {
247 DecodeBuf::new(&mut self.buf, len)
248 };
249
250 return Ok(Some(decode_buf));
251 }
252
253 Ok(None)
254 }
255
256 fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> {
258 let frame = match ready!(Pin::new(self.body.get_mut()).poll_frame(cx)) {
259 Some(Ok(frame)) => frame,
260 Some(Err(status)) => {
261 if self.direction == Direction::Request && status.code() == Code::Cancelled {
262 return Poll::Ready(Ok(None));
263 }
264
265 let _ = std::mem::replace(&mut self.state, State::Error(Some(status.clone())));
266 debug!("decoder inner stream error: {:?}", status);
267 return Poll::Ready(Err(status));
268 }
269 None => {
270 return Poll::Ready(if self.buf.has_remaining() {
272 trace!("unexpected EOF decoding stream, state: {:?}", self.state);
273 Err(Status::internal("Unexpected EOF decoding stream."))
274 } else {
275 Ok(None)
276 });
277 }
278 };
279
280 Poll::Ready(if frame.is_data() {
281 self.buf.put(frame.into_data().unwrap());
282 Ok(Some(()))
283 } else if frame.is_trailers() {
284 if let Some(trailers) = &mut self.trailers {
285 trailers.extend(frame.into_trailers().unwrap());
286 } else {
287 self.trailers = Some(frame.into_trailers().unwrap());
288 }
289
290 Ok(None)
291 } else {
292 panic!("unexpected frame: {frame:?}");
293 })
294 }
295
296 fn response(&mut self) -> Result<(), Status> {
297 if let Direction::Response(status) = self.direction {
298 if let Err(Some(e)) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) {
299 self.trailers.take();
302 return Err(e);
303 }
304 }
305 Ok(())
306 }
307}
308
309impl<T> Streaming<T> {
310 pub async fn message(&mut self) -> Result<Option<T>, Status> {
339 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
340 Some(Ok(m)) => Ok(Some(m)),
341 Some(Err(e)) => Err(e),
342 None => Ok(None),
343 }
344 }
345
346 pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
362 if let Some(trailers) = self.inner.trailers.take() {
365 return Ok(Some(MetadataMap::from_headers(trailers)));
366 }
367
368 while self.message().await?.is_some() {}
370
371 if let Some(trailers) = self.inner.trailers.take() {
374 return Ok(Some(MetadataMap::from_headers(trailers)));
375 }
376
377 Ok(None)
379 }
380
381 fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
382 match self
383 .inner
384 .decode_chunk(self.decoder.get_mut().buffer_settings())?
385 {
386 Some(mut decode_buf) => match self.decoder.get_mut().decode(&mut decode_buf)? {
387 Some(msg) => {
388 self.inner.state = State::ReadHeader;
389 Ok(Some(msg))
390 }
391 None => Ok(None),
392 },
393 None => Ok(None),
394 }
395 }
396}
397
398impl<T> Stream for Streaming<T> {
399 type Item = Result<T, Status>;
400
401 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
402 loop {
403 if let State::Error(status) = &mut self.inner.state {
407 return Poll::Ready(status.take().map(Err));
408 }
409
410 if let Some(item) = self.decode_chunk()? {
411 return Poll::Ready(Some(Ok(item)));
412 }
413
414 if ready!(self.inner.poll_frame(cx))?.is_none() {
415 match self.inner.response() {
416 Ok(()) => return Poll::Ready(None),
417 Err(err) => self.inner.state = State::Error(Some(err)),
418 }
419 }
420 }
421 }
422}
423
424impl<T> fmt::Debug for Streaming<T> {
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 f.debug_struct("Streaming").finish()
427 }
428}
429
430#[cfg(test)]
431static_assertions::assert_impl_all!(Streaming<()>: Send, Sync);