tonic/service/
layered.rs

1use std::{
2    marker::PhantomData,
3    task::{Context, Poll},
4};
5
6use tower_layer::Layer;
7use tower_service::Service;
8
9use crate::server::NamedService;
10
11/// A layered service to propagate [`NamedService`] implementation.
12#[derive(Debug, Clone)]
13pub struct Layered<S, T> {
14    inner: S,
15    _ty: PhantomData<T>,
16}
17
18impl<S, T: NamedService> NamedService for Layered<S, T> {
19    const NAME: &'static str = T::NAME;
20}
21
22impl<Req, S, T> Service<Req> for Layered<S, T>
23where
24    S: Service<Req>,
25{
26    type Response = S::Response;
27    type Error = S::Error;
28    type Future = S::Future;
29
30    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
31        self.inner.poll_ready(cx)
32    }
33
34    fn call(&mut self, req: Req) -> Self::Future {
35        self.inner.call(req)
36    }
37}
38
39/// Extension trait which adds utility methods to types which implement [`tower_layer::Layer`].
40pub trait LayerExt<L>: sealed::Sealed {
41    /// Applies the layer to a service and wraps it in [`Layered`].
42    fn named_layer<S>(&self, service: S) -> Layered<L::Service, S>
43    where
44        L: Layer<S>;
45}
46
47impl<L> LayerExt<L> for L {
48    fn named_layer<S>(&self, service: S) -> Layered<<L>::Service, S>
49    where
50        L: Layer<S>,
51    {
52        Layered {
53            inner: self.layer(service),
54            _ty: PhantomData,
55        }
56    }
57}
58
59mod sealed {
60    pub trait Sealed {}
61    impl<T> Sealed for T {}
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[derive(Debug, Default)]
69    struct TestService {}
70
71    const TEST_SERVICE_NAME: &str = "test-service-name";
72
73    impl NamedService for TestService {
74        const NAME: &'static str = TEST_SERVICE_NAME;
75    }
76
77    // Checks if the argument implements `NamedService` and returns the implemented `NAME`.
78    fn get_name_of_named_service<S: NamedService>(_s: &S) -> &'static str {
79        S::NAME
80    }
81
82    #[test]
83    fn named_service_is_propagated_to_layered() {
84        use std::time::Duration;
85        use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
86
87        let layered = TimeoutLayer::new(Duration::from_secs(5)).named_layer(TestService::default());
88        assert_eq!(get_name_of_named_service(&layered), TEST_SERVICE_NAME);
89
90        let layered = ConcurrencyLimitLayer::new(3).named_layer(layered);
91        assert_eq!(get_name_of_named_service(&layered), TEST_SERVICE_NAME);
92    }
93}