tonic/server/
grpc.rs

1use crate::codec::compression::{
2    CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3};
4use crate::codec::EncodeBody;
5use crate::metadata::GRPC_CONTENT_TYPE;
6use crate::{
7    body::Body,
8    codec::{Codec, Streaming},
9    server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
10    Request, Status,
11};
12use http_body::Body as HttpBody;
13use std::{fmt, pin::pin};
14use tokio_stream::{Stream, StreamExt};
15
16macro_rules! t {
17    ($result:expr) => {
18        match $result {
19            Ok(value) => value,
20            Err(status) => return status.into_http(),
21        }
22    };
23}
24
25/// A gRPC Server handler.
26///
27/// This will wrap some inner [`Codec`] and provide utilities to handle
28/// inbound unary, client side streaming, server side streaming, and
29/// bi-directional streaming.
30///
31/// Each request handler method accepts some service that implements the
32/// corresponding service trait and a http request that contains some body that
33/// implements some [`Body`].
34pub struct Grpc<T> {
35    codec: T,
36    /// Which compression encodings does the server accept for requests?
37    accept_compression_encodings: EnabledCompressionEncodings,
38    /// Which compression encodings might the server use for responses.
39    send_compression_encodings: EnabledCompressionEncodings,
40    /// Limits the maximum size of a decoded message.
41    max_decoding_message_size: Option<usize>,
42    /// Limits the maximum size of an encoded message.
43    max_encoding_message_size: Option<usize>,
44}
45
46impl<T> Grpc<T>
47where
48    T: Codec,
49{
50    /// Creates a new gRPC server with the provided [`Codec`].
51    pub fn new(codec: T) -> Self {
52        Self {
53            codec,
54            accept_compression_encodings: EnabledCompressionEncodings::default(),
55            send_compression_encodings: EnabledCompressionEncodings::default(),
56            max_decoding_message_size: None,
57            max_encoding_message_size: None,
58        }
59    }
60
61    /// Enable accepting compressed requests.
62    ///
63    /// If a request with an unsupported encoding is received the server will respond with
64    /// [`Code::UnUnimplemented`](crate::Code).
65    ///
66    /// # Example
67    ///
68    /// The most common way of using this is through a server generated by tonic-build:
69    ///
70    /// ```rust
71    /// # enum CompressionEncoding { Gzip }
72    /// # struct Svc;
73    /// # struct ExampleServer<T>(T);
74    /// # impl<T> ExampleServer<T> {
75    /// #     fn new(svc: T) -> Self { Self(svc) }
76    /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
77    /// # }
78    /// # #[tonic::async_trait]
79    /// # trait Example {}
80    ///
81    /// #[tonic::async_trait]
82    /// impl Example for Svc {
83    ///     // ...
84    /// }
85    ///
86    /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip);
87    /// ```
88    pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
89        self.accept_compression_encodings.enable(encoding);
90        self
91    }
92
93    /// Enable sending compressed responses.
94    ///
95    /// Requires the client to also support receiving compressed responses.
96    ///
97    /// # Example
98    ///
99    /// The most common way of using this is through a server generated by tonic-build:
100    ///
101    /// ```rust
102    /// # enum CompressionEncoding { Gzip }
103    /// # struct Svc;
104    /// # struct ExampleServer<T>(T);
105    /// # impl<T> ExampleServer<T> {
106    /// #     fn new(svc: T) -> Self { Self(svc) }
107    /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
108    /// # }
109    /// # #[tonic::async_trait]
110    /// # trait Example {}
111    ///
112    /// #[tonic::async_trait]
113    /// impl Example for Svc {
114    ///     // ...
115    /// }
116    ///
117    /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip);
118    /// ```
119    pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
120        self.send_compression_encodings.enable(encoding);
121        self
122    }
123
124    /// Limits the maximum size of a decoded message.
125    ///
126    /// # Example
127    ///
128    /// The most common way of using this is through a server generated by tonic-build:
129    ///
130    /// ```rust
131    /// # struct Svc;
132    /// # struct ExampleServer<T>(T);
133    /// # impl<T> ExampleServer<T> {
134    /// #     fn new(svc: T) -> Self { Self(svc) }
135    /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
136    /// # }
137    /// # #[tonic::async_trait]
138    /// # trait Example {}
139    ///
140    /// #[tonic::async_trait]
141    /// impl Example for Svc {
142    ///     // ...
143    /// }
144    ///
145    /// // Set the limit to 2MB, Defaults to 4MB.
146    /// let limit = 2 * 1024 * 1024;
147    /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit);
148    /// ```
149    pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
150        self.max_decoding_message_size = Some(limit);
151        self
152    }
153
154    /// Limits the maximum size of a encoded message.
155    ///
156    /// # Example
157    ///
158    /// The most common way of using this is through a server generated by tonic-build:
159    ///
160    /// ```rust
161    /// # struct Svc;
162    /// # struct ExampleServer<T>(T);
163    /// # impl<T> ExampleServer<T> {
164    /// #     fn new(svc: T) -> Self { Self(svc) }
165    /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
166    /// # }
167    /// # #[tonic::async_trait]
168    /// # trait Example {}
169    ///
170    /// #[tonic::async_trait]
171    /// impl Example for Svc {
172    ///     // ...
173    /// }
174    ///
175    /// // Set the limit to 2MB, Defaults to 4MB.
176    /// let limit = 2 * 1024 * 1024;
177    /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit);
178    /// ```
179    pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
180        self.max_encoding_message_size = Some(limit);
181        self
182    }
183
184    #[doc(hidden)]
185    pub fn apply_compression_config(
186        mut self,
187        accept_encodings: EnabledCompressionEncodings,
188        send_encodings: EnabledCompressionEncodings,
189    ) -> Self {
190        for &encoding in CompressionEncoding::ENCODINGS {
191            if accept_encodings.is_enabled(encoding) {
192                self = self.accept_compressed(encoding);
193            }
194            if send_encodings.is_enabled(encoding) {
195                self = self.send_compressed(encoding);
196            }
197        }
198
199        self
200    }
201
202    #[doc(hidden)]
203    pub fn apply_max_message_size_config(
204        mut self,
205        max_decoding_message_size: Option<usize>,
206        max_encoding_message_size: Option<usize>,
207    ) -> Self {
208        if let Some(limit) = max_decoding_message_size {
209            self = self.max_decoding_message_size(limit);
210        }
211        if let Some(limit) = max_encoding_message_size {
212            self = self.max_encoding_message_size(limit);
213        }
214
215        self
216    }
217
218    /// Handle a single unary gRPC request.
219    pub async fn unary<S, B>(
220        &mut self,
221        mut service: S,
222        req: http::Request<B>,
223    ) -> http::Response<Body>
224    where
225        S: UnaryService<T::Decode, Response = T::Encode>,
226        B: HttpBody + Send + 'static,
227        B::Error: Into<crate::BoxError> + Send,
228    {
229        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
230            req.headers(),
231            self.send_compression_encodings,
232        );
233
234        let request = match self.map_request_unary(req).await {
235            Ok(r) => r,
236            Err(status) => {
237                return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
238                    Err(status),
239                    accept_encoding,
240                    SingleMessageCompressionOverride::default(),
241                    self.max_encoding_message_size,
242                );
243            }
244        };
245
246        let response = service
247            .call(request)
248            .await
249            .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
250
251        let compression_override = compression_override_from_response(&response);
252
253        self.map_response(
254            response,
255            accept_encoding,
256            compression_override,
257            self.max_encoding_message_size,
258        )
259    }
260
261    /// Handle a server side streaming request.
262    pub async fn server_streaming<S, B>(
263        &mut self,
264        mut service: S,
265        req: http::Request<B>,
266    ) -> http::Response<Body>
267    where
268        S: ServerStreamingService<T::Decode, Response = T::Encode>,
269        S::ResponseStream: Send + 'static,
270        B: HttpBody + Send + 'static,
271        B::Error: Into<crate::BoxError> + Send,
272    {
273        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
274            req.headers(),
275            self.send_compression_encodings,
276        );
277
278        let request = match self.map_request_unary(req).await {
279            Ok(r) => r,
280            Err(status) => {
281                return self.map_response::<S::ResponseStream>(
282                    Err(status),
283                    accept_encoding,
284                    SingleMessageCompressionOverride::default(),
285                    self.max_encoding_message_size,
286                );
287            }
288        };
289
290        let response = service.call(request).await;
291
292        self.map_response(
293            response,
294            accept_encoding,
295            // disabling compression of individual stream items must be done on
296            // the items themselves
297            SingleMessageCompressionOverride::default(),
298            self.max_encoding_message_size,
299        )
300    }
301
302    /// Handle a client side streaming gRPC request.
303    pub async fn client_streaming<S, B>(
304        &mut self,
305        mut service: S,
306        req: http::Request<B>,
307    ) -> http::Response<Body>
308    where
309        S: ClientStreamingService<T::Decode, Response = T::Encode>,
310        B: HttpBody + Send + 'static,
311        B::Error: Into<crate::BoxError> + Send + 'static,
312    {
313        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
314            req.headers(),
315            self.send_compression_encodings,
316        );
317
318        let request = t!(self.map_request_streaming(req));
319
320        let response = service
321            .call(request)
322            .await
323            .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
324
325        let compression_override = compression_override_from_response(&response);
326
327        self.map_response(
328            response,
329            accept_encoding,
330            compression_override,
331            self.max_encoding_message_size,
332        )
333    }
334
335    /// Handle a bi-directional streaming gRPC request.
336    pub async fn streaming<S, B>(
337        &mut self,
338        mut service: S,
339        req: http::Request<B>,
340    ) -> http::Response<Body>
341    where
342        S: StreamingService<T::Decode, Response = T::Encode> + Send,
343        S::ResponseStream: Send + 'static,
344        B: HttpBody + Send + 'static,
345        B::Error: Into<crate::BoxError> + Send,
346    {
347        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
348            req.headers(),
349            self.send_compression_encodings,
350        );
351
352        let request = t!(self.map_request_streaming(req));
353
354        let response = service.call(request).await;
355
356        self.map_response(
357            response,
358            accept_encoding,
359            SingleMessageCompressionOverride::default(),
360            self.max_encoding_message_size,
361        )
362    }
363
364    async fn map_request_unary<B>(
365        &mut self,
366        request: http::Request<B>,
367    ) -> Result<Request<T::Decode>, Status>
368    where
369        B: HttpBody + Send + 'static,
370        B::Error: Into<crate::BoxError> + Send,
371    {
372        let request_compression_encoding = self.request_encoding_if_supported(&request)?;
373
374        let (parts, body) = request.into_parts();
375
376        let mut stream = pin!(Streaming::new_request(
377            self.codec.decoder(),
378            body,
379            request_compression_encoding,
380            self.max_decoding_message_size,
381        ));
382
383        let message = stream
384            .try_next()
385            .await?
386            .ok_or_else(|| Status::internal("Missing request message."))?;
387
388        let mut req = Request::from_http_parts(parts, message);
389
390        if let Some(trailers) = stream.trailers().await? {
391            req.metadata_mut().merge(trailers);
392        }
393
394        Ok(req)
395    }
396
397    fn map_request_streaming<B>(
398        &mut self,
399        request: http::Request<B>,
400    ) -> Result<Request<Streaming<T::Decode>>, Status>
401    where
402        B: HttpBody + Send + 'static,
403        B::Error: Into<crate::BoxError> + Send,
404    {
405        let encoding = self.request_encoding_if_supported(&request)?;
406
407        let request = request.map(|body| {
408            Streaming::new_request(
409                self.codec.decoder(),
410                body,
411                encoding,
412                self.max_decoding_message_size,
413            )
414        });
415
416        Ok(Request::from_http(request))
417    }
418
419    fn map_response<B>(
420        &mut self,
421        response: Result<crate::Response<B>, Status>,
422        accept_encoding: Option<CompressionEncoding>,
423        compression_override: SingleMessageCompressionOverride,
424        max_message_size: Option<usize>,
425    ) -> http::Response<Body>
426    where
427        B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
428    {
429        let response = t!(response);
430
431        let (mut parts, body) = response.into_http().into_parts();
432
433        // Set the content type
434        parts
435            .headers
436            .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
437
438        #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
439        if let Some(encoding) = accept_encoding {
440            // Set the content encoding
441            parts.headers.insert(
442                crate::codec::compression::ENCODING_HEADER,
443                encoding.into_header_value(),
444            );
445        }
446
447        let body = EncodeBody::new_server(
448            self.codec.encoder(),
449            body,
450            accept_encoding,
451            compression_override,
452            max_message_size,
453        );
454
455        http::Response::from_parts(parts, Body::new(body))
456    }
457
458    fn request_encoding_if_supported<B>(
459        &self,
460        request: &http::Request<B>,
461    ) -> Result<Option<CompressionEncoding>, Status> {
462        CompressionEncoding::from_encoding_header(
463            request.headers(),
464            self.accept_compression_encodings,
465        )
466    }
467}
468
469impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
470    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471        f.debug_struct("Grpc")
472            .field("codec", &self.codec)
473            .field(
474                "accept_compression_encodings",
475                &self.accept_compression_encodings,
476            )
477            .field(
478                "send_compression_encodings",
479                &self.send_compression_encodings,
480            )
481            .finish()
482    }
483}
484
485fn compression_override_from_response<B, E>(
486    res: &Result<crate::Response<B>, E>,
487) -> SingleMessageCompressionOverride {
488    res.as_ref()
489        .ok()
490        .and_then(|response| {
491            response
492                .extensions()
493                .get::<SingleMessageCompressionOverride>()
494                .copied()
495        })
496        .unwrap_or_default()
497}