alloy_transport/layers/
retry.rs

1use crate::{
2    error::{RpcErrorExt, TransportError, TransportErrorKind},
3    TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use std::{
7    sync::{
8        atomic::{AtomicU32, Ordering},
9        Arc,
10    },
11    task::{Context, Poll},
12    time::Duration,
13};
14use tower::{Layer, Service};
15use tracing::trace;
16
17#[cfg(target_family = "wasm")]
18use wasmtimer::tokio::sleep;
19
20#[cfg(not(target_family = "wasm"))]
21use tokio::time::sleep;
22
23/// The default average cost of a request in compute units (CU).
24const DEFAULT_AVG_COST: u64 = 17u64;
25
26/// A Transport Layer that is responsible for retrying requests based on the
27/// error type. See [`TransportError`].
28///
29/// TransportError: crate::error::TransportError
30#[derive(Debug, Clone)]
31pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
32    /// The maximum number of retries for rate limit errors
33    max_rate_limit_retries: u32,
34    /// The initial backoff in milliseconds
35    initial_backoff: u64,
36    /// The number of compute units per second for this provider
37    compute_units_per_second: u64,
38    /// The average cost of a request. Defaults to [DEFAULT_AVG_COST]
39    avg_cost: u64,
40    /// The [RetryPolicy] to use. Defaults to [RateLimitRetryPolicy]
41    policy: P,
42}
43
44impl RetryBackoffLayer {
45    /// Creates a new retry layer with the given parameters and the default [RateLimitRetryPolicy].
46    pub const fn new(
47        max_rate_limit_retries: u32,
48        initial_backoff: u64,
49        compute_units_per_second: u64,
50    ) -> Self {
51        Self {
52            max_rate_limit_retries,
53            initial_backoff,
54            compute_units_per_second,
55            avg_cost: DEFAULT_AVG_COST,
56            policy: RateLimitRetryPolicy,
57        }
58    }
59
60    /// Set the average cost of a request. Defaults to `17` CU
61    /// The cost of requests are usually weighted and can vary from 10 CU to several 100 CU,
62    /// cheaper requests are more common some example alchemy
63    /// weights:
64    /// - `eth_getStorageAt`: 17
65    /// - `eth_getBlockByNumber`: 16
66    /// - `eth_newFilter`: 20
67    ///
68    /// (coming from forking mode) assuming here that storage request will be the
69    /// driver for Rate limits we choose `17` as the average cost
70    /// of any request
71    pub const fn with_avg_unit_cost(mut self, avg_cost: u64) -> Self {
72        self.avg_cost = avg_cost;
73        self
74    }
75}
76
77impl<P: RetryPolicy> RetryBackoffLayer<P> {
78    /// Creates a new retry layer with the given parameters and [RetryPolicy].
79    pub const fn new_with_policy(
80        max_rate_limit_retries: u32,
81        initial_backoff: u64,
82        compute_units_per_second: u64,
83        policy: P,
84    ) -> Self {
85        Self {
86            max_rate_limit_retries,
87            initial_backoff,
88            compute_units_per_second,
89            policy,
90            avg_cost: DEFAULT_AVG_COST,
91        }
92    }
93}
94
95/// [RateLimitRetryPolicy] implements [RetryPolicy] to determine whether to retry depending on the
96/// err.
97#[derive(Debug, Copy, Clone, Default)]
98#[non_exhaustive]
99pub struct RateLimitRetryPolicy;
100
101/// [RetryPolicy] defines logic for which [TransportError] instances should
102/// the client retry the request and try to recover from.
103pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
104    /// Whether to retry the request based on the given `error`
105    fn should_retry(&self, error: &TransportError) -> bool;
106
107    /// Providers may include the `backoff` in the error response directly
108    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
109}
110
111impl RetryPolicy for RateLimitRetryPolicy {
112    fn should_retry(&self, error: &TransportError) -> bool {
113        error.is_retryable()
114    }
115
116    /// Provides a backoff hint if the error response contains it
117    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
118        error.backoff_hint()
119    }
120}
121
122impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
123    type Service = RetryBackoffService<S, P>;
124
125    fn layer(&self, inner: S) -> Self::Service {
126        RetryBackoffService {
127            inner,
128            policy: self.policy.clone(),
129            max_rate_limit_retries: self.max_rate_limit_retries,
130            initial_backoff: self.initial_backoff,
131            compute_units_per_second: self.compute_units_per_second,
132            requests_enqueued: Arc::new(AtomicU32::new(0)),
133            avg_cost: self.avg_cost,
134        }
135    }
136}
137
138/// A Tower Service used by the RetryBackoffLayer that is responsible for retrying requests based
139/// on the error type. See [TransportError] and [RateLimitRetryPolicy].
140#[derive(Debug, Clone)]
141pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
142    /// The inner service
143    inner: S,
144    /// The [RetryPolicy] to use.
145    policy: P,
146    /// The maximum number of retries for rate limit errors
147    max_rate_limit_retries: u32,
148    /// The initial backoff in milliseconds
149    initial_backoff: u64,
150    /// The number of compute units per second for this service
151    compute_units_per_second: u64,
152    /// The number of requests currently enqueued
153    requests_enqueued: Arc<AtomicU32>,
154    /// The average cost of a request.
155    avg_cost: u64,
156}
157
158impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
159    const fn initial_backoff(&self) -> Duration {
160        Duration::from_millis(self.initial_backoff)
161    }
162}
163
164impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
165where
166    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
167        + Send
168        + 'static
169        + Clone,
170    P: RetryPolicy + Clone + 'static,
171{
172    type Response = ResponsePacket;
173    type Error = TransportError;
174    type Future = TransportFut<'static>;
175
176    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        // Our middleware doesn't care about backpressure, so it's ready as long
178        // as the inner service is ready.
179        self.inner.poll_ready(cx)
180    }
181
182    fn call(&mut self, request: RequestPacket) -> Self::Future {
183        let inner = self.inner.clone();
184        let this = self.clone();
185        let mut inner = std::mem::replace(&mut self.inner, inner);
186        Box::pin(async move {
187            let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
188            let mut rate_limit_retry_number: u32 = 0;
189            loop {
190                let err;
191                let res = inner.call(request.clone()).await;
192
193                match res {
194                    Ok(res) => {
195                        if let Some(e) = res.as_error() {
196                            err = TransportError::ErrorResp(e.clone())
197                        } else {
198                            this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
199                            return Ok(res);
200                        }
201                    }
202                    Err(e) => err = e,
203                }
204
205                let should_retry = this.policy.should_retry(&err);
206                if should_retry {
207                    rate_limit_retry_number += 1;
208                    if rate_limit_retry_number > this.max_rate_limit_retries {
209                        return Err(TransportErrorKind::custom_str(&format!(
210                            "Max retries exceeded {err}"
211                        )));
212                    }
213                    trace!(%err, "retrying request");
214
215                    let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
216
217                    // try to extract the requested backoff from the error or compute the next
218                    // backoff based on retry count
219                    let backoff_hint = this.policy.backoff_hint(&err);
220                    let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
221
222                    let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
223                        this.avg_cost,
224                        this.compute_units_per_second,
225                        current_queued_reqs,
226                        ahead_in_queue,
227                    );
228                    let total_backoff = next_backoff
229                        + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
230
231                    trace!(
232                        total_backoff_millis = total_backoff.as_millis(),
233                        budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
234                        default_backoff_millis = next_backoff.as_millis(),
235                        backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
236                        "(all in ms) backing off due to rate limit"
237                    );
238
239                    sleep(total_backoff).await;
240                } else {
241                    this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
242                    return Err(err);
243                }
244            }
245        })
246    }
247}
248
249/// Calculates an offset in seconds by taking into account the number of currently queued requests,
250/// number of requests that were ahead in the queue when the request was first issued, the average
251/// cost a weighted request (heuristic), and the number of available compute units per seconds.
252///
253/// Returns the number of seconds (the unit the remote endpoint measures compute budget) a request
254/// is supposed to wait to not get rate limited. The budget per second is
255/// `compute_units_per_second`, assuming an average cost of `avg_cost` this allows (in theory)
256/// `compute_units_per_second / avg_cost` requests per seconds without getting rate limited.
257/// By taking into account the number of concurrent request and the position in queue when the
258/// request was first issued and determine the number of seconds a request is supposed to wait, if
259/// at all
260fn compute_unit_offset_in_secs(
261    avg_cost: u64,
262    compute_units_per_second: u64,
263    current_queued_requests: u64,
264    ahead_in_queue: u64,
265) -> u64 {
266    let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
267    if current_queued_requests > request_capacity_per_second {
268        current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
269    } else {
270        0
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_compute_units_per_second() {
280        let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
281        assert_eq!(offset, 0);
282        let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
283        assert_eq!(offset, 2);
284    }
285}