1use crate::codec::compression::{
2 CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3};
4use crate::codec::EncodeBody;
5use crate::metadata::GRPC_CONTENT_TYPE;
6use crate::{
7 body::Body,
8 codec::{Codec, Streaming},
9 server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
10 Request, Status,
11};
12use http_body::Body as HttpBody;
13use std::{fmt, pin::pin};
14use tokio_stream::{Stream, StreamExt};
15
16macro_rules! t {
17 ($result:expr) => {
18 match $result {
19 Ok(value) => value,
20 Err(status) => return status.into_http(),
21 }
22 };
23}
24
25pub struct Grpc<T> {
35 codec: T,
36 accept_compression_encodings: EnabledCompressionEncodings,
38 send_compression_encodings: EnabledCompressionEncodings,
40 max_decoding_message_size: Option<usize>,
42 max_encoding_message_size: Option<usize>,
44}
45
46impl<T> Grpc<T>
47where
48 T: Codec,
49{
50 pub fn new(codec: T) -> Self {
52 Self {
53 codec,
54 accept_compression_encodings: EnabledCompressionEncodings::default(),
55 send_compression_encodings: EnabledCompressionEncodings::default(),
56 max_decoding_message_size: None,
57 max_encoding_message_size: None,
58 }
59 }
60
61 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
89 self.accept_compression_encodings.enable(encoding);
90 self
91 }
92
93 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
120 self.send_compression_encodings.enable(encoding);
121 self
122 }
123
124 pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
150 self.max_decoding_message_size = Some(limit);
151 self
152 }
153
154 pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
180 self.max_encoding_message_size = Some(limit);
181 self
182 }
183
184 #[doc(hidden)]
185 pub fn apply_compression_config(
186 mut self,
187 accept_encodings: EnabledCompressionEncodings,
188 send_encodings: EnabledCompressionEncodings,
189 ) -> Self {
190 for &encoding in CompressionEncoding::ENCODINGS {
191 if accept_encodings.is_enabled(encoding) {
192 self = self.accept_compressed(encoding);
193 }
194 if send_encodings.is_enabled(encoding) {
195 self = self.send_compressed(encoding);
196 }
197 }
198
199 self
200 }
201
202 #[doc(hidden)]
203 pub fn apply_max_message_size_config(
204 mut self,
205 max_decoding_message_size: Option<usize>,
206 max_encoding_message_size: Option<usize>,
207 ) -> Self {
208 if let Some(limit) = max_decoding_message_size {
209 self = self.max_decoding_message_size(limit);
210 }
211 if let Some(limit) = max_encoding_message_size {
212 self = self.max_encoding_message_size(limit);
213 }
214
215 self
216 }
217
218 pub async fn unary<S, B>(
220 &mut self,
221 mut service: S,
222 req: http::Request<B>,
223 ) -> http::Response<Body>
224 where
225 S: UnaryService<T::Decode, Response = T::Encode>,
226 B: HttpBody + Send + 'static,
227 B::Error: Into<crate::BoxError> + Send,
228 {
229 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
230 req.headers(),
231 self.send_compression_encodings,
232 );
233
234 let request = match self.map_request_unary(req).await {
235 Ok(r) => r,
236 Err(status) => {
237 return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
238 Err(status),
239 accept_encoding,
240 SingleMessageCompressionOverride::default(),
241 self.max_encoding_message_size,
242 );
243 }
244 };
245
246 let response = service
247 .call(request)
248 .await
249 .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
250
251 let compression_override = compression_override_from_response(&response);
252
253 self.map_response(
254 response,
255 accept_encoding,
256 compression_override,
257 self.max_encoding_message_size,
258 )
259 }
260
261 pub async fn server_streaming<S, B>(
263 &mut self,
264 mut service: S,
265 req: http::Request<B>,
266 ) -> http::Response<Body>
267 where
268 S: ServerStreamingService<T::Decode, Response = T::Encode>,
269 S::ResponseStream: Send + 'static,
270 B: HttpBody + Send + 'static,
271 B::Error: Into<crate::BoxError> + Send,
272 {
273 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
274 req.headers(),
275 self.send_compression_encodings,
276 );
277
278 let request = match self.map_request_unary(req).await {
279 Ok(r) => r,
280 Err(status) => {
281 return self.map_response::<S::ResponseStream>(
282 Err(status),
283 accept_encoding,
284 SingleMessageCompressionOverride::default(),
285 self.max_encoding_message_size,
286 );
287 }
288 };
289
290 let response = service.call(request).await;
291
292 self.map_response(
293 response,
294 accept_encoding,
295 SingleMessageCompressionOverride::default(),
298 self.max_encoding_message_size,
299 )
300 }
301
302 pub async fn client_streaming<S, B>(
304 &mut self,
305 mut service: S,
306 req: http::Request<B>,
307 ) -> http::Response<Body>
308 where
309 S: ClientStreamingService<T::Decode, Response = T::Encode>,
310 B: HttpBody + Send + 'static,
311 B::Error: Into<crate::BoxError> + Send + 'static,
312 {
313 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
314 req.headers(),
315 self.send_compression_encodings,
316 );
317
318 let request = t!(self.map_request_streaming(req));
319
320 let response = service
321 .call(request)
322 .await
323 .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
324
325 let compression_override = compression_override_from_response(&response);
326
327 self.map_response(
328 response,
329 accept_encoding,
330 compression_override,
331 self.max_encoding_message_size,
332 )
333 }
334
335 pub async fn streaming<S, B>(
337 &mut self,
338 mut service: S,
339 req: http::Request<B>,
340 ) -> http::Response<Body>
341 where
342 S: StreamingService<T::Decode, Response = T::Encode> + Send,
343 S::ResponseStream: Send + 'static,
344 B: HttpBody + Send + 'static,
345 B::Error: Into<crate::BoxError> + Send,
346 {
347 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
348 req.headers(),
349 self.send_compression_encodings,
350 );
351
352 let request = t!(self.map_request_streaming(req));
353
354 let response = service.call(request).await;
355
356 self.map_response(
357 response,
358 accept_encoding,
359 SingleMessageCompressionOverride::default(),
360 self.max_encoding_message_size,
361 )
362 }
363
364 async fn map_request_unary<B>(
365 &mut self,
366 request: http::Request<B>,
367 ) -> Result<Request<T::Decode>, Status>
368 where
369 B: HttpBody + Send + 'static,
370 B::Error: Into<crate::BoxError> + Send,
371 {
372 let request_compression_encoding = self.request_encoding_if_supported(&request)?;
373
374 let (parts, body) = request.into_parts();
375
376 let mut stream = pin!(Streaming::new_request(
377 self.codec.decoder(),
378 body,
379 request_compression_encoding,
380 self.max_decoding_message_size,
381 ));
382
383 let message = stream
384 .try_next()
385 .await?
386 .ok_or_else(|| Status::internal("Missing request message."))?;
387
388 let mut req = Request::from_http_parts(parts, message);
389
390 if let Some(trailers) = stream.trailers().await? {
391 req.metadata_mut().merge(trailers);
392 }
393
394 Ok(req)
395 }
396
397 fn map_request_streaming<B>(
398 &mut self,
399 request: http::Request<B>,
400 ) -> Result<Request<Streaming<T::Decode>>, Status>
401 where
402 B: HttpBody + Send + 'static,
403 B::Error: Into<crate::BoxError> + Send,
404 {
405 let encoding = self.request_encoding_if_supported(&request)?;
406
407 let request = request.map(|body| {
408 Streaming::new_request(
409 self.codec.decoder(),
410 body,
411 encoding,
412 self.max_decoding_message_size,
413 )
414 });
415
416 Ok(Request::from_http(request))
417 }
418
419 fn map_response<B>(
420 &mut self,
421 response: Result<crate::Response<B>, Status>,
422 accept_encoding: Option<CompressionEncoding>,
423 compression_override: SingleMessageCompressionOverride,
424 max_message_size: Option<usize>,
425 ) -> http::Response<Body>
426 where
427 B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
428 {
429 let response = t!(response);
430
431 let (mut parts, body) = response.into_http().into_parts();
432
433 parts
435 .headers
436 .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
437
438 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
439 if let Some(encoding) = accept_encoding {
440 parts.headers.insert(
442 crate::codec::compression::ENCODING_HEADER,
443 encoding.into_header_value(),
444 );
445 }
446
447 let body = EncodeBody::new_server(
448 self.codec.encoder(),
449 body,
450 accept_encoding,
451 compression_override,
452 max_message_size,
453 );
454
455 http::Response::from_parts(parts, Body::new(body))
456 }
457
458 fn request_encoding_if_supported<B>(
459 &self,
460 request: &http::Request<B>,
461 ) -> Result<Option<CompressionEncoding>, Status> {
462 CompressionEncoding::from_encoding_header(
463 request.headers(),
464 self.accept_compression_encodings,
465 )
466 }
467}
468
469impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 f.debug_struct("Grpc")
472 .field("codec", &self.codec)
473 .field(
474 "accept_compression_encodings",
475 &self.accept_compression_encodings,
476 )
477 .field(
478 "send_compression_encodings",
479 &self.send_compression_encodings,
480 )
481 .finish()
482 }
483}
484
485fn compression_override_from_response<B, E>(
486 res: &Result<crate::Response<B>, E>,
487) -> SingleMessageCompressionOverride {
488 res.as_ref()
489 .ok()
490 .and_then(|response| {
491 response
492 .extensions()
493 .get::<SingleMessageCompressionOverride>()
494 .copied()
495 })
496 .unwrap_or_default()
497}