1use crate::{TransportError, TransportFut};
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use tower::Service;
4
5pub trait DualTransportHandler<L, R> {
7 fn call(&self, request: RequestPacket, left: L, right: R) -> TransportFut<'static>;
9}
10
11impl<F, L, R> DualTransportHandler<L, R> for F
12where
13 F: Fn(RequestPacket, L, R) -> TransportFut<'static> + Send + Sync,
14{
15 fn call(&self, request: RequestPacket, left: L, right: R) -> TransportFut<'static> {
16 (self)(request, left, right)
17 }
18}
19
20#[derive(Debug, Clone)]
32pub struct DualTransport<L, R, H> {
33 left: L,
35 right: R,
37 handler: H,
39}
40
41impl<L, R, H> DualTransport<L, R, H> {
42 pub const fn new(left: L, right: R, handler: H) -> Self {
44 Self { left, right, handler }
45 }
46
47 pub const fn new_handler<F>(left: L, right: R, f: F) -> DualTransport<L, R, F>
49 where
50 F: Fn(RequestPacket, L, R) -> TransportFut<'static> + Send + Sync,
51 {
52 DualTransport { left, right, handler: f }
53 }
54}
55
56impl<L, R, H> Service<RequestPacket> for DualTransport<L, R, H>
57where
58 L: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
59 + Send
60 + Sync
61 + Clone
62 + 'static,
63 R: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
64 + Send
65 + Sync
66 + Clone
67 + 'static,
68 H: DualTransportHandler<L, R> + 'static,
69{
70 type Response = ResponsePacket;
71 type Error = TransportError;
72 type Future = TransportFut<'static>;
73
74 #[inline]
75 fn poll_ready(
76 &mut self,
77 cx: &mut std::task::Context<'_>,
78 ) -> std::task::Poll<Result<(), Self::Error>> {
79 match (self.left.poll_ready(cx), self.right.poll_ready(cx)) {
80 (std::task::Poll::Ready(Ok(())), std::task::Poll::Ready(Ok(()))) => {
81 std::task::Poll::Ready(Ok(()))
82 }
83 (std::task::Poll::Ready(Err(e)), _) => std::task::Poll::Ready(Err(e)),
84 (_, std::task::Poll::Ready(Err(e))) => std::task::Poll::Ready(Err(e)),
85 _ => std::task::Poll::Pending,
86 }
87 }
88
89 #[inline]
90 fn call(&mut self, req: RequestPacket) -> Self::Future {
91 self.handler.call(req, self.left.clone(), self.right.clone())
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use alloy_json_rpc::{Id, Request, Response, ResponsePayload};
99 use alloy_primitives::B256;
100 use serde_json::value::RawValue;
101 use std::task::{Context, Poll};
102
103 fn request_fn<T>(f: T) -> RequestFn<T>
105 where
106 T: FnMut(RequestPacket) -> TransportFut<'static>,
107 {
108 RequestFn { f }
109 }
110
111 #[derive(Copy, Clone)]
112 struct RequestFn<T> {
113 f: T,
114 }
115
116 impl<T> Service<RequestPacket> for RequestFn<T>
117 where
118 T: FnMut(RequestPacket) -> TransportFut<'static>,
119 {
120 type Response = ResponsePacket;
121 type Error = TransportError;
122 type Future = TransportFut<'static>;
123
124 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), TransportError>> {
125 Ok(()).into()
126 }
127
128 fn call(&mut self, req: RequestPacket) -> Self::Future {
129 (self.f)(req)
130 }
131 }
132
133 fn make_hash_response() -> ResponsePacket {
134 ResponsePacket::Single(Response {
135 id: Id::Number(0),
136 payload: ResponsePayload::Success(
137 RawValue::from_string(serde_json::to_string(&B256::ZERO).unwrap()).unwrap(),
138 ),
139 })
140 }
141
142 #[tokio::test]
143 async fn test_dual_transport() {
144 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
145
146 let left = request_fn(move |request: RequestPacket| {
147 let tx = tx.clone();
148 Box::pin(async move {
149 tx.send(request).unwrap();
150 Ok::<_, TransportError>(make_hash_response())
151 })
152 });
153
154 let right = request_fn(|_request: RequestPacket| {
155 Box::pin(async move { Ok::<_, TransportError>(make_hash_response()) })
156 });
157
158 let handler = |req: RequestPacket, mut left: RequestFn<_>, mut right: RequestFn<_>| {
159 let id = match &req {
160 RequestPacket::Single(req) => req.id().as_number().unwrap_or(0),
161 RequestPacket::Batch(reqs) => {
162 reqs.first().map(|r| r.id().as_number().unwrap_or(0)).unwrap_or(0)
163 }
164 };
165
166 if id % 2 == 0 {
167 left.call(req)
168 } else {
169 right.call(req)
170 }
171 };
172
173 let mut dual_transport = DualTransport::new(left, right, handler);
174
175 let req_even = RequestPacket::Single(
176 Request::new("test", Id::Number(2), None::<&'static RawValue>).try_into().unwrap(),
177 );
178 let _ = dual_transport.call(req_even.clone()).await.unwrap();
179
180 let received = rx.try_recv().unwrap();
181
182 match &received {
183 RequestPacket::Single(req) => assert_eq!(*req.id(), Id::Number(2)),
184 _ => panic!("Expected Single RequestPacket with id 2, but got something else"),
185 }
186
187 let req_odd = RequestPacket::Single(
188 Request::new("test", Id::Number(1), None::<&'static RawValue>)
189 .try_into()
190 .expect("Failed to serialize request"),
191 );
192 let _ = dual_transport.call(req_odd.clone()).await.unwrap();
193
194 assert!(rx.try_recv().is_err(), "Received unexpected request for odd ID");
195 }
196}