1use crate::{metadata::MetadataValue, Status};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(feature = "gzip")]
4use flate2::read::{GzDecoder, GzEncoder};
5use std::fmt;
6#[cfg(feature = "zstd")]
7use zstd::stream::read::{Decoder, Encoder};
8
9pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
10pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
11
12#[derive(Debug, Default, Clone, Copy)]
16pub struct EnabledCompressionEncodings {
17 inner: [Option<CompressionEncoding>; 2],
18}
19
20impl EnabledCompressionEncodings {
21 pub fn enable(&mut self, encoding: CompressionEncoding) {
25 for e in self.inner.iter_mut() {
26 match e {
27 Some(e) if *e == encoding => return,
28 None => {
29 *e = Some(encoding);
30 return;
31 }
32 _ => continue,
33 }
34 }
35 }
36
37 pub fn pop(&mut self) -> Option<CompressionEncoding> {
39 self.inner
40 .iter_mut()
41 .rev()
42 .find(|entry| entry.is_some())?
43 .take()
44 }
45
46 pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
47 let mut value = BytesMut::new();
48 for encoding in self.inner.into_iter().flatten() {
49 value.put_slice(encoding.as_str().as_bytes());
50 value.put_u8(b',');
51 }
52
53 if value.is_empty() {
54 return None;
55 }
56
57 value.put_slice(b"identity");
58 Some(http::HeaderValue::from_maybe_shared(value).unwrap())
59 }
60
61 pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
63 self.inner.contains(&Some(encoding))
64 }
65
66 pub fn is_empty(&self) -> bool {
68 self.inner.iter().all(|e| e.is_none())
69 }
70}
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub(crate) struct CompressionSettings {
74 pub(crate) encoding: CompressionEncoding,
75 pub(crate) buffer_growth_interval: usize,
78}
79
80#[derive(Clone, Copy, Debug, PartialEq, Eq)]
82#[non_exhaustive]
83pub enum CompressionEncoding {
84 #[allow(missing_docs)]
85 #[cfg(feature = "gzip")]
86 Gzip,
87 #[allow(missing_docs)]
88 #[cfg(feature = "zstd")]
89 Zstd,
90}
91
92impl CompressionEncoding {
93 pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
94 #[cfg(feature = "gzip")]
95 CompressionEncoding::Gzip,
96 #[cfg(feature = "zstd")]
97 CompressionEncoding::Zstd,
98 ];
99
100 pub(crate) fn from_accept_encoding_header(
102 map: &http::HeaderMap,
103 enabled_encodings: EnabledCompressionEncodings,
104 ) -> Option<Self> {
105 if enabled_encodings.is_empty() {
106 return None;
107 }
108
109 let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
110 let header_value_str = header_value.to_str().ok()?;
111
112 split_by_comma(header_value_str).find_map(|value| match value {
113 #[cfg(feature = "gzip")]
114 "gzip" => Some(CompressionEncoding::Gzip),
115 #[cfg(feature = "zstd")]
116 "zstd" => Some(CompressionEncoding::Zstd),
117 _ => None,
118 })
119 }
120
121 pub(crate) fn from_encoding_header(
123 map: &http::HeaderMap,
124 enabled_encodings: EnabledCompressionEncodings,
125 ) -> Result<Option<Self>, Status> {
126 let Some(header_value) = map.get(ENCODING_HEADER) else {
127 return Ok(None);
128 };
129
130 match header_value.as_bytes() {
131 #[cfg(feature = "gzip")]
132 b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
133 Ok(Some(CompressionEncoding::Gzip))
134 }
135 #[cfg(feature = "zstd")]
136 b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
137 Ok(Some(CompressionEncoding::Zstd))
138 }
139 b"identity" => Ok(None),
140 other => {
141 let other_debug_string;
144
145 let mut status = Status::unimplemented(format!(
146 "Content is compressed with `{}` which isn't supported",
147 match std::str::from_utf8(other) {
148 Ok(s) => s,
149 Err(_) => {
150 other_debug_string = format!("{other:?}");
151 &other_debug_string
152 }
153 }
154 ));
155
156 let header_value = enabled_encodings
157 .into_accept_encoding_header_value()
158 .map(MetadataValue::unchecked_from_header_value)
159 .unwrap_or_else(|| MetadataValue::from_static("identity"));
160 status
161 .metadata_mut()
162 .insert(ACCEPT_ENCODING_HEADER, header_value);
163
164 Err(status)
165 }
166 }
167 }
168
169 pub(crate) fn as_str(self) -> &'static str {
170 match self {
171 #[cfg(feature = "gzip")]
172 CompressionEncoding::Gzip => "gzip",
173 #[cfg(feature = "zstd")]
174 CompressionEncoding::Zstd => "zstd",
175 }
176 }
177
178 #[cfg(any(feature = "gzip", feature = "zstd"))]
179 pub(crate) fn into_header_value(self) -> http::HeaderValue {
180 http::HeaderValue::from_static(self.as_str())
181 }
182}
183
184impl fmt::Display for CompressionEncoding {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 f.write_str(self.as_str())
187 }
188}
189
190fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
191 s.split(',').map(|s| s.trim())
192}
193
194#[allow(unused_variables, unreachable_code)]
197pub(crate) fn compress(
198 settings: CompressionSettings,
199 decompressed_buf: &mut BytesMut,
200 out_buf: &mut BytesMut,
201 len: usize,
202) -> Result<(), std::io::Error> {
203 let buffer_growth_interval = settings.buffer_growth_interval;
204 let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
205 out_buf.reserve(capacity);
206
207 #[cfg(any(feature = "gzip", feature = "zstd"))]
208 let mut out_writer = out_buf.writer();
209
210 match settings.encoding {
211 #[cfg(feature = "gzip")]
212 CompressionEncoding::Gzip => {
213 let mut gzip_encoder = GzEncoder::new(
214 &decompressed_buf[0..len],
215 flate2::Compression::new(6),
217 );
218 std::io::copy(&mut gzip_encoder, &mut out_writer)?;
219 }
220 #[cfg(feature = "zstd")]
221 CompressionEncoding::Zstd => {
222 let mut zstd_encoder = Encoder::new(
223 &decompressed_buf[0..len],
224 zstd::DEFAULT_COMPRESSION_LEVEL,
226 )?;
227 std::io::copy(&mut zstd_encoder, &mut out_writer)?;
228 }
229 }
230
231 decompressed_buf.advance(len);
232
233 Ok(())
234}
235
236#[allow(unused_variables, unreachable_code)]
238pub(crate) fn decompress(
239 settings: CompressionSettings,
240 compressed_buf: &mut BytesMut,
241 out_buf: &mut BytesMut,
242 len: usize,
243) -> Result<(), std::io::Error> {
244 let buffer_growth_interval = settings.buffer_growth_interval;
245 let estimate_decompressed_len = len * 2;
246 let capacity =
247 ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
248 out_buf.reserve(capacity);
249
250 #[cfg(any(feature = "gzip", feature = "zstd"))]
251 let mut out_writer = out_buf.writer();
252
253 match settings.encoding {
254 #[cfg(feature = "gzip")]
255 CompressionEncoding::Gzip => {
256 let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
257 std::io::copy(&mut gzip_decoder, &mut out_writer)?;
258 }
259 #[cfg(feature = "zstd")]
260 CompressionEncoding::Zstd => {
261 let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
262 std::io::copy(&mut zstd_decoder, &mut out_writer)?;
263 }
264 }
265
266 compressed_buf.advance(len);
267
268 Ok(())
269}
270
271#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
272pub enum SingleMessageCompressionOverride {
273 #[default]
278 Inherit,
279 Disable,
281}
282
283#[cfg(test)]
284mod tests {
285 use http::HeaderValue;
286
287 use super::*;
288
289 #[test]
290 fn convert_none_into_header_value() {
291 let encodings = EnabledCompressionEncodings::default();
292
293 assert!(encodings.into_accept_encoding_header_value().is_none());
294 }
295
296 #[test]
297 #[cfg(feature = "gzip")]
298 fn convert_gzip_into_header_value() {
299 const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
300
301 let encodings = EnabledCompressionEncodings {
302 inner: [Some(CompressionEncoding::Gzip), None],
303 };
304
305 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
306
307 let encodings = EnabledCompressionEncodings {
308 inner: [None, Some(CompressionEncoding::Gzip)],
309 };
310
311 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
312 }
313
314 #[test]
315 #[cfg(feature = "zstd")]
316 fn convert_zstd_into_header_value() {
317 const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
318
319 let encodings = EnabledCompressionEncodings {
320 inner: [Some(CompressionEncoding::Zstd), None],
321 };
322
323 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
324
325 let encodings = EnabledCompressionEncodings {
326 inner: [None, Some(CompressionEncoding::Zstd)],
327 };
328
329 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
330 }
331
332 #[test]
333 #[cfg(all(feature = "gzip", feature = "zstd"))]
334 fn convert_gzip_and_zstd_into_header_value() {
335 let encodings = EnabledCompressionEncodings {
336 inner: [
337 Some(CompressionEncoding::Gzip),
338 Some(CompressionEncoding::Zstd),
339 ],
340 };
341
342 assert_eq!(
343 encodings.into_accept_encoding_header_value().unwrap(),
344 HeaderValue::from_static("gzip,zstd,identity"),
345 );
346
347 let encodings = EnabledCompressionEncodings {
348 inner: [
349 Some(CompressionEncoding::Zstd),
350 Some(CompressionEncoding::Gzip),
351 ],
352 };
353
354 assert_eq!(
355 encodings.into_accept_encoding_header_value().unwrap(),
356 HeaderValue::from_static("zstd,gzip,identity"),
357 );
358 }
359}