tonic/service/
recover_error.rs

1//! Middleware which recovers from error.
2
3use std::{
4    fmt,
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use http::Response;
11use pin_project::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::Status;
16
17/// Layer which applies the [`RecoverError`] middleware.
18#[derive(Debug, Default, Clone)]
19pub struct RecoverErrorLayer {
20    _priv: (),
21}
22
23impl RecoverErrorLayer {
24    /// Create a new `RecoverErrorLayer`.
25    pub fn new() -> Self {
26        Self { _priv: () }
27    }
28}
29
30impl<S> Layer<S> for RecoverErrorLayer {
31    type Service = RecoverError<S>;
32
33    fn layer(&self, inner: S) -> Self::Service {
34        RecoverError::new(inner)
35    }
36}
37
38/// Middleware that attempts to recover from service errors by turning them into a response built
39/// from the `Status`.
40#[derive(Debug, Clone)]
41pub struct RecoverError<S> {
42    inner: S,
43}
44
45impl<S> RecoverError<S> {
46    /// Create a new `RecoverError` middleware.
47    pub fn new(inner: S) -> Self {
48        Self { inner }
49    }
50}
51
52impl<S, Req, ResBody> Service<Req> for RecoverError<S>
53where
54    S: Service<Req, Response = Response<ResBody>>,
55    S::Error: Into<crate::BoxError>,
56{
57    type Response = Response<ResponseBody<ResBody>>;
58    type Error = crate::BoxError;
59    type Future = ResponseFuture<S::Future>;
60
61    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.inner.poll_ready(cx).map_err(Into::into)
63    }
64
65    fn call(&mut self, req: Req) -> Self::Future {
66        ResponseFuture {
67            inner: self.inner.call(req),
68        }
69    }
70}
71
72/// Response future for [`RecoverError`].
73#[pin_project]
74pub struct ResponseFuture<F> {
75    #[pin]
76    inner: F,
77}
78
79impl<F> fmt::Debug for ResponseFuture<F> {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.debug_struct("ResponseFuture").finish()
82    }
83}
84
85impl<F, E, ResBody> Future for ResponseFuture<F>
86where
87    F: Future<Output = Result<Response<ResBody>, E>>,
88    E: Into<crate::BoxError>,
89{
90    type Output = Result<Response<ResponseBody<ResBody>>, crate::BoxError>;
91
92    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
93        match ready!(self.project().inner.poll(cx)) {
94            Ok(response) => {
95                let response = response.map(ResponseBody::full);
96                Poll::Ready(Ok(response))
97            }
98            Err(err) => match Status::try_from_error(err.into()) {
99                Ok(status) => {
100                    let (parts, ()) = status.into_http::<()>().into_parts();
101                    let res = Response::from_parts(parts, ResponseBody::empty());
102                    Poll::Ready(Ok(res))
103                }
104                Err(err) => Poll::Ready(Err(err)),
105            },
106        }
107    }
108}
109
110/// Response body for [`RecoverError`].
111#[pin_project]
112pub struct ResponseBody<B> {
113    #[pin]
114    inner: Option<B>,
115}
116
117impl<B> fmt::Debug for ResponseBody<B> {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        f.debug_struct("ResponseBody").finish()
120    }
121}
122
123impl<B> ResponseBody<B> {
124    fn full(inner: B) -> Self {
125        Self { inner: Some(inner) }
126    }
127
128    const fn empty() -> Self {
129        Self { inner: None }
130    }
131}
132
133impl<B> http_body::Body for ResponseBody<B>
134where
135    B: http_body::Body,
136{
137    type Data = B::Data;
138    type Error = B::Error;
139
140    fn poll_frame(
141        self: Pin<&mut Self>,
142        cx: &mut Context<'_>,
143    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
144        match self.project().inner.as_pin_mut() {
145            Some(b) => b.poll_frame(cx),
146            None => Poll::Ready(None),
147        }
148    }
149
150    fn is_end_stream(&self) -> bool {
151        match &self.inner {
152            Some(b) => b.is_end_stream(),
153            None => true,
154        }
155    }
156
157    fn size_hint(&self) -> http_body::SizeHint {
158        match &self.inner {
159            Some(body) => body.size_hint(),
160            None => http_body::SizeHint::with_exact(0),
161        }
162    }
163}