1use crate::{request::SanitizeHeaders, Status};
6use pin_project::pin_project;
7use std::{
8 fmt,
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16pub trait Interceptor {
42 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
44}
45
46impl<F> Interceptor for F
47where
48 F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
49{
50 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
51 self(request)
52 }
53}
54
55#[derive(Debug, Clone, Copy)]
59pub struct InterceptorLayer<I> {
60 interceptor: I,
61}
62
63impl<I> InterceptorLayer<I> {
64 pub fn new(interceptor: I) -> Self {
68 Self { interceptor }
69 }
70}
71
72impl<S, I> Layer<S> for InterceptorLayer<I>
73where
74 I: Clone,
75{
76 type Service = InterceptedService<S, I>;
77
78 fn layer(&self, service: S) -> Self::Service {
79 InterceptedService::new(service, self.interceptor.clone())
80 }
81}
82
83#[derive(Clone, Copy)]
87pub struct InterceptedService<S, I> {
88 inner: S,
89 interceptor: I,
90}
91
92impl<S, I> InterceptedService<S, I> {
93 pub fn new(service: S, interceptor: I) -> Self {
96 Self {
97 inner: service,
98 interceptor,
99 }
100 }
101}
102
103impl<S, I> fmt::Debug for InterceptedService<S, I>
104where
105 S: fmt::Debug,
106{
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("InterceptedService")
109 .field("inner", &self.inner)
110 .field("f", &format_args!("{}", std::any::type_name::<I>()))
111 .finish()
112 }
113}
114
115impl<S, I, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, I>
116where
117 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
118 I: Interceptor,
119{
120 type Response = http::Response<ResponseBody<ResBody>>;
121 type Error = S::Error;
122 type Future = ResponseFuture<S::Future>;
123
124 #[inline]
125 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126 self.inner.poll_ready(cx)
127 }
128
129 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
130 let uri = req.uri().clone();
136 let method = req.method().clone();
137 let version = req.version();
138 let req = crate::Request::from_http(req);
139 let (metadata, extensions, msg) = req.into_parts();
140
141 match self
142 .interceptor
143 .call(crate::Request::from_parts(metadata, extensions, ()))
144 {
145 Ok(req) => {
146 let (metadata, extensions, _) = req.into_parts();
147 let req = crate::Request::from_parts(metadata, extensions, msg);
148 let req = req.into_http(uri, method, version, SanitizeHeaders::No);
149 ResponseFuture::future(self.inner.call(req))
150 }
151 Err(status) => ResponseFuture::status(status),
152 }
153 }
154}
155
156impl<S, I> crate::server::NamedService for InterceptedService<S, I>
158where
159 S: crate::server::NamedService,
160{
161 const NAME: &'static str = S::NAME;
162}
163
164#[pin_project]
166#[derive(Debug)]
167pub struct ResponseFuture<F> {
168 #[pin]
169 kind: Kind<F>,
170}
171
172impl<F> ResponseFuture<F> {
173 fn future(future: F) -> Self {
174 Self {
175 kind: Kind::Future(future),
176 }
177 }
178
179 fn status(status: Status) -> Self {
180 Self {
181 kind: Kind::Status(Some(status)),
182 }
183 }
184}
185
186#[pin_project(project = KindProj)]
187#[derive(Debug)]
188enum Kind<F> {
189 Future(#[pin] F),
190 Status(Option<Status>),
191}
192
193impl<F, E, B> Future for ResponseFuture<F>
194where
195 F: Future<Output = Result<http::Response<B>, E>>,
196{
197 type Output = Result<http::Response<ResponseBody<B>>, E>;
198
199 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
200 match self.project().kind.project() {
201 KindProj::Future(future) => future.poll(cx).map_ok(|res| res.map(ResponseBody::wrap)),
202 KindProj::Status(status) => {
203 let (parts, ()) = status.take().unwrap().into_http::<()>().into_parts();
204 let response = http::Response::from_parts(parts, ResponseBody::<B>::empty());
205 Poll::Ready(Ok(response))
206 }
207 }
208 }
209}
210
211#[pin_project]
213#[derive(Debug)]
214pub struct ResponseBody<B> {
215 #[pin]
216 kind: ResponseBodyKind<B>,
217}
218
219#[pin_project(project = ResponseBodyKindProj)]
220#[derive(Debug)]
221enum ResponseBodyKind<B> {
222 Empty,
223 Wrap(#[pin] B),
224}
225
226impl<B> ResponseBody<B> {
227 fn new(kind: ResponseBodyKind<B>) -> Self {
228 Self { kind }
229 }
230
231 fn empty() -> Self {
232 Self::new(ResponseBodyKind::Empty)
233 }
234
235 fn wrap(body: B) -> Self {
236 Self::new(ResponseBodyKind::Wrap(body))
237 }
238}
239
240impl<B: http_body::Body> http_body::Body for ResponseBody<B> {
241 type Data = B::Data;
242 type Error = B::Error;
243
244 fn poll_frame(
245 self: Pin<&mut Self>,
246 cx: &mut Context<'_>,
247 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
248 match self.project().kind.project() {
249 ResponseBodyKindProj::Empty => Poll::Ready(None),
250 ResponseBodyKindProj::Wrap(body) => body.poll_frame(cx),
251 }
252 }
253
254 fn size_hint(&self) -> http_body::SizeHint {
255 match &self.kind {
256 ResponseBodyKind::Empty => http_body::SizeHint::with_exact(0),
257 ResponseBodyKind::Wrap(body) => body.size_hint(),
258 }
259 }
260
261 fn is_end_stream(&self) -> bool {
262 match &self.kind {
263 ResponseBodyKind::Empty => true,
264 ResponseBodyKind::Wrap(body) => body.is_end_stream(),
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use tower::ServiceExt;
273
274 #[tokio::test]
275 async fn doesnt_remove_headers_from_requests() {
276 let svc = tower::service_fn(|request: http::Request<()>| async move {
277 assert_eq!(
278 request
279 .headers()
280 .get("user-agent")
281 .expect("missing in leaf service"),
282 "test-tonic"
283 );
284
285 Ok::<_, Status>(http::Response::new(()))
286 });
287
288 let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
289 assert_eq!(
290 request
291 .metadata()
292 .get("user-agent")
293 .expect("missing in interceptor"),
294 "test-tonic"
295 );
296
297 Ok(request)
298 });
299
300 let request = http::Request::builder()
301 .header("user-agent", "test-tonic")
302 .body(())
303 .unwrap();
304
305 svc.oneshot(request).await.unwrap();
306 }
307
308 #[tokio::test]
309 async fn handles_intercepted_status_as_response() {
310 let message = "Blocked by the interceptor";
311 let expected = Status::permission_denied(message).into_http::<()>();
312
313 let svc = tower::service_fn(|_: http::Request<()>| async {
314 Ok::<_, Status>(http::Response::new(()))
315 });
316
317 let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
318 Err(Status::permission_denied(message))
319 });
320
321 let request = http::Request::builder().body(()).unwrap();
322 let response = svc.oneshot(request).await.unwrap();
323
324 assert_eq!(expected.status(), response.status());
325 assert_eq!(expected.version(), response.version());
326 assert_eq!(expected.headers(), response.headers());
327 }
328
329 #[tokio::test]
330 async fn doesnt_change_http_method() {
331 let svc = tower::service_fn(|request: http::Request<()>| async move {
332 assert_eq!(request.method(), http::Method::OPTIONS);
333
334 Ok::<_, hyper::Error>(hyper::Response::new(()))
335 });
336
337 let svc = InterceptedService::new(svc, Ok);
338
339 let request = http::Request::builder()
340 .method(http::Method::OPTIONS)
341 .body(())
342 .unwrap();
343
344 svc.oneshot(request).await.unwrap();
345 }
346}