1use std::{
2 marker::PhantomData,
3 task::{Context, Poll},
4};
5
6use tower_layer::Layer;
7use tower_service::Service;
8
9use crate::server::NamedService;
10
11#[derive(Debug, Clone)]
13pub struct Layered<S, T> {
14 inner: S,
15 _ty: PhantomData<T>,
16}
17
18impl<S, T: NamedService> NamedService for Layered<S, T> {
19 const NAME: &'static str = T::NAME;
20}
21
22impl<Req, S, T> Service<Req> for Layered<S, T>
23where
24 S: Service<Req>,
25{
26 type Response = S::Response;
27 type Error = S::Error;
28 type Future = S::Future;
29
30 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
31 self.inner.poll_ready(cx)
32 }
33
34 fn call(&mut self, req: Req) -> Self::Future {
35 self.inner.call(req)
36 }
37}
38
39pub trait LayerExt<L>: sealed::Sealed {
41 fn named_layer<S>(&self, service: S) -> Layered<L::Service, S>
43 where
44 L: Layer<S>;
45}
46
47impl<L> LayerExt<L> for L {
48 fn named_layer<S>(&self, service: S) -> Layered<<L>::Service, S>
49 where
50 L: Layer<S>,
51 {
52 Layered {
53 inner: self.layer(service),
54 _ty: PhantomData,
55 }
56 }
57}
58
59mod sealed {
60 pub trait Sealed {}
61 impl<T> Sealed for T {}
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67
68 #[derive(Debug, Default)]
69 struct TestService {}
70
71 const TEST_SERVICE_NAME: &str = "test-service-name";
72
73 impl NamedService for TestService {
74 const NAME: &'static str = TEST_SERVICE_NAME;
75 }
76
77 fn get_name_of_named_service<S: NamedService>(_s: &S) -> &'static str {
79 S::NAME
80 }
81
82 #[test]
83 fn named_service_is_propagated_to_layered() {
84 use std::time::Duration;
85 use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
86
87 let layered = TimeoutLayer::new(Duration::from_secs(5)).named_layer(TestService::default());
88 assert_eq!(get_name_of_named_service(&layered), TEST_SERVICE_NAME);
89
90 let layered = ConcurrencyLimitLayer::new(3).named_layer(layered);
91 assert_eq!(get_name_of_named_service(&layered), TEST_SERVICE_NAME);
92 }
93}