tonic/service/
interceptor.rs

1//! gRPC interceptors which are a kind of middleware.
2//!
3//! See [`Interceptor`] for more details.
4
5use crate::{request::SanitizeHeaders, Status};
6use pin_project::pin_project;
7use std::{
8    fmt,
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16/// A gRPC interceptor.
17///
18/// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
19/// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
20/// request. Two, cancel a request with a `Status`.
21///
22/// Any function that satisfies the bound `FnMut(Request<()>) -> Result<Request<()>, Status>` can be
23/// used as an `Interceptor`.
24///
25/// An interceptor can be used on both the server and client side through the `tonic-build` crate's
26/// generated structs.
27///
28/// See the [interceptor example][example] for more details.
29///
30/// If you need more powerful middleware, [tower] is the recommended approach. You can find
31/// examples of how to use tower with tonic [here][tower-example].
32///
33/// Additionally, interceptors is not the recommended way to add logging to your service. For that
34/// a [tower] middleware is more appropriate since it can also act on the response. For example
35/// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
36/// middleware supports gRPC out of the box.
37///
38/// [tower]: https://crates.io/crates/tower
39/// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
40/// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
41pub trait Interceptor {
42    /// Intercept a request before it is sent, optionally cancelling it.
43    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
44}
45
46impl<F> Interceptor for F
47where
48    F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
49{
50    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
51        self(request)
52    }
53}
54
55/// A gRPC interceptor that can be used as a [`Layer`],
56///
57/// See [`Interceptor`] for more details.
58#[derive(Debug, Clone, Copy)]
59pub struct InterceptorLayer<I> {
60    interceptor: I,
61}
62
63impl<I> InterceptorLayer<I> {
64    /// Create a new interceptor layer.
65    ///
66    /// See [`Interceptor`] for more details.
67    pub fn new(interceptor: I) -> Self {
68        Self { interceptor }
69    }
70}
71
72impl<S, I> Layer<S> for InterceptorLayer<I>
73where
74    I: Clone,
75{
76    type Service = InterceptedService<S, I>;
77
78    fn layer(&self, service: S) -> Self::Service {
79        InterceptedService::new(service, self.interceptor.clone())
80    }
81}
82
83/// A service wrapped in an interceptor middleware.
84///
85/// See [`Interceptor`] for more details.
86#[derive(Clone, Copy)]
87pub struct InterceptedService<S, I> {
88    inner: S,
89    interceptor: I,
90}
91
92impl<S, I> InterceptedService<S, I> {
93    /// Create a new `InterceptedService` that wraps `S` and intercepts each request with the
94    /// function `F`.
95    pub fn new(service: S, interceptor: I) -> Self {
96        Self {
97            inner: service,
98            interceptor,
99        }
100    }
101}
102
103impl<S, I> fmt::Debug for InterceptedService<S, I>
104where
105    S: fmt::Debug,
106{
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("InterceptedService")
109            .field("inner", &self.inner)
110            .field("f", &format_args!("{}", std::any::type_name::<I>()))
111            .finish()
112    }
113}
114
115impl<S, I, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, I>
116where
117    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
118    I: Interceptor,
119{
120    type Response = http::Response<ResponseBody<ResBody>>;
121    type Error = S::Error;
122    type Future = ResponseFuture<S::Future>;
123
124    #[inline]
125    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126        self.inner.poll_ready(cx)
127    }
128
129    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
130        // It is bad practice to modify the body (i.e. Message) of the request via an interceptor.
131        // To avoid exposing the body of the request to the interceptor function, we first remove it
132        // here, allow the interceptor to modify the metadata and extensions, and then recreate the
133        // HTTP request with the body. Tonic requests do not preserve the URI, HTTP version, and
134        // HTTP method of the HTTP request, so we extract them here and then add them back in below.
135        let uri = req.uri().clone();
136        let method = req.method().clone();
137        let version = req.version();
138        let req = crate::Request::from_http(req);
139        let (metadata, extensions, msg) = req.into_parts();
140
141        match self
142            .interceptor
143            .call(crate::Request::from_parts(metadata, extensions, ()))
144        {
145            Ok(req) => {
146                let (metadata, extensions, _) = req.into_parts();
147                let req = crate::Request::from_parts(metadata, extensions, msg);
148                let req = req.into_http(uri, method, version, SanitizeHeaders::No);
149                ResponseFuture::future(self.inner.call(req))
150            }
151            Err(status) => ResponseFuture::status(status),
152        }
153    }
154}
155
156// required to use `InterceptedService` with `Router`
157impl<S, I> crate::server::NamedService for InterceptedService<S, I>
158where
159    S: crate::server::NamedService,
160{
161    const NAME: &'static str = S::NAME;
162}
163
164/// Response future for [`InterceptedService`].
165#[pin_project]
166#[derive(Debug)]
167pub struct ResponseFuture<F> {
168    #[pin]
169    kind: Kind<F>,
170}
171
172impl<F> ResponseFuture<F> {
173    fn future(future: F) -> Self {
174        Self {
175            kind: Kind::Future(future),
176        }
177    }
178
179    fn status(status: Status) -> Self {
180        Self {
181            kind: Kind::Status(Some(status)),
182        }
183    }
184}
185
186#[pin_project(project = KindProj)]
187#[derive(Debug)]
188enum Kind<F> {
189    Future(#[pin] F),
190    Status(Option<Status>),
191}
192
193impl<F, E, B> Future for ResponseFuture<F>
194where
195    F: Future<Output = Result<http::Response<B>, E>>,
196{
197    type Output = Result<http::Response<ResponseBody<B>>, E>;
198
199    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
200        match self.project().kind.project() {
201            KindProj::Future(future) => future.poll(cx).map_ok(|res| res.map(ResponseBody::wrap)),
202            KindProj::Status(status) => {
203                let (parts, ()) = status.take().unwrap().into_http::<()>().into_parts();
204                let response = http::Response::from_parts(parts, ResponseBody::<B>::empty());
205                Poll::Ready(Ok(response))
206            }
207        }
208    }
209}
210
211/// Response body for [`InterceptedService`].
212#[pin_project]
213#[derive(Debug)]
214pub struct ResponseBody<B> {
215    #[pin]
216    kind: ResponseBodyKind<B>,
217}
218
219#[pin_project(project = ResponseBodyKindProj)]
220#[derive(Debug)]
221enum ResponseBodyKind<B> {
222    Empty,
223    Wrap(#[pin] B),
224}
225
226impl<B> ResponseBody<B> {
227    fn new(kind: ResponseBodyKind<B>) -> Self {
228        Self { kind }
229    }
230
231    fn empty() -> Self {
232        Self::new(ResponseBodyKind::Empty)
233    }
234
235    fn wrap(body: B) -> Self {
236        Self::new(ResponseBodyKind::Wrap(body))
237    }
238}
239
240impl<B: http_body::Body> http_body::Body for ResponseBody<B> {
241    type Data = B::Data;
242    type Error = B::Error;
243
244    fn poll_frame(
245        self: Pin<&mut Self>,
246        cx: &mut Context<'_>,
247    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
248        match self.project().kind.project() {
249            ResponseBodyKindProj::Empty => Poll::Ready(None),
250            ResponseBodyKindProj::Wrap(body) => body.poll_frame(cx),
251        }
252    }
253
254    fn size_hint(&self) -> http_body::SizeHint {
255        match &self.kind {
256            ResponseBodyKind::Empty => http_body::SizeHint::with_exact(0),
257            ResponseBodyKind::Wrap(body) => body.size_hint(),
258        }
259    }
260
261    fn is_end_stream(&self) -> bool {
262        match &self.kind {
263            ResponseBodyKind::Empty => true,
264            ResponseBodyKind::Wrap(body) => body.is_end_stream(),
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use tower::ServiceExt;
273
274    #[tokio::test]
275    async fn doesnt_remove_headers_from_requests() {
276        let svc = tower::service_fn(|request: http::Request<()>| async move {
277            assert_eq!(
278                request
279                    .headers()
280                    .get("user-agent")
281                    .expect("missing in leaf service"),
282                "test-tonic"
283            );
284
285            Ok::<_, Status>(http::Response::new(()))
286        });
287
288        let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
289            assert_eq!(
290                request
291                    .metadata()
292                    .get("user-agent")
293                    .expect("missing in interceptor"),
294                "test-tonic"
295            );
296
297            Ok(request)
298        });
299
300        let request = http::Request::builder()
301            .header("user-agent", "test-tonic")
302            .body(())
303            .unwrap();
304
305        svc.oneshot(request).await.unwrap();
306    }
307
308    #[tokio::test]
309    async fn handles_intercepted_status_as_response() {
310        let message = "Blocked by the interceptor";
311        let expected = Status::permission_denied(message).into_http::<()>();
312
313        let svc = tower::service_fn(|_: http::Request<()>| async {
314            Ok::<_, Status>(http::Response::new(()))
315        });
316
317        let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
318            Err(Status::permission_denied(message))
319        });
320
321        let request = http::Request::builder().body(()).unwrap();
322        let response = svc.oneshot(request).await.unwrap();
323
324        assert_eq!(expected.status(), response.status());
325        assert_eq!(expected.version(), response.version());
326        assert_eq!(expected.headers(), response.headers());
327    }
328
329    #[tokio::test]
330    async fn doesnt_change_http_method() {
331        let svc = tower::service_fn(|request: http::Request<()>| async move {
332            assert_eq!(request.method(), http::Method::OPTIONS);
333
334            Ok::<_, hyper::Error>(hyper::Response::new(()))
335        });
336
337        let svc = InterceptedService::new(svc, Ok);
338
339        let request = http::Request::builder()
340            .method(http::Method::OPTIONS)
341            .body(())
342            .unwrap();
343
344        svc.oneshot(request).await.unwrap();
345    }
346}