tonic/codec/
compression.rs

1use crate::{metadata::MetadataValue, Status};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(feature = "gzip")]
4use flate2::read::{GzDecoder, GzEncoder};
5#[cfg(feature = "deflate")]
6use flate2::read::{ZlibDecoder, ZlibEncoder};
7use std::{borrow::Cow, fmt};
8#[cfg(feature = "zstd")]
9use zstd::stream::read::{Decoder, Encoder};
10
11pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14/// Struct used to configure which encodings are enabled on a server or channel.
15///
16/// Represents an ordered list of compression encodings that are enabled.
17#[derive(Debug, Default, Clone, Copy)]
18pub struct EnabledCompressionEncodings {
19    inner: [Option<CompressionEncoding>; 3],
20}
21
22impl EnabledCompressionEncodings {
23    /// Enable a [`CompressionEncoding`].
24    ///
25    /// Adds the new encoding to the end of the encoding list.
26    pub fn enable(&mut self, encoding: CompressionEncoding) {
27        for e in self.inner.iter_mut() {
28            match e {
29                Some(e) if *e == encoding => return,
30                None => {
31                    *e = Some(encoding);
32                    return;
33                }
34                _ => continue,
35            }
36        }
37    }
38
39    /// Remove the last [`CompressionEncoding`].
40    pub fn pop(&mut self) -> Option<CompressionEncoding> {
41        self.inner
42            .iter_mut()
43            .rev()
44            .find(|entry| entry.is_some())?
45            .take()
46    }
47
48    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49        let mut value = BytesMut::new();
50        for encoding in self.inner.into_iter().flatten() {
51            value.put_slice(encoding.as_str().as_bytes());
52            value.put_u8(b',');
53        }
54
55        if value.is_empty() {
56            return None;
57        }
58
59        value.put_slice(b"identity");
60        Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61    }
62
63    /// Check if a [`CompressionEncoding`] is enabled.
64    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65        self.inner.contains(&Some(encoding))
66    }
67
68    /// Check if any [`CompressionEncoding`]s are enabled.
69    pub fn is_empty(&self) -> bool {
70        self.inner.iter().all(|e| e.is_none())
71    }
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq)]
75pub(crate) struct CompressionSettings {
76    pub(crate) encoding: CompressionEncoding,
77    /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
78    /// The default buffer growth interval is 8 kilobytes.
79    pub(crate) buffer_growth_interval: usize,
80}
81
82/// The compression encodings Tonic supports.
83#[derive(Clone, Copy, Debug, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum CompressionEncoding {
86    #[allow(missing_docs)]
87    #[cfg(feature = "gzip")]
88    Gzip,
89    #[allow(missing_docs)]
90    #[cfg(feature = "deflate")]
91    Deflate,
92    #[allow(missing_docs)]
93    #[cfg(feature = "zstd")]
94    Zstd,
95}
96
97impl CompressionEncoding {
98    pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99        #[cfg(feature = "gzip")]
100        CompressionEncoding::Gzip,
101        #[cfg(feature = "deflate")]
102        CompressionEncoding::Deflate,
103        #[cfg(feature = "zstd")]
104        CompressionEncoding::Zstd,
105    ];
106
107    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
108    pub(crate) fn from_accept_encoding_header(
109        map: &http::HeaderMap,
110        enabled_encodings: EnabledCompressionEncodings,
111    ) -> Option<Self> {
112        if enabled_encodings.is_empty() {
113            return None;
114        }
115
116        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117        let header_value_str = header_value.to_str().ok()?;
118
119        split_by_comma(header_value_str).find_map(|value| match value {
120            #[cfg(feature = "gzip")]
121            "gzip" => Some(CompressionEncoding::Gzip),
122            #[cfg(feature = "deflate")]
123            "deflate" => Some(CompressionEncoding::Deflate),
124            #[cfg(feature = "zstd")]
125            "zstd" => Some(CompressionEncoding::Zstd),
126            _ => None,
127        })
128    }
129
130    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
131    pub(crate) fn from_encoding_header(
132        map: &http::HeaderMap,
133        enabled_encodings: EnabledCompressionEncodings,
134    ) -> Result<Option<Self>, Status> {
135        let Some(header_value) = map.get(ENCODING_HEADER) else {
136            return Ok(None);
137        };
138
139        match header_value.as_bytes() {
140            #[cfg(feature = "gzip")]
141            b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142                Ok(Some(CompressionEncoding::Gzip))
143            }
144            #[cfg(feature = "deflate")]
145            b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146                Ok(Some(CompressionEncoding::Deflate))
147            }
148            #[cfg(feature = "zstd")]
149            b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150                Ok(Some(CompressionEncoding::Zstd))
151            }
152            b"identity" => Ok(None),
153            other => {
154                let other = match std::str::from_utf8(other) {
155                    Ok(s) => Cow::Borrowed(s),
156                    Err(_) => Cow::Owned(format!("{other:?}")),
157                };
158
159                let mut status = Status::unimplemented(format!(
160                    "Content is compressed with `{other}` which isn't supported"
161                ));
162
163                let header_value = enabled_encodings
164                    .into_accept_encoding_header_value()
165                    .map(MetadataValue::unchecked_from_header_value)
166                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
167                status
168                    .metadata_mut()
169                    .insert(ACCEPT_ENCODING_HEADER, header_value);
170
171                Err(status)
172            }
173        }
174    }
175
176    pub(crate) fn as_str(self) -> &'static str {
177        match self {
178            #[cfg(feature = "gzip")]
179            CompressionEncoding::Gzip => "gzip",
180            #[cfg(feature = "deflate")]
181            CompressionEncoding::Deflate => "deflate",
182            #[cfg(feature = "zstd")]
183            CompressionEncoding::Zstd => "zstd",
184        }
185    }
186
187    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
188    pub(crate) fn into_header_value(self) -> http::HeaderValue {
189        http::HeaderValue::from_static(self.as_str())
190    }
191}
192
193impl fmt::Display for CompressionEncoding {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.write_str(self.as_str())
196    }
197}
198
199fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
200    s.split(',').map(|s| s.trim())
201}
202
203/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
204/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
205#[allow(unused_variables, unreachable_code)]
206pub(crate) fn compress(
207    settings: CompressionSettings,
208    decompressed_buf: &mut BytesMut,
209    out_buf: &mut BytesMut,
210    len: usize,
211) -> Result<(), std::io::Error> {
212    let buffer_growth_interval = settings.buffer_growth_interval;
213    let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
214    out_buf.reserve(capacity);
215
216    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
217    let mut out_writer = out_buf.writer();
218
219    match settings.encoding {
220        #[cfg(feature = "gzip")]
221        CompressionEncoding::Gzip => {
222            let mut gzip_encoder = GzEncoder::new(
223                &decompressed_buf[0..len],
224                // FIXME: support customizing the compression level
225                flate2::Compression::new(6),
226            );
227            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
228        }
229        #[cfg(feature = "deflate")]
230        CompressionEncoding::Deflate => {
231            let mut deflate_encoder = ZlibEncoder::new(
232                &decompressed_buf[0..len],
233                // FIXME: support customizing the compression level
234                flate2::Compression::new(6),
235            );
236            std::io::copy(&mut deflate_encoder, &mut out_writer)?;
237        }
238        #[cfg(feature = "zstd")]
239        CompressionEncoding::Zstd => {
240            let mut zstd_encoder = Encoder::new(
241                &decompressed_buf[0..len],
242                // FIXME: support customizing the compression level
243                zstd::DEFAULT_COMPRESSION_LEVEL,
244            )?;
245            std::io::copy(&mut zstd_encoder, &mut out_writer)?;
246        }
247    }
248
249    decompressed_buf.advance(len);
250
251    Ok(())
252}
253
254/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
255#[allow(unused_variables, unreachable_code)]
256pub(crate) fn decompress(
257    settings: CompressionSettings,
258    compressed_buf: &mut BytesMut,
259    out_buf: &mut BytesMut,
260    len: usize,
261) -> Result<(), std::io::Error> {
262    let buffer_growth_interval = settings.buffer_growth_interval;
263    let estimate_decompressed_len = len * 2;
264    let capacity =
265        ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
266    out_buf.reserve(capacity);
267
268    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
269    let mut out_writer = out_buf.writer();
270
271    match settings.encoding {
272        #[cfg(feature = "gzip")]
273        CompressionEncoding::Gzip => {
274            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
275            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
276        }
277        #[cfg(feature = "deflate")]
278        CompressionEncoding::Deflate => {
279            let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
280            std::io::copy(&mut deflate_decoder, &mut out_writer)?;
281        }
282        #[cfg(feature = "zstd")]
283        CompressionEncoding::Zstd => {
284            let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
285            std::io::copy(&mut zstd_decoder, &mut out_writer)?;
286        }
287    }
288
289    compressed_buf.advance(len);
290
291    Ok(())
292}
293
294/// Controls compression behavior for individual messages within a stream.
295#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
296pub enum SingleMessageCompressionOverride {
297    /// Inherit whatever compression is already configured. If the stream is compressed this
298    /// message will also be configured.
299    ///
300    /// This is the default.
301    #[default]
302    Inherit,
303    /// Don't compress this message, even if compression is enabled on the stream.
304    Disable,
305}
306
307#[cfg(test)]
308mod tests {
309    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
310    use http::HeaderValue;
311
312    use super::*;
313
314    #[test]
315    fn convert_none_into_header_value() {
316        let encodings = EnabledCompressionEncodings::default();
317
318        assert!(encodings.into_accept_encoding_header_value().is_none());
319    }
320
321    #[test]
322    #[cfg(feature = "gzip")]
323    fn convert_gzip_into_header_value() {
324        const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
325
326        let encodings = EnabledCompressionEncodings {
327            inner: [Some(CompressionEncoding::Gzip), None, None],
328        };
329
330        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
331
332        let encodings = EnabledCompressionEncodings {
333            inner: [None, None, Some(CompressionEncoding::Gzip)],
334        };
335
336        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
337    }
338
339    #[test]
340    #[cfg(feature = "zstd")]
341    fn convert_zstd_into_header_value() {
342        const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
343
344        let encodings = EnabledCompressionEncodings {
345            inner: [Some(CompressionEncoding::Zstd), None, None],
346        };
347
348        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
349
350        let encodings = EnabledCompressionEncodings {
351            inner: [None, None, Some(CompressionEncoding::Zstd)],
352        };
353
354        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
355    }
356
357    #[test]
358    #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
359    fn convert_compression_encodings_into_header_value() {
360        let encodings = EnabledCompressionEncodings {
361            inner: [
362                Some(CompressionEncoding::Gzip),
363                Some(CompressionEncoding::Deflate),
364                Some(CompressionEncoding::Zstd),
365            ],
366        };
367
368        assert_eq!(
369            encodings.into_accept_encoding_header_value().unwrap(),
370            HeaderValue::from_static("gzip,deflate,zstd,identity"),
371        );
372
373        let encodings = EnabledCompressionEncodings {
374            inner: [
375                Some(CompressionEncoding::Zstd),
376                Some(CompressionEncoding::Deflate),
377                Some(CompressionEncoding::Gzip),
378            ],
379        };
380
381        assert_eq!(
382            encodings.into_accept_encoding_header_value().unwrap(),
383            HeaderValue::from_static("zstd,deflate,gzip,identity"),
384        );
385    }
386}