1use crate::{client::RpcClientInner, ClientRef};
2use alloy_json_rpc::{
3 transform_response, try_deserialize_ok, Id, Request, RequestPacket, ResponsePacket, RpcRecv,
4 RpcSend, SerializedRequest,
5};
6use alloy_primitives::map::HashMap;
7use alloy_transport::{
8 BoxTransport, TransportError, TransportErrorKind, TransportFut, TransportResult,
9};
10use futures::FutureExt;
11use pin_project::pin_project;
12use serde_json::value::RawValue;
13use std::{
14 borrow::Cow,
15 future::{Future, IntoFuture},
16 marker::PhantomData,
17 pin::Pin,
18 task::{
19 self, ready,
20 Poll::{self, Ready},
21 },
22};
23use tokio::sync::oneshot;
24use tower::Service;
25
26pub(crate) type Channel = oneshot::Sender<TransportResult<Box<RawValue>>>;
27pub(crate) type ChannelMap = HashMap<Id, Channel>;
28
29#[derive(Debug)]
32#[must_use = "A BatchRequest does nothing unless sent via `send_batch` and `.await`"]
33pub struct BatchRequest<'a> {
34 transport: ClientRef<'a>,
36
37 requests: RequestPacket,
39
40 channels: ChannelMap,
42}
43
44#[must_use = "A Waiter does nothing unless the corresponding BatchRequest is sent via `send_batch` and `.await`, AND the Waiter is awaited."]
46#[pin_project]
47#[derive(Debug)]
48pub struct Waiter<Resp, Output = Resp, Map = fn(Resp) -> Output> {
49 #[pin]
50 rx: oneshot::Receiver<TransportResult<Box<RawValue>>>,
51 map: Option<Map>,
52 _resp: PhantomData<fn() -> (Output, Resp)>,
53}
54
55impl<Resp, Output, Map> Waiter<Resp, Output, Map> {
56 pub fn map_resp<NewOutput, NewMap>(self, map: NewMap) -> Waiter<Resp, NewOutput, NewMap>
68 where
69 NewMap: FnOnce(Resp) -> NewOutput,
70 {
71 Waiter { rx: self.rx, map: Some(map), _resp: PhantomData }
72 }
73}
74
75impl<Resp> From<oneshot::Receiver<TransportResult<Box<RawValue>>>> for Waiter<Resp> {
76 fn from(rx: oneshot::Receiver<TransportResult<Box<RawValue>>>) -> Self {
77 Self { rx, map: Some(std::convert::identity), _resp: PhantomData }
78 }
79}
80
81impl<Resp, Output, Map> std::future::Future for Waiter<Resp, Output, Map>
82where
83 Resp: RpcRecv,
84 Map: FnOnce(Resp) -> Output,
85{
86 type Output = TransportResult<Output>;
87
88 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
89 let this = self.get_mut();
90
91 match ready!(this.rx.poll_unpin(cx)) {
92 Ok(resp) => {
93 let resp: Result<Resp, _> = try_deserialize_ok(resp);
94 Ready(resp.map(this.map.take().expect("polled after completion")))
95 }
96 Err(e) => Poll::Ready(Err(TransportErrorKind::custom(e))),
97 }
98 }
99}
100
101#[pin_project::pin_project(project = CallStateProj)]
102#[expect(unnameable_types, missing_debug_implementations)]
103pub enum BatchFuture {
104 Prepared {
105 transport: BoxTransport,
106 requests: RequestPacket,
107 channels: ChannelMap,
108 },
109 SerError(Option<TransportError>),
110 AwaitingResponse {
111 channels: ChannelMap,
112 #[pin]
113 fut: TransportFut<'static>,
114 },
115 Complete,
116}
117
118impl<'a> BatchRequest<'a> {
119 pub fn new(transport: &'a RpcClientInner) -> Self {
121 Self {
122 transport,
123 requests: RequestPacket::Batch(Vec::with_capacity(10)),
124 channels: HashMap::with_capacity_and_hasher(10, Default::default()),
125 }
126 }
127
128 fn push_raw(
129 &mut self,
130 request: SerializedRequest,
131 ) -> oneshot::Receiver<TransportResult<Box<RawValue>>> {
132 let (tx, rx) = oneshot::channel();
133 self.channels.insert(request.id().clone(), tx);
134 self.requests.push(request);
135 rx
136 }
137
138 fn push<Params: RpcSend, Resp: RpcRecv>(
139 &mut self,
140 request: Request<Params>,
141 ) -> TransportResult<Waiter<Resp>> {
142 let ser = request.serialize().map_err(TransportError::ser_err)?;
143 Ok(self.push_raw(ser).into())
144 }
145
146 pub fn add_call<Params: RpcSend, Resp: RpcRecv>(
152 &mut self,
153 method: impl Into<Cow<'static, str>>,
154 params: &Params,
155 ) -> TransportResult<Waiter<Resp>> {
156 let request = self.transport.make_request(method, Cow::Borrowed(params));
157 self.push(request)
158 }
159
160 pub fn send(self) -> BatchFuture {
162 BatchFuture::Prepared {
163 transport: self.transport.transport.clone(),
164 requests: self.requests,
165 channels: self.channels,
166 }
167 }
168}
169
170impl IntoFuture for BatchRequest<'_> {
171 type Output = <BatchFuture as Future>::Output;
172 type IntoFuture = BatchFuture;
173
174 fn into_future(self) -> Self::IntoFuture {
175 self.send()
176 }
177}
178
179impl BatchFuture {
180 fn poll_prepared(
181 mut self: Pin<&mut Self>,
182 cx: &mut task::Context<'_>,
183 ) -> Poll<<Self as Future>::Output> {
184 let CallStateProj::Prepared { transport, requests, channels } = self.as_mut().project()
185 else {
186 unreachable!("Called poll_prepared in incorrect state")
187 };
188
189 if let Err(e) = task::ready!(transport.poll_ready(cx)) {
190 self.set(Self::Complete);
191 return Poll::Ready(Err(e));
192 }
193
194 let channels = std::mem::take(channels);
197 let req = std::mem::replace(requests, RequestPacket::Batch(Vec::new()));
198
199 let fut = transport.call(req);
200 self.set(Self::AwaitingResponse { channels, fut });
201 cx.waker().wake_by_ref();
202 Poll::Pending
203 }
204
205 fn poll_awaiting_response(
206 mut self: Pin<&mut Self>,
207 cx: &mut task::Context<'_>,
208 ) -> Poll<<Self as Future>::Output> {
209 let CallStateProj::AwaitingResponse { channels, fut } = self.as_mut().project() else {
210 unreachable!("Called poll_awaiting_response in incorrect state")
211 };
212
213 let responses = match ready!(fut.poll(cx)) {
215 Ok(responses) => responses,
216 Err(e) => {
217 self.set(Self::Complete);
218 return Poll::Ready(Err(e));
219 }
220 };
221
222 match responses {
224 ResponsePacket::Single(single) => {
225 if let Some(tx) = channels.remove(&single.id) {
226 let _ = tx.send(transform_response(single));
227 }
228 }
229 ResponsePacket::Batch(responses) => {
230 for response in responses {
231 if let Some(tx) = channels.remove(&response.id) {
232 let _ = tx.send(transform_response(response));
233 }
234 }
235 }
236 }
237
238 for (id, tx) in channels.drain() {
241 let _ = tx.send(Err(TransportErrorKind::missing_batch_response(id)));
242 }
243
244 self.set(Self::Complete);
245 Poll::Ready(Ok(()))
246 }
247
248 fn poll_ser_error(
249 mut self: Pin<&mut Self>,
250 _cx: &mut task::Context<'_>,
251 ) -> Poll<<Self as Future>::Output> {
252 let e = if let CallStateProj::SerError(e) = self.as_mut().project() {
253 e.take().expect("no error")
254 } else {
255 unreachable!("Called poll_ser_error in incorrect state")
256 };
257
258 self.set(Self::Complete);
259 Poll::Ready(Err(e))
260 }
261}
262
263impl Future for BatchFuture {
264 type Output = TransportResult<()>;
265
266 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
267 if matches!(*self.as_mut(), Self::Prepared { .. }) {
268 return self.poll_prepared(cx);
269 }
270
271 if matches!(*self.as_mut(), Self::AwaitingResponse { .. }) {
272 return self.poll_awaiting_response(cx);
273 }
274
275 if matches!(*self.as_mut(), Self::SerError(_)) {
276 return self.poll_ser_error(cx);
277 }
278
279 panic!("Called poll on CallState in invalid state")
280 }
281}