alloy_transport/
dual.rs

1use crate::{TransportError, TransportFut};
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use tower::Service;
4
5/// Trait that determines how to dispatch a request given two transports.
6pub trait DualTransportHandler<L, R> {
7    /// The type of the future returned by the transport.
8    fn call(&self, request: RequestPacket, left: L, right: R) -> TransportFut<'static>;
9}
10
11impl<F, L, R> DualTransportHandler<L, R> for F
12where
13    F: Fn(RequestPacket, L, R) -> TransportFut<'static> + Send + Sync,
14{
15    fn call(&self, request: RequestPacket, left: L, right: R) -> TransportFut<'static> {
16        (self)(request, left, right)
17    }
18}
19
20/// A transport that dispatches requests to one of two inner transports based on a handler.
21///
22/// This type allows RPC clients to dynamically select between two different transports
23/// at runtime depending on the request. It is [Send] + [`Sync` ] and implements Transport
24/// via the [`Service`] trait.
25///
26/// This is useful for building clients that abstract over multiple backends or protocols,
27/// routing requests flexibly without having to commit to a single transport implementation.
28///
29/// All higher-level types can use  [`DualTransport`] internally to support multiple transport
30/// strategies.
31#[derive(Debug, Clone)]
32pub struct DualTransport<L, R, H> {
33    /// The left transport.
34    left: L,
35    /// The right transport.
36    right: R,
37    /// The handler that decides which transport to use.
38    handler: H,
39}
40
41impl<L, R, H> DualTransport<L, R, H> {
42    /// Instantiate a new dual transport from a suitable transport.
43    pub const fn new(left: L, right: R, handler: H) -> Self {
44        Self { left, right, handler }
45    }
46
47    /// Create a new dual transport with a function handler.
48    pub const fn new_handler<F>(left: L, right: R, f: F) -> DualTransport<L, R, F>
49    where
50        F: Fn(RequestPacket, L, R) -> TransportFut<'static> + Send + Sync,
51    {
52        DualTransport { left, right, handler: f }
53    }
54}
55
56impl<L, R, H> Service<RequestPacket> for DualTransport<L, R, H>
57where
58    L: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
59        + Send
60        + Sync
61        + Clone
62        + 'static,
63    R: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
64        + Send
65        + Sync
66        + Clone
67        + 'static,
68    H: DualTransportHandler<L, R> + 'static,
69{
70    type Response = ResponsePacket;
71    type Error = TransportError;
72    type Future = TransportFut<'static>;
73
74    #[inline]
75    fn poll_ready(
76        &mut self,
77        cx: &mut std::task::Context<'_>,
78    ) -> std::task::Poll<Result<(), Self::Error>> {
79        match (self.left.poll_ready(cx), self.right.poll_ready(cx)) {
80            (std::task::Poll::Ready(Ok(())), std::task::Poll::Ready(Ok(()))) => {
81                std::task::Poll::Ready(Ok(()))
82            }
83            (std::task::Poll::Ready(Err(e)), _) => std::task::Poll::Ready(Err(e)),
84            (_, std::task::Poll::Ready(Err(e))) => std::task::Poll::Ready(Err(e)),
85            _ => std::task::Poll::Pending,
86        }
87    }
88
89    #[inline]
90    fn call(&mut self, req: RequestPacket) -> Self::Future {
91        self.handler.call(req, self.left.clone(), self.right.clone())
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use alloy_json_rpc::{Id, Request, Response, ResponsePayload};
99    use alloy_primitives::B256;
100    use serde_json::value::RawValue;
101    use std::task::{Context, Poll};
102
103    /// Helper function that transforms a closure to a alloy transport service
104    fn request_fn<T>(f: T) -> RequestFn<T>
105    where
106        T: FnMut(RequestPacket) -> TransportFut<'static>,
107    {
108        RequestFn { f }
109    }
110
111    #[derive(Copy, Clone)]
112    struct RequestFn<T> {
113        f: T,
114    }
115
116    impl<T> Service<RequestPacket> for RequestFn<T>
117    where
118        T: FnMut(RequestPacket) -> TransportFut<'static>,
119    {
120        type Response = ResponsePacket;
121        type Error = TransportError;
122        type Future = TransportFut<'static>;
123
124        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), TransportError>> {
125            Ok(()).into()
126        }
127
128        fn call(&mut self, req: RequestPacket) -> Self::Future {
129            (self.f)(req)
130        }
131    }
132
133    fn make_hash_response() -> ResponsePacket {
134        ResponsePacket::Single(Response {
135            id: Id::Number(0),
136            payload: ResponsePayload::Success(
137                RawValue::from_string(serde_json::to_string(&B256::ZERO).unwrap()).unwrap(),
138            ),
139        })
140    }
141
142    #[tokio::test]
143    async fn test_dual_transport() {
144        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
145
146        let left = request_fn(move |request: RequestPacket| {
147            let tx = tx.clone();
148            Box::pin(async move {
149                tx.send(request).unwrap();
150                Ok::<_, TransportError>(make_hash_response())
151            })
152        });
153
154        let right = request_fn(|_request: RequestPacket| {
155            Box::pin(async move { Ok::<_, TransportError>(make_hash_response()) })
156        });
157
158        let handler = |req: RequestPacket, mut left: RequestFn<_>, mut right: RequestFn<_>| {
159            let id = match &req {
160                RequestPacket::Single(req) => req.id().as_number().unwrap_or(0),
161                RequestPacket::Batch(reqs) => {
162                    reqs.first().map(|r| r.id().as_number().unwrap_or(0)).unwrap_or(0)
163                }
164            };
165
166            if id % 2 == 0 {
167                left.call(req)
168            } else {
169                right.call(req)
170            }
171        };
172
173        let mut dual_transport = DualTransport::new(left, right, handler);
174
175        let req_even = RequestPacket::Single(
176            Request::new("test", Id::Number(2), None::<&'static RawValue>).try_into().unwrap(),
177        );
178        let _ = dual_transport.call(req_even.clone()).await.unwrap();
179
180        let received = rx.try_recv().unwrap();
181
182        match &received {
183            RequestPacket::Single(req) => assert_eq!(*req.id(), Id::Number(2)),
184            _ => panic!("Expected Single RequestPacket with id 2, but got something else"),
185        }
186
187        let req_odd = RequestPacket::Single(
188            Request::new("test", Id::Number(1), None::<&'static RawValue>)
189                .try_into()
190                .expect("Failed to serialize request"),
191        );
192        let _ = dual_transport.call(req_odd.clone()).await.unwrap();
193
194        assert!(rx.try_recv().is_err(), "Received unexpected request for odd ID");
195    }
196}