alloy_rpc_client/
batch.rs

1use crate::{client::RpcClientInner, ClientRef};
2use alloy_json_rpc::{
3    transform_response, try_deserialize_ok, Id, Request, RequestPacket, ResponsePacket, RpcRecv,
4    RpcSend, SerializedRequest,
5};
6use alloy_primitives::map::HashMap;
7use alloy_transport::{
8    BoxTransport, TransportError, TransportErrorKind, TransportFut, TransportResult,
9};
10use futures::FutureExt;
11use pin_project::pin_project;
12use serde_json::value::RawValue;
13use std::{
14    borrow::Cow,
15    future::{Future, IntoFuture},
16    marker::PhantomData,
17    pin::Pin,
18    task::{
19        self, ready,
20        Poll::{self, Ready},
21    },
22};
23use tokio::sync::oneshot;
24use tower::Service;
25
26pub(crate) type Channel = oneshot::Sender<TransportResult<Box<RawValue>>>;
27pub(crate) type ChannelMap = HashMap<Id, Channel>;
28
29/// A batch JSON-RPC request, used to bundle requests into a single transport
30/// call.
31#[derive(Debug)]
32#[must_use = "A BatchRequest does nothing unless sent via `send_batch` and `.await`"]
33pub struct BatchRequest<'a> {
34    /// The transport via which the batch will be sent.
35    transport: ClientRef<'a>,
36
37    /// The requests to be sent.
38    requests: RequestPacket,
39
40    /// The channels to send the responses through.
41    channels: ChannelMap,
42}
43
44/// Awaits a single response for a request that has been included in a batch.
45#[must_use = "A Waiter does nothing unless the corresponding BatchRequest is sent via `send_batch` and `.await`, AND the Waiter is awaited."]
46#[pin_project]
47#[derive(Debug)]
48pub struct Waiter<Resp, Output = Resp, Map = fn(Resp) -> Output> {
49    #[pin]
50    rx: oneshot::Receiver<TransportResult<Box<RawValue>>>,
51    map: Option<Map>,
52    _resp: PhantomData<fn() -> (Output, Resp)>,
53}
54
55impl<Resp, Output, Map> Waiter<Resp, Output, Map> {
56    /// Map the response to a different type. This is usable for converting
57    /// the response to a more usable type, e.g. changing `U64` to `u64`.
58    ///
59    /// ## Note
60    ///
61    /// Carefully review the rust documentation on [fn pointers] before passing
62    /// them to this function. Unless the pointer is specifically coerced to a
63    /// `fn(_) -> _`, the `NewMap` will be inferred as that function's unique
64    /// type. This can lead to confusing error messages.
65    ///
66    /// [fn pointers]: https://doc.rust-lang.org/std/primitive.fn.html#creating-function-pointers
67    pub fn map_resp<NewOutput, NewMap>(self, map: NewMap) -> Waiter<Resp, NewOutput, NewMap>
68    where
69        NewMap: FnOnce(Resp) -> NewOutput,
70    {
71        Waiter { rx: self.rx, map: Some(map), _resp: PhantomData }
72    }
73}
74
75impl<Resp> From<oneshot::Receiver<TransportResult<Box<RawValue>>>> for Waiter<Resp> {
76    fn from(rx: oneshot::Receiver<TransportResult<Box<RawValue>>>) -> Self {
77        Self { rx, map: Some(std::convert::identity), _resp: PhantomData }
78    }
79}
80
81impl<Resp, Output, Map> std::future::Future for Waiter<Resp, Output, Map>
82where
83    Resp: RpcRecv,
84    Map: FnOnce(Resp) -> Output,
85{
86    type Output = TransportResult<Output>;
87
88    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
89        let this = self.get_mut();
90
91        match ready!(this.rx.poll_unpin(cx)) {
92            Ok(resp) => {
93                let resp: Result<Resp, _> = try_deserialize_ok(resp);
94                Ready(resp.map(this.map.take().expect("polled after completion")))
95            }
96            Err(e) => Poll::Ready(Err(TransportErrorKind::custom(e))),
97        }
98    }
99}
100
101#[pin_project::pin_project(project = CallStateProj)]
102#[expect(unnameable_types, missing_debug_implementations)]
103pub enum BatchFuture {
104    Prepared {
105        transport: BoxTransport,
106        requests: RequestPacket,
107        channels: ChannelMap,
108    },
109    SerError(Option<TransportError>),
110    AwaitingResponse {
111        channels: ChannelMap,
112        #[pin]
113        fut: TransportFut<'static>,
114    },
115    Complete,
116}
117
118impl<'a> BatchRequest<'a> {
119    /// Create a new batch request.
120    pub fn new(transport: &'a RpcClientInner) -> Self {
121        Self {
122            transport,
123            requests: RequestPacket::Batch(Vec::with_capacity(10)),
124            channels: HashMap::with_capacity_and_hasher(10, Default::default()),
125        }
126    }
127
128    fn push_raw(
129        &mut self,
130        request: SerializedRequest,
131    ) -> oneshot::Receiver<TransportResult<Box<RawValue>>> {
132        let (tx, rx) = oneshot::channel();
133        self.channels.insert(request.id().clone(), tx);
134        self.requests.push(request);
135        rx
136    }
137
138    fn push<Params: RpcSend, Resp: RpcRecv>(
139        &mut self,
140        request: Request<Params>,
141    ) -> TransportResult<Waiter<Resp>> {
142        let ser = request.serialize().map_err(TransportError::ser_err)?;
143        Ok(self.push_raw(ser).into())
144    }
145
146    /// Add a call to the batch.
147    ///
148    /// ### Errors
149    ///
150    /// If the request cannot be serialized, this will return an error.
151    pub fn add_call<Params: RpcSend, Resp: RpcRecv>(
152        &mut self,
153        method: impl Into<Cow<'static, str>>,
154        params: &Params,
155    ) -> TransportResult<Waiter<Resp>> {
156        let request = self.transport.make_request(method, Cow::Borrowed(params));
157        self.push(request)
158    }
159
160    /// Send the batch future via its connection.
161    pub fn send(self) -> BatchFuture {
162        BatchFuture::Prepared {
163            transport: self.transport.transport.clone(),
164            requests: self.requests,
165            channels: self.channels,
166        }
167    }
168}
169
170impl IntoFuture for BatchRequest<'_> {
171    type Output = <BatchFuture as Future>::Output;
172    type IntoFuture = BatchFuture;
173
174    fn into_future(self) -> Self::IntoFuture {
175        self.send()
176    }
177}
178
179impl BatchFuture {
180    fn poll_prepared(
181        mut self: Pin<&mut Self>,
182        cx: &mut task::Context<'_>,
183    ) -> Poll<<Self as Future>::Output> {
184        let CallStateProj::Prepared { transport, requests, channels } = self.as_mut().project()
185        else {
186            unreachable!("Called poll_prepared in incorrect state")
187        };
188
189        if let Err(e) = task::ready!(transport.poll_ready(cx)) {
190            self.set(Self::Complete);
191            return Poll::Ready(Err(e));
192        }
193
194        // We only have mut refs, and we want ownership, so we just replace with 0-capacity
195        // collections.
196        let channels = std::mem::take(channels);
197        let req = std::mem::replace(requests, RequestPacket::Batch(Vec::new()));
198
199        let fut = transport.call(req);
200        self.set(Self::AwaitingResponse { channels, fut });
201        cx.waker().wake_by_ref();
202        Poll::Pending
203    }
204
205    fn poll_awaiting_response(
206        mut self: Pin<&mut Self>,
207        cx: &mut task::Context<'_>,
208    ) -> Poll<<Self as Future>::Output> {
209        let CallStateProj::AwaitingResponse { channels, fut } = self.as_mut().project() else {
210            unreachable!("Called poll_awaiting_response in incorrect state")
211        };
212
213        // Has the service responded yet?
214        let responses = match ready!(fut.poll(cx)) {
215            Ok(responses) => responses,
216            Err(e) => {
217                self.set(Self::Complete);
218                return Poll::Ready(Err(e));
219            }
220        };
221
222        // Send all responses via channels
223        match responses {
224            ResponsePacket::Single(single) => {
225                if let Some(tx) = channels.remove(&single.id) {
226                    let _ = tx.send(transform_response(single));
227                }
228            }
229            ResponsePacket::Batch(responses) => {
230                for response in responses {
231                    if let Some(tx) = channels.remove(&response.id) {
232                        let _ = tx.send(transform_response(response));
233                    }
234                }
235            }
236        }
237
238        // Any channels remaining in the map are missing responses.
239        // To avoid hanging futures, we send an error.
240        for (id, tx) in channels.drain() {
241            let _ = tx.send(Err(TransportErrorKind::missing_batch_response(id)));
242        }
243
244        self.set(Self::Complete);
245        Poll::Ready(Ok(()))
246    }
247
248    fn poll_ser_error(
249        mut self: Pin<&mut Self>,
250        _cx: &mut task::Context<'_>,
251    ) -> Poll<<Self as Future>::Output> {
252        let e = if let CallStateProj::SerError(e) = self.as_mut().project() {
253            e.take().expect("no error")
254        } else {
255            unreachable!("Called poll_ser_error in incorrect state")
256        };
257
258        self.set(Self::Complete);
259        Poll::Ready(Err(e))
260    }
261}
262
263impl Future for BatchFuture {
264    type Output = TransportResult<()>;
265
266    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
267        if matches!(*self.as_mut(), Self::Prepared { .. }) {
268            return self.poll_prepared(cx);
269        }
270
271        if matches!(*self.as_mut(), Self::AwaitingResponse { .. }) {
272            return self.poll_awaiting_response(cx);
273        }
274
275        if matches!(*self.as_mut(), Self::SerError(_)) {
276            return self.poll_ser_error(cx);
277        }
278
279        panic!("Called poll on CallState in invalid state")
280    }
281}