reqwest/async_impl/
decoder.rs

1use std::fmt;
2#[cfg(any(
3    feature = "gzip",
4    feature = "zstd",
5    feature = "brotli",
6    feature = "deflate"
7))]
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{ready, Context, Poll};
11
12#[cfg(any(
13    feature = "gzip",
14    feature = "zstd",
15    feature = "brotli",
16    feature = "deflate"
17))]
18use futures_util::stream::Fuse;
19
20#[cfg(feature = "gzip")]
21use async_compression::tokio::bufread::GzipDecoder;
22
23#[cfg(feature = "brotli")]
24use async_compression::tokio::bufread::BrotliDecoder;
25
26#[cfg(feature = "zstd")]
27use async_compression::tokio::bufread::ZstdDecoder;
28
29#[cfg(feature = "deflate")]
30use async_compression::tokio::bufread::ZlibDecoder;
31
32#[cfg(any(
33    feature = "gzip",
34    feature = "zstd",
35    feature = "brotli",
36    feature = "deflate",
37    feature = "blocking",
38))]
39use futures_core::Stream;
40
41use bytes::Bytes;
42use http::HeaderMap;
43use hyper::body::Body as HttpBody;
44use hyper::body::Frame;
45
46#[cfg(any(
47    feature = "gzip",
48    feature = "brotli",
49    feature = "zstd",
50    feature = "deflate"
51))]
52use tokio_util::codec::{BytesCodec, FramedRead};
53#[cfg(any(
54    feature = "gzip",
55    feature = "brotli",
56    feature = "zstd",
57    feature = "deflate"
58))]
59use tokio_util::io::StreamReader;
60
61use super::body::ResponseBody;
62
63#[derive(Clone, Copy, Debug)]
64pub(super) struct Accepts {
65    #[cfg(feature = "gzip")]
66    pub(super) gzip: bool,
67    #[cfg(feature = "brotli")]
68    pub(super) brotli: bool,
69    #[cfg(feature = "zstd")]
70    pub(super) zstd: bool,
71    #[cfg(feature = "deflate")]
72    pub(super) deflate: bool,
73}
74
75impl Accepts {
76    pub fn none() -> Self {
77        Self {
78            #[cfg(feature = "gzip")]
79            gzip: false,
80            #[cfg(feature = "brotli")]
81            brotli: false,
82            #[cfg(feature = "zstd")]
83            zstd: false,
84            #[cfg(feature = "deflate")]
85            deflate: false,
86        }
87    }
88}
89
90/// A response decompressor over a non-blocking stream of chunks.
91///
92/// The inner decoder may be constructed asynchronously.
93pub(crate) struct Decoder {
94    inner: Inner,
95}
96
97#[cfg(any(
98    feature = "gzip",
99    feature = "zstd",
100    feature = "brotli",
101    feature = "deflate"
102))]
103type PeekableIoStream = futures_util::stream::Peekable<IoStream>;
104
105#[cfg(any(
106    feature = "gzip",
107    feature = "zstd",
108    feature = "brotli",
109    feature = "deflate"
110))]
111type PeekableIoStreamReader = StreamReader<PeekableIoStream, Bytes>;
112
113enum Inner {
114    /// A `PlainText` decoder just returns the response content as is.
115    PlainText(ResponseBody),
116
117    /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
118    #[cfg(feature = "gzip")]
119    Gzip(Pin<Box<Fuse<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
120
121    /// A `Brotli` decoder will uncompress the brotlied response content before returning it.
122    #[cfg(feature = "brotli")]
123    Brotli(Pin<Box<Fuse<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
124
125    /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it.
126    #[cfg(feature = "zstd")]
127    Zstd(Pin<Box<Fuse<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
128
129    /// A `Deflate` decoder will uncompress the deflated response content before returning it.
130    #[cfg(feature = "deflate")]
131    Deflate(Pin<Box<Fuse<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
132
133    /// A decoder that doesn't have a value yet.
134    #[cfg(any(
135        feature = "brotli",
136        feature = "zstd",
137        feature = "gzip",
138        feature = "deflate"
139    ))]
140    Pending(Pin<Box<Pending>>),
141}
142
143#[cfg(any(
144    feature = "gzip",
145    feature = "zstd",
146    feature = "brotli",
147    feature = "deflate"
148))]
149/// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
150struct Pending(PeekableIoStream, DecoderType);
151
152#[cfg(any(
153    feature = "gzip",
154    feature = "zstd",
155    feature = "brotli",
156    feature = "deflate",
157    feature = "blocking",
158))]
159pub(crate) struct IoStream<B = ResponseBody>(B);
160
161#[cfg(any(
162    feature = "gzip",
163    feature = "zstd",
164    feature = "brotli",
165    feature = "deflate"
166))]
167enum DecoderType {
168    #[cfg(feature = "gzip")]
169    Gzip,
170    #[cfg(feature = "brotli")]
171    Brotli,
172    #[cfg(feature = "zstd")]
173    Zstd,
174    #[cfg(feature = "deflate")]
175    Deflate,
176}
177
178impl fmt::Debug for Decoder {
179    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
180        f.debug_struct("Decoder").finish()
181    }
182}
183
184impl Decoder {
185    #[cfg(feature = "blocking")]
186    pub(crate) fn empty() -> Decoder {
187        Decoder {
188            inner: Inner::PlainText(empty()),
189        }
190    }
191
192    #[cfg(feature = "blocking")]
193    pub(crate) fn into_stream(self) -> IoStream<Self> {
194        IoStream(self)
195    }
196
197    /// A plain text decoder.
198    ///
199    /// This decoder will emit the underlying chunks as-is.
200    fn plain_text(body: ResponseBody) -> Decoder {
201        Decoder {
202            inner: Inner::PlainText(body),
203        }
204    }
205
206    /// A gzip decoder.
207    ///
208    /// This decoder will buffer and decompress chunks that are gzipped.
209    #[cfg(feature = "gzip")]
210    fn gzip(body: ResponseBody) -> Decoder {
211        use futures_util::StreamExt;
212
213        Decoder {
214            inner: Inner::Pending(Box::pin(Pending(
215                IoStream(body).peekable(),
216                DecoderType::Gzip,
217            ))),
218        }
219    }
220
221    /// A brotli decoder.
222    ///
223    /// This decoder will buffer and decompress chunks that are brotlied.
224    #[cfg(feature = "brotli")]
225    fn brotli(body: ResponseBody) -> Decoder {
226        use futures_util::StreamExt;
227
228        Decoder {
229            inner: Inner::Pending(Box::pin(Pending(
230                IoStream(body).peekable(),
231                DecoderType::Brotli,
232            ))),
233        }
234    }
235
236    /// A zstd decoder.
237    ///
238    /// This decoder will buffer and decompress chunks that are zstd compressed.
239    #[cfg(feature = "zstd")]
240    fn zstd(body: ResponseBody) -> Decoder {
241        use futures_util::StreamExt;
242
243        Decoder {
244            inner: Inner::Pending(Box::pin(Pending(
245                IoStream(body).peekable(),
246                DecoderType::Zstd,
247            ))),
248        }
249    }
250
251    /// A deflate decoder.
252    ///
253    /// This decoder will buffer and decompress chunks that are deflated.
254    #[cfg(feature = "deflate")]
255    fn deflate(body: ResponseBody) -> Decoder {
256        use futures_util::StreamExt;
257
258        Decoder {
259            inner: Inner::Pending(Box::pin(Pending(
260                IoStream(body).peekable(),
261                DecoderType::Deflate,
262            ))),
263        }
264    }
265
266    #[cfg(any(
267        feature = "brotli",
268        feature = "zstd",
269        feature = "gzip",
270        feature = "deflate"
271    ))]
272    fn detect_encoding(headers: &mut HeaderMap, encoding_str: &str) -> bool {
273        use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
274        use log::warn;
275
276        let mut is_content_encoded = {
277            headers
278                .get_all(CONTENT_ENCODING)
279                .iter()
280                .any(|enc| enc == encoding_str)
281                || headers
282                    .get_all(TRANSFER_ENCODING)
283                    .iter()
284                    .any(|enc| enc == encoding_str)
285        };
286        if is_content_encoded {
287            if let Some(content_length) = headers.get(CONTENT_LENGTH) {
288                if content_length == "0" {
289                    warn!("{encoding_str} response with content-length of 0");
290                    is_content_encoded = false;
291                }
292            }
293        }
294        if is_content_encoded {
295            headers.remove(CONTENT_ENCODING);
296            headers.remove(CONTENT_LENGTH);
297        }
298        is_content_encoded
299    }
300
301    /// Constructs a Decoder from a hyper request.
302    ///
303    /// A decoder is just a wrapper around the hyper request that knows
304    /// how to decode the content body of the request.
305    ///
306    /// Uses the correct variant by inspecting the Content-Encoding header.
307    pub(super) fn detect(
308        _headers: &mut HeaderMap,
309        body: ResponseBody,
310        _accepts: Accepts,
311    ) -> Decoder {
312        #[cfg(feature = "gzip")]
313        {
314            if _accepts.gzip && Decoder::detect_encoding(_headers, "gzip") {
315                return Decoder::gzip(body);
316            }
317        }
318
319        #[cfg(feature = "brotli")]
320        {
321            if _accepts.brotli && Decoder::detect_encoding(_headers, "br") {
322                return Decoder::brotli(body);
323            }
324        }
325
326        #[cfg(feature = "zstd")]
327        {
328            if _accepts.zstd && Decoder::detect_encoding(_headers, "zstd") {
329                return Decoder::zstd(body);
330            }
331        }
332
333        #[cfg(feature = "deflate")]
334        {
335            if _accepts.deflate && Decoder::detect_encoding(_headers, "deflate") {
336                return Decoder::deflate(body);
337            }
338        }
339
340        Decoder::plain_text(body)
341    }
342}
343
344impl HttpBody for Decoder {
345    type Data = Bytes;
346    type Error = crate::Error;
347
348    fn poll_frame(
349        mut self: Pin<&mut Self>,
350        cx: &mut Context,
351    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
352        match self.inner {
353            #[cfg(any(
354                feature = "brotli",
355                feature = "zstd",
356                feature = "gzip",
357                feature = "deflate"
358            ))]
359            Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) {
360                Poll::Ready(Ok(inner)) => {
361                    self.inner = inner;
362                    self.poll_frame(cx)
363                }
364                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(crate::error::decode_io(e)))),
365                Poll::Pending => Poll::Pending,
366            },
367            Inner::PlainText(ref mut body) => match ready!(Pin::new(body).poll_frame(cx)) {
368                Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
369                Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode(err)))),
370                None => Poll::Ready(None),
371            },
372            #[cfg(feature = "gzip")]
373            Inner::Gzip(ref mut decoder) => {
374                match ready!(Pin::new(&mut *decoder).poll_next(cx)) {
375                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
376                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
377                    None => {
378                        // poll inner connection until EOF after gzip stream is finished
379                        poll_inner_should_be_empty(
380                            decoder.get_mut().get_mut().get_mut().get_mut(),
381                            cx,
382                        )
383                    }
384                }
385            }
386            #[cfg(feature = "brotli")]
387            Inner::Brotli(ref mut decoder) => {
388                match ready!(Pin::new(&mut *decoder).poll_next(cx)) {
389                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
390                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
391                    None => {
392                        // poll inner connection until EOF after brotli stream is finished
393                        poll_inner_should_be_empty(
394                            decoder.get_mut().get_mut().get_mut().get_mut(),
395                            cx,
396                        )
397                    }
398                }
399            }
400            #[cfg(feature = "zstd")]
401            Inner::Zstd(ref mut decoder) => {
402                match ready!(Pin::new(&mut *decoder).poll_next(cx)) {
403                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
404                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
405                    None => {
406                        // poll inner connection until EOF after zstd stream is finished
407                        poll_inner_should_be_empty(
408                            decoder.get_mut().get_mut().get_mut().get_mut(),
409                            cx,
410                        )
411                    }
412                }
413            }
414            #[cfg(feature = "deflate")]
415            Inner::Deflate(ref mut decoder) => {
416                match ready!(Pin::new(&mut *decoder).poll_next(cx)) {
417                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
418                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
419                    None => {
420                        // poll inner connection until EOF after deflate stream is finished
421                        poll_inner_should_be_empty(
422                            decoder.get_mut().get_mut().get_mut().get_mut(),
423                            cx,
424                        )
425                    }
426                }
427            }
428        }
429    }
430
431    fn size_hint(&self) -> http_body::SizeHint {
432        match self.inner {
433            Inner::PlainText(ref body) => HttpBody::size_hint(body),
434            // the rest are "unknown", so default
435            #[cfg(any(
436                feature = "brotli",
437                feature = "zstd",
438                feature = "gzip",
439                feature = "deflate"
440            ))]
441            _ => http_body::SizeHint::default(),
442        }
443    }
444}
445
446#[cfg(any(
447    feature = "gzip",
448    feature = "zstd",
449    feature = "brotli",
450    feature = "deflate"
451))]
452fn poll_inner_should_be_empty(
453    inner: &mut PeekableIoStream,
454    cx: &mut Context,
455) -> Poll<Option<Result<Frame<Bytes>, crate::Error>>> {
456    // poll inner connection until EOF after deflate stream is finished
457    // loop in case of empty frames
458    let mut inner = Pin::new(inner);
459    loop {
460        match ready!(inner.as_mut().poll_next(cx)) {
461            // ignore any empty frames
462            Some(Ok(bytes)) if bytes.is_empty() => continue,
463            Some(Ok(_)) => {
464                return Poll::Ready(Some(Err(crate::error::decode(
465                    "there are extra bytes after body has been decompressed",
466                ))))
467            }
468            Some(Err(err)) => return Poll::Ready(Some(Err(crate::error::decode_io(err)))),
469            None => return Poll::Ready(None),
470        }
471    }
472}
473
474#[cfg(any(
475    feature = "gzip",
476    feature = "zstd",
477    feature = "brotli",
478    feature = "deflate",
479    feature = "blocking",
480))]
481fn empty() -> ResponseBody {
482    use http_body_util::{combinators::BoxBody, BodyExt, Empty};
483    BoxBody::new(Empty::new().map_err(|never| match never {}))
484}
485
486#[cfg(any(
487    feature = "gzip",
488    feature = "zstd",
489    feature = "brotli",
490    feature = "deflate"
491))]
492impl Future for Pending {
493    type Output = Result<Inner, std::io::Error>;
494
495    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
496        use futures_util::StreamExt;
497
498        match ready!(Pin::new(&mut self.0).poll_peek(cx)) {
499            Some(Ok(_)) => {
500                // fallthrough
501            }
502            Some(Err(_e)) => {
503                // error was just a ref, so we need to really poll to move it
504                return Poll::Ready(Err(ready!(Pin::new(&mut self.0).poll_next(cx))
505                    .expect("just peeked Some")
506                    .unwrap_err()));
507            }
508            None => return Poll::Ready(Ok(Inner::PlainText(empty()))),
509        };
510
511        let _body = std::mem::replace(&mut self.0, IoStream(empty()).peekable());
512
513        match self.1 {
514            #[cfg(feature = "brotli")]
515            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(
516                FramedRead::new(
517                    BrotliDecoder::new(StreamReader::new(_body)),
518                    BytesCodec::new(),
519                )
520                .fuse(),
521            )))),
522            #[cfg(feature = "zstd")]
523            DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(
524                FramedRead::new(
525                    {
526                        let mut d = ZstdDecoder::new(StreamReader::new(_body));
527                        d.multiple_members(true);
528                        d
529                    },
530                    BytesCodec::new(),
531                )
532                .fuse(),
533            )))),
534            #[cfg(feature = "gzip")]
535            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(
536                FramedRead::new(
537                    GzipDecoder::new(StreamReader::new(_body)),
538                    BytesCodec::new(),
539                )
540                .fuse(),
541            )))),
542            #[cfg(feature = "deflate")]
543            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(
544                FramedRead::new(
545                    ZlibDecoder::new(StreamReader::new(_body)),
546                    BytesCodec::new(),
547                )
548                .fuse(),
549            )))),
550        }
551    }
552}
553
554#[cfg(any(
555    feature = "gzip",
556    feature = "zstd",
557    feature = "brotli",
558    feature = "deflate",
559    feature = "blocking",
560))]
561impl<B> Stream for IoStream<B>
562where
563    B: HttpBody<Data = Bytes> + Unpin,
564    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
565{
566    type Item = Result<Bytes, std::io::Error>;
567
568    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
569        loop {
570            return match ready!(Pin::new(&mut self.0).poll_frame(cx)) {
571                Some(Ok(frame)) => {
572                    // skip non-data frames
573                    if let Ok(buf) = frame.into_data() {
574                        Poll::Ready(Some(Ok(buf)))
575                    } else {
576                        continue;
577                    }
578                }
579                Some(Err(err)) => Poll::Ready(Some(Err(crate::error::into_io(err.into())))),
580                None => Poll::Ready(None),
581            };
582        }
583    }
584}
585
586// ===== impl Accepts =====
587
588impl Accepts {
589    /*
590    pub(super) fn none() -> Self {
591        Accepts {
592            #[cfg(feature = "gzip")]
593            gzip: false,
594            #[cfg(feature = "brotli")]
595            brotli: false,
596            #[cfg(feature = "zstd")]
597            zstd: false,
598            #[cfg(feature = "deflate")]
599            deflate: false,
600        }
601    }
602    */
603
604    pub(super) const fn as_str(&self) -> Option<&'static str> {
605        match (
606            self.is_gzip(),
607            self.is_brotli(),
608            self.is_zstd(),
609            self.is_deflate(),
610        ) {
611            (true, true, true, true) => Some("gzip, br, zstd, deflate"),
612            (true, true, false, true) => Some("gzip, br, deflate"),
613            (true, true, true, false) => Some("gzip, br, zstd"),
614            (true, true, false, false) => Some("gzip, br"),
615            (true, false, true, true) => Some("gzip, zstd, deflate"),
616            (true, false, false, true) => Some("gzip, deflate"),
617            (false, true, true, true) => Some("br, zstd, deflate"),
618            (false, true, false, true) => Some("br, deflate"),
619            (true, false, true, false) => Some("gzip, zstd"),
620            (true, false, false, false) => Some("gzip"),
621            (false, true, true, false) => Some("br, zstd"),
622            (false, true, false, false) => Some("br"),
623            (false, false, true, true) => Some("zstd, deflate"),
624            (false, false, true, false) => Some("zstd"),
625            (false, false, false, true) => Some("deflate"),
626            (false, false, false, false) => None,
627        }
628    }
629
630    const fn is_gzip(&self) -> bool {
631        #[cfg(feature = "gzip")]
632        {
633            self.gzip
634        }
635
636        #[cfg(not(feature = "gzip"))]
637        {
638            false
639        }
640    }
641
642    const fn is_brotli(&self) -> bool {
643        #[cfg(feature = "brotli")]
644        {
645            self.brotli
646        }
647
648        #[cfg(not(feature = "brotli"))]
649        {
650            false
651        }
652    }
653
654    const fn is_zstd(&self) -> bool {
655        #[cfg(feature = "zstd")]
656        {
657            self.zstd
658        }
659
660        #[cfg(not(feature = "zstd"))]
661        {
662            false
663        }
664    }
665
666    const fn is_deflate(&self) -> bool {
667        #[cfg(feature = "deflate")]
668        {
669            self.deflate
670        }
671
672        #[cfg(not(feature = "deflate"))]
673        {
674            false
675        }
676    }
677}
678
679impl Default for Accepts {
680    fn default() -> Accepts {
681        Accepts {
682            #[cfg(feature = "gzip")]
683            gzip: true,
684            #[cfg(feature = "brotli")]
685            brotli: true,
686            #[cfg(feature = "zstd")]
687            zstd: true,
688            #[cfg(feature = "deflate")]
689            deflate: true,
690        }
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    #[test]
699    fn accepts_as_str() {
700        fn format_accept_encoding(accepts: &Accepts) -> String {
701            let mut encodings = vec![];
702            if accepts.is_gzip() {
703                encodings.push("gzip");
704            }
705            if accepts.is_brotli() {
706                encodings.push("br");
707            }
708            if accepts.is_zstd() {
709                encodings.push("zstd");
710            }
711            if accepts.is_deflate() {
712                encodings.push("deflate");
713            }
714            encodings.join(", ")
715        }
716
717        let state = [true, false];
718        let mut permutations = Vec::new();
719
720        #[allow(unused_variables)]
721        for gzip in state {
722            for brotli in state {
723                for zstd in state {
724                    for deflate in state {
725                        permutations.push(Accepts {
726                            #[cfg(feature = "gzip")]
727                            gzip,
728                            #[cfg(feature = "brotli")]
729                            brotli,
730                            #[cfg(feature = "zstd")]
731                            zstd,
732                            #[cfg(feature = "deflate")]
733                            deflate,
734                        });
735                    }
736                }
737            }
738        }
739
740        for accepts in permutations {
741            let expected = format_accept_encoding(&accepts);
742            let got = accepts.as_str().unwrap_or("");
743            assert_eq!(got, expected.as_str());
744        }
745    }
746}