1use crate::{metadata::MetadataValue, Status};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(feature = "gzip")]
4use flate2::read::{GzDecoder, GzEncoder};
5#[cfg(feature = "deflate")]
6use flate2::read::{ZlibDecoder, ZlibEncoder};
7use std::{borrow::Cow, fmt};
8#[cfg(feature = "zstd")]
9use zstd::stream::read::{Decoder, Encoder};
10
11pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14#[derive(Debug, Default, Clone, Copy)]
18pub struct EnabledCompressionEncodings {
19 inner: [Option<CompressionEncoding>; 3],
20}
21
22impl EnabledCompressionEncodings {
23 pub fn enable(&mut self, encoding: CompressionEncoding) {
27 for e in self.inner.iter_mut() {
28 match e {
29 Some(e) if *e == encoding => return,
30 None => {
31 *e = Some(encoding);
32 return;
33 }
34 _ => continue,
35 }
36 }
37 }
38
39 pub fn pop(&mut self) -> Option<CompressionEncoding> {
41 self.inner
42 .iter_mut()
43 .rev()
44 .find(|entry| entry.is_some())?
45 .take()
46 }
47
48 pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49 let mut value = BytesMut::new();
50 for encoding in self.inner.into_iter().flatten() {
51 value.put_slice(encoding.as_str().as_bytes());
52 value.put_u8(b',');
53 }
54
55 if value.is_empty() {
56 return None;
57 }
58
59 value.put_slice(b"identity");
60 Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61 }
62
63 pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65 self.inner.contains(&Some(encoding))
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.inner.iter().all(|e| e.is_none())
71 }
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq)]
75pub(crate) struct CompressionSettings {
76 pub(crate) encoding: CompressionEncoding,
77 pub(crate) buffer_growth_interval: usize,
80}
81
82#[derive(Clone, Copy, Debug, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum CompressionEncoding {
86 #[allow(missing_docs)]
87 #[cfg(feature = "gzip")]
88 Gzip,
89 #[allow(missing_docs)]
90 #[cfg(feature = "deflate")]
91 Deflate,
92 #[allow(missing_docs)]
93 #[cfg(feature = "zstd")]
94 Zstd,
95}
96
97impl CompressionEncoding {
98 pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99 #[cfg(feature = "gzip")]
100 CompressionEncoding::Gzip,
101 #[cfg(feature = "deflate")]
102 CompressionEncoding::Deflate,
103 #[cfg(feature = "zstd")]
104 CompressionEncoding::Zstd,
105 ];
106
107 pub(crate) fn from_accept_encoding_header(
109 map: &http::HeaderMap,
110 enabled_encodings: EnabledCompressionEncodings,
111 ) -> Option<Self> {
112 if enabled_encodings.is_empty() {
113 return None;
114 }
115
116 let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117 let header_value_str = header_value.to_str().ok()?;
118
119 split_by_comma(header_value_str).find_map(|value| match value {
120 #[cfg(feature = "gzip")]
121 "gzip" => Some(CompressionEncoding::Gzip),
122 #[cfg(feature = "deflate")]
123 "deflate" => Some(CompressionEncoding::Deflate),
124 #[cfg(feature = "zstd")]
125 "zstd" => Some(CompressionEncoding::Zstd),
126 _ => None,
127 })
128 }
129
130 pub(crate) fn from_encoding_header(
132 map: &http::HeaderMap,
133 enabled_encodings: EnabledCompressionEncodings,
134 ) -> Result<Option<Self>, Status> {
135 let Some(header_value) = map.get(ENCODING_HEADER) else {
136 return Ok(None);
137 };
138
139 match header_value.as_bytes() {
140 #[cfg(feature = "gzip")]
141 b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142 Ok(Some(CompressionEncoding::Gzip))
143 }
144 #[cfg(feature = "deflate")]
145 b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146 Ok(Some(CompressionEncoding::Deflate))
147 }
148 #[cfg(feature = "zstd")]
149 b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150 Ok(Some(CompressionEncoding::Zstd))
151 }
152 b"identity" => Ok(None),
153 other => {
154 let other = match std::str::from_utf8(other) {
155 Ok(s) => Cow::Borrowed(s),
156 Err(_) => Cow::Owned(format!("{other:?}")),
157 };
158
159 let mut status = Status::unimplemented(format!(
160 "Content is compressed with `{other}` which isn't supported"
161 ));
162
163 let header_value = enabled_encodings
164 .into_accept_encoding_header_value()
165 .map(MetadataValue::unchecked_from_header_value)
166 .unwrap_or_else(|| MetadataValue::from_static("identity"));
167 status
168 .metadata_mut()
169 .insert(ACCEPT_ENCODING_HEADER, header_value);
170
171 Err(status)
172 }
173 }
174 }
175
176 pub(crate) fn as_str(self) -> &'static str {
177 match self {
178 #[cfg(feature = "gzip")]
179 CompressionEncoding::Gzip => "gzip",
180 #[cfg(feature = "deflate")]
181 CompressionEncoding::Deflate => "deflate",
182 #[cfg(feature = "zstd")]
183 CompressionEncoding::Zstd => "zstd",
184 }
185 }
186
187 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
188 pub(crate) fn into_header_value(self) -> http::HeaderValue {
189 http::HeaderValue::from_static(self.as_str())
190 }
191}
192
193impl fmt::Display for CompressionEncoding {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.write_str(self.as_str())
196 }
197}
198
199fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
200 s.split(',').map(|s| s.trim())
201}
202
203#[allow(unused_variables, unreachable_code)]
206pub(crate) fn compress(
207 settings: CompressionSettings,
208 decompressed_buf: &mut BytesMut,
209 out_buf: &mut BytesMut,
210 len: usize,
211) -> Result<(), std::io::Error> {
212 let buffer_growth_interval = settings.buffer_growth_interval;
213 let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
214 out_buf.reserve(capacity);
215
216 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
217 let mut out_writer = out_buf.writer();
218
219 match settings.encoding {
220 #[cfg(feature = "gzip")]
221 CompressionEncoding::Gzip => {
222 let mut gzip_encoder = GzEncoder::new(
223 &decompressed_buf[0..len],
224 flate2::Compression::new(6),
226 );
227 std::io::copy(&mut gzip_encoder, &mut out_writer)?;
228 }
229 #[cfg(feature = "deflate")]
230 CompressionEncoding::Deflate => {
231 let mut deflate_encoder = ZlibEncoder::new(
232 &decompressed_buf[0..len],
233 flate2::Compression::new(6),
235 );
236 std::io::copy(&mut deflate_encoder, &mut out_writer)?;
237 }
238 #[cfg(feature = "zstd")]
239 CompressionEncoding::Zstd => {
240 let mut zstd_encoder = Encoder::new(
241 &decompressed_buf[0..len],
242 zstd::DEFAULT_COMPRESSION_LEVEL,
244 )?;
245 std::io::copy(&mut zstd_encoder, &mut out_writer)?;
246 }
247 }
248
249 decompressed_buf.advance(len);
250
251 Ok(())
252}
253
254#[allow(unused_variables, unreachable_code)]
256pub(crate) fn decompress(
257 settings: CompressionSettings,
258 compressed_buf: &mut BytesMut,
259 out_buf: &mut BytesMut,
260 len: usize,
261) -> Result<(), std::io::Error> {
262 let buffer_growth_interval = settings.buffer_growth_interval;
263 let estimate_decompressed_len = len * 2;
264 let capacity =
265 ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
266 out_buf.reserve(capacity);
267
268 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
269 let mut out_writer = out_buf.writer();
270
271 match settings.encoding {
272 #[cfg(feature = "gzip")]
273 CompressionEncoding::Gzip => {
274 let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
275 std::io::copy(&mut gzip_decoder, &mut out_writer)?;
276 }
277 #[cfg(feature = "deflate")]
278 CompressionEncoding::Deflate => {
279 let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
280 std::io::copy(&mut deflate_decoder, &mut out_writer)?;
281 }
282 #[cfg(feature = "zstd")]
283 CompressionEncoding::Zstd => {
284 let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
285 std::io::copy(&mut zstd_decoder, &mut out_writer)?;
286 }
287 }
288
289 compressed_buf.advance(len);
290
291 Ok(())
292}
293
294#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
296pub enum SingleMessageCompressionOverride {
297 #[default]
302 Inherit,
303 Disable,
305}
306
307#[cfg(test)]
308mod tests {
309 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
310 use http::HeaderValue;
311
312 use super::*;
313
314 #[test]
315 fn convert_none_into_header_value() {
316 let encodings = EnabledCompressionEncodings::default();
317
318 assert!(encodings.into_accept_encoding_header_value().is_none());
319 }
320
321 #[test]
322 #[cfg(feature = "gzip")]
323 fn convert_gzip_into_header_value() {
324 const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
325
326 let encodings = EnabledCompressionEncodings {
327 inner: [Some(CompressionEncoding::Gzip), None, None],
328 };
329
330 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
331
332 let encodings = EnabledCompressionEncodings {
333 inner: [None, None, Some(CompressionEncoding::Gzip)],
334 };
335
336 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
337 }
338
339 #[test]
340 #[cfg(feature = "zstd")]
341 fn convert_zstd_into_header_value() {
342 const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
343
344 let encodings = EnabledCompressionEncodings {
345 inner: [Some(CompressionEncoding::Zstd), None, None],
346 };
347
348 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
349
350 let encodings = EnabledCompressionEncodings {
351 inner: [None, None, Some(CompressionEncoding::Zstd)],
352 };
353
354 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
355 }
356
357 #[test]
358 #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
359 fn convert_compression_encodings_into_header_value() {
360 let encodings = EnabledCompressionEncodings {
361 inner: [
362 Some(CompressionEncoding::Gzip),
363 Some(CompressionEncoding::Deflate),
364 Some(CompressionEncoding::Zstd),
365 ],
366 };
367
368 assert_eq!(
369 encodings.into_accept_encoding_header_value().unwrap(),
370 HeaderValue::from_static("gzip,deflate,zstd,identity"),
371 );
372
373 let encodings = EnabledCompressionEncodings {
374 inner: [
375 Some(CompressionEncoding::Zstd),
376 Some(CompressionEncoding::Deflate),
377 Some(CompressionEncoding::Gzip),
378 ],
379 };
380
381 assert_eq!(
382 encodings.into_accept_encoding_header_value().unwrap(),
383 HeaderValue::from_static("zstd,deflate,gzip,identity"),
384 );
385 }
386}