tonic/codec/
compression.rs

1use crate::{metadata::MetadataValue, Status};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(feature = "gzip")]
4use flate2::read::{GzDecoder, GzEncoder};
5use std::fmt;
6#[cfg(feature = "zstd")]
7use zstd::stream::read::{Decoder, Encoder};
8
9pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
10pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
11
12/// Struct used to configure which encodings are enabled on a server or channel.
13///
14/// Represents an ordered list of compression encodings that are enabled.
15#[derive(Debug, Default, Clone, Copy)]
16pub struct EnabledCompressionEncodings {
17    inner: [Option<CompressionEncoding>; 2],
18}
19
20impl EnabledCompressionEncodings {
21    /// Enable a [`CompressionEncoding`].
22    ///
23    /// Adds the new encoding to the end of the encoding list.
24    pub fn enable(&mut self, encoding: CompressionEncoding) {
25        for e in self.inner.iter_mut() {
26            match e {
27                Some(e) if *e == encoding => return,
28                None => {
29                    *e = Some(encoding);
30                    return;
31                }
32                _ => continue,
33            }
34        }
35    }
36
37    /// Remove the last [`CompressionEncoding`].
38    pub fn pop(&mut self) -> Option<CompressionEncoding> {
39        self.inner
40            .iter_mut()
41            .rev()
42            .find(|entry| entry.is_some())?
43            .take()
44    }
45
46    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
47        let mut value = BytesMut::new();
48        for encoding in self.inner.into_iter().flatten() {
49            value.put_slice(encoding.as_str().as_bytes());
50            value.put_u8(b',');
51        }
52
53        if value.is_empty() {
54            return None;
55        }
56
57        value.put_slice(b"identity");
58        Some(http::HeaderValue::from_maybe_shared(value).unwrap())
59    }
60
61    /// Check if a [`CompressionEncoding`] is enabled.
62    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
63        self.inner.contains(&Some(encoding))
64    }
65
66    /// Check if any [`CompressionEncoding`]s are enabled.
67    pub fn is_empty(&self) -> bool {
68        self.inner.iter().all(|e| e.is_none())
69    }
70}
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub(crate) struct CompressionSettings {
74    pub(crate) encoding: CompressionEncoding,
75    /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
76    /// The default buffer growth interval is 8 kilobytes.
77    pub(crate) buffer_growth_interval: usize,
78}
79
80/// The compression encodings Tonic supports.
81#[derive(Clone, Copy, Debug, PartialEq, Eq)]
82#[non_exhaustive]
83pub enum CompressionEncoding {
84    #[allow(missing_docs)]
85    #[cfg(feature = "gzip")]
86    Gzip,
87    #[allow(missing_docs)]
88    #[cfg(feature = "zstd")]
89    Zstd,
90}
91
92impl CompressionEncoding {
93    pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
94        #[cfg(feature = "gzip")]
95        CompressionEncoding::Gzip,
96        #[cfg(feature = "zstd")]
97        CompressionEncoding::Zstd,
98    ];
99
100    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
101    pub(crate) fn from_accept_encoding_header(
102        map: &http::HeaderMap,
103        enabled_encodings: EnabledCompressionEncodings,
104    ) -> Option<Self> {
105        if enabled_encodings.is_empty() {
106            return None;
107        }
108
109        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
110        let header_value_str = header_value.to_str().ok()?;
111
112        split_by_comma(header_value_str).find_map(|value| match value {
113            #[cfg(feature = "gzip")]
114            "gzip" => Some(CompressionEncoding::Gzip),
115            #[cfg(feature = "zstd")]
116            "zstd" => Some(CompressionEncoding::Zstd),
117            _ => None,
118        })
119    }
120
121    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
122    pub(crate) fn from_encoding_header(
123        map: &http::HeaderMap,
124        enabled_encodings: EnabledCompressionEncodings,
125    ) -> Result<Option<Self>, Status> {
126        let Some(header_value) = map.get(ENCODING_HEADER) else {
127            return Ok(None);
128        };
129
130        match header_value.as_bytes() {
131            #[cfg(feature = "gzip")]
132            b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
133                Ok(Some(CompressionEncoding::Gzip))
134            }
135            #[cfg(feature = "zstd")]
136            b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
137                Ok(Some(CompressionEncoding::Zstd))
138            }
139            b"identity" => Ok(None),
140            other => {
141                // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79.
142                // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension
143                let other_debug_string;
144
145                let mut status = Status::unimplemented(format!(
146                    "Content is compressed with `{}` which isn't supported",
147                    match std::str::from_utf8(other) {
148                        Ok(s) => s,
149                        Err(_) => {
150                            other_debug_string = format!("{other:?}");
151                            &other_debug_string
152                        }
153                    }
154                ));
155
156                let header_value = enabled_encodings
157                    .into_accept_encoding_header_value()
158                    .map(MetadataValue::unchecked_from_header_value)
159                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
160                status
161                    .metadata_mut()
162                    .insert(ACCEPT_ENCODING_HEADER, header_value);
163
164                Err(status)
165            }
166        }
167    }
168
169    pub(crate) fn as_str(self) -> &'static str {
170        match self {
171            #[cfg(feature = "gzip")]
172            CompressionEncoding::Gzip => "gzip",
173            #[cfg(feature = "zstd")]
174            CompressionEncoding::Zstd => "zstd",
175        }
176    }
177
178    #[cfg(any(feature = "gzip", feature = "zstd"))]
179    pub(crate) fn into_header_value(self) -> http::HeaderValue {
180        http::HeaderValue::from_static(self.as_str())
181    }
182}
183
184impl fmt::Display for CompressionEncoding {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.write_str(self.as_str())
187    }
188}
189
190fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
191    s.split(',').map(|s| s.trim())
192}
193
194/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
195/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
196#[allow(unused_variables, unreachable_code)]
197pub(crate) fn compress(
198    settings: CompressionSettings,
199    decompressed_buf: &mut BytesMut,
200    out_buf: &mut BytesMut,
201    len: usize,
202) -> Result<(), std::io::Error> {
203    let buffer_growth_interval = settings.buffer_growth_interval;
204    let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
205    out_buf.reserve(capacity);
206
207    #[cfg(any(feature = "gzip", feature = "zstd"))]
208    let mut out_writer = out_buf.writer();
209
210    match settings.encoding {
211        #[cfg(feature = "gzip")]
212        CompressionEncoding::Gzip => {
213            let mut gzip_encoder = GzEncoder::new(
214                &decompressed_buf[0..len],
215                // FIXME: support customizing the compression level
216                flate2::Compression::new(6),
217            );
218            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
219        }
220        #[cfg(feature = "zstd")]
221        CompressionEncoding::Zstd => {
222            let mut zstd_encoder = Encoder::new(
223                &decompressed_buf[0..len],
224                // FIXME: support customizing the compression level
225                zstd::DEFAULT_COMPRESSION_LEVEL,
226            )?;
227            std::io::copy(&mut zstd_encoder, &mut out_writer)?;
228        }
229    }
230
231    decompressed_buf.advance(len);
232
233    Ok(())
234}
235
236/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
237#[allow(unused_variables, unreachable_code)]
238pub(crate) fn decompress(
239    settings: CompressionSettings,
240    compressed_buf: &mut BytesMut,
241    out_buf: &mut BytesMut,
242    len: usize,
243) -> Result<(), std::io::Error> {
244    let buffer_growth_interval = settings.buffer_growth_interval;
245    let estimate_decompressed_len = len * 2;
246    let capacity =
247        ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
248    out_buf.reserve(capacity);
249
250    #[cfg(any(feature = "gzip", feature = "zstd"))]
251    let mut out_writer = out_buf.writer();
252
253    match settings.encoding {
254        #[cfg(feature = "gzip")]
255        CompressionEncoding::Gzip => {
256            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
257            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
258        }
259        #[cfg(feature = "zstd")]
260        CompressionEncoding::Zstd => {
261            let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
262            std::io::copy(&mut zstd_decoder, &mut out_writer)?;
263        }
264    }
265
266    compressed_buf.advance(len);
267
268    Ok(())
269}
270
271#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
272pub enum SingleMessageCompressionOverride {
273    /// Inherit whatever compression is already configured. If the stream is compressed this
274    /// message will also be configured.
275    ///
276    /// This is the default.
277    #[default]
278    Inherit,
279    /// Don't compress this message, even if compression is enabled on the stream.
280    Disable,
281}
282
283#[cfg(test)]
284mod tests {
285    use http::HeaderValue;
286
287    use super::*;
288
289    #[test]
290    fn convert_none_into_header_value() {
291        let encodings = EnabledCompressionEncodings::default();
292
293        assert!(encodings.into_accept_encoding_header_value().is_none());
294    }
295
296    #[test]
297    #[cfg(feature = "gzip")]
298    fn convert_gzip_into_header_value() {
299        const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
300
301        let encodings = EnabledCompressionEncodings {
302            inner: [Some(CompressionEncoding::Gzip), None],
303        };
304
305        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
306
307        let encodings = EnabledCompressionEncodings {
308            inner: [None, Some(CompressionEncoding::Gzip)],
309        };
310
311        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
312    }
313
314    #[test]
315    #[cfg(feature = "zstd")]
316    fn convert_zstd_into_header_value() {
317        const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
318
319        let encodings = EnabledCompressionEncodings {
320            inner: [Some(CompressionEncoding::Zstd), None],
321        };
322
323        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
324
325        let encodings = EnabledCompressionEncodings {
326            inner: [None, Some(CompressionEncoding::Zstd)],
327        };
328
329        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
330    }
331
332    #[test]
333    #[cfg(all(feature = "gzip", feature = "zstd"))]
334    fn convert_gzip_and_zstd_into_header_value() {
335        let encodings = EnabledCompressionEncodings {
336            inner: [
337                Some(CompressionEncoding::Gzip),
338                Some(CompressionEncoding::Zstd),
339            ],
340        };
341
342        assert_eq!(
343            encodings.into_accept_encoding_header_value().unwrap(),
344            HeaderValue::from_static("gzip,zstd,identity"),
345        );
346
347        let encodings = EnabledCompressionEncodings {
348            inner: [
349                Some(CompressionEncoding::Zstd),
350                Some(CompressionEncoding::Gzip),
351            ],
352        };
353
354        assert_eq!(
355            encodings.into_accept_encoding_header_value().unwrap(),
356            HeaderValue::from_static("zstd,gzip,identity"),
357        );
358    }
359}