alloy_transport/layers/
retry.rs

1use crate::{
2    error::{RpcErrorExt, TransportError, TransportErrorKind},
3    TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use core::fmt;
7use std::{
8    sync::{
9        atomic::{AtomicU32, Ordering},
10        Arc,
11    },
12    task::{Context, Poll},
13    time::Duration,
14};
15use tower::{Layer, Service};
16use tracing::trace;
17
18#[cfg(target_family = "wasm")]
19use wasmtimer::tokio::sleep;
20
21#[cfg(not(target_family = "wasm"))]
22use tokio::time::sleep;
23
24/// The default average cost of a request in compute units (CU).
25const DEFAULT_AVG_COST: u64 = 17u64;
26
27/// A Transport Layer that is responsible for retrying requests based on the
28/// error type. See [`TransportError`].
29///
30/// TransportError: crate::error::TransportError
31#[derive(Debug, Clone)]
32pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
33    /// The maximum number of retries for rate limit errors
34    max_rate_limit_retries: u32,
35    /// The initial backoff in milliseconds
36    initial_backoff: u64,
37    /// The number of compute units per second for this provider
38    compute_units_per_second: u64,
39    /// The average cost of a request. Defaults to [DEFAULT_AVG_COST]
40    avg_cost: u64,
41    /// The [RetryPolicy] to use. Defaults to [RateLimitRetryPolicy]
42    policy: P,
43}
44
45impl RetryBackoffLayer {
46    /// Creates a new retry layer with the given parameters and the default [RateLimitRetryPolicy].
47    pub const fn new(
48        max_rate_limit_retries: u32,
49        initial_backoff: u64,
50        compute_units_per_second: u64,
51    ) -> Self {
52        Self {
53            max_rate_limit_retries,
54            initial_backoff,
55            compute_units_per_second,
56            avg_cost: DEFAULT_AVG_COST,
57            policy: RateLimitRetryPolicy,
58        }
59    }
60
61    /// Set the average cost of a request. Defaults to `17` CU
62    /// The cost of requests are usually weighted and can vary from 10 CU to several 100 CU,
63    /// cheaper requests are more common some example alchemy
64    /// weights:
65    /// - `eth_getStorageAt`: 17
66    /// - `eth_getBlockByNumber`: 16
67    /// - `eth_newFilter`: 20
68    ///
69    /// (coming from forking mode) assuming here that storage request will be the
70    /// driver for Rate limits we choose `17` as the average cost
71    /// of any request
72    pub const fn with_avg_unit_cost(mut self, avg_cost: u64) -> Self {
73        self.avg_cost = avg_cost;
74        self
75    }
76}
77
78impl<P: RetryPolicy> RetryBackoffLayer<P> {
79    /// Creates a new retry layer with the given parameters and [RetryPolicy].
80    pub const fn new_with_policy(
81        max_rate_limit_retries: u32,
82        initial_backoff: u64,
83        compute_units_per_second: u64,
84        policy: P,
85    ) -> Self {
86        Self {
87            max_rate_limit_retries,
88            initial_backoff,
89            compute_units_per_second,
90            policy,
91            avg_cost: DEFAULT_AVG_COST,
92        }
93    }
94}
95
96/// [RateLimitRetryPolicy] implements [RetryPolicy] to determine whether to retry depending on the
97/// err.
98#[derive(Debug, Copy, Clone, Default)]
99#[non_exhaustive]
100pub struct RateLimitRetryPolicy;
101
102impl RateLimitRetryPolicy {
103    /// Creates a new [`RetryPolicy`] that in addition to this policy respects the given closure
104    /// function for detecting if an error should be retried.
105    pub fn or<F>(self, f: F) -> OrRetryPolicyFn<Self>
106    where
107        F: Fn(&TransportError) -> bool + Send + Sync + 'static,
108    {
109        OrRetryPolicyFn::new(self, f)
110    }
111}
112
113/// [RetryPolicy] defines logic for which [TransportError] instances should
114/// the client retry the request and try to recover from.
115pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
116    /// Whether to retry the request based on the given `error`
117    fn should_retry(&self, error: &TransportError) -> bool;
118
119    /// Providers may include the `backoff` in the error response directly
120    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
121}
122
123impl RetryPolicy for RateLimitRetryPolicy {
124    fn should_retry(&self, error: &TransportError) -> bool {
125        error.is_retryable()
126    }
127
128    /// Provides a backoff hint if the error response contains it
129    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
130        error.backoff_hint()
131    }
132}
133
134/// A [`RetryPolicy`] that supports an additional closure for deciding if an error should be
135/// retried.
136#[derive(Clone)]
137pub struct OrRetryPolicyFn<P = RateLimitRetryPolicy> {
138    inner: Arc<dyn Fn(&TransportError) -> bool + Send + Sync>,
139    base: P,
140}
141
142impl<P> OrRetryPolicyFn<P> {
143    /// Creates a new instance with the given base policy and the given closure
144    pub fn new<F>(base: P, or: F) -> Self
145    where
146        F: Fn(&TransportError) -> bool + Send + Sync + 'static,
147    {
148        Self { inner: Arc::new(or), base }
149    }
150}
151
152impl<P: RetryPolicy> RetryPolicy for OrRetryPolicyFn<P> {
153    fn should_retry(&self, error: &TransportError) -> bool {
154        self.inner.as_ref()(error) || self.base.should_retry(error)
155    }
156
157    fn backoff_hint(&self, error: &TransportError) -> Option<Duration> {
158        self.base.backoff_hint(error)
159    }
160}
161
162impl<P: fmt::Debug> fmt::Debug for OrRetryPolicyFn<P> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        f.debug_struct("OrRetryPolicyFn")
165            .field("base", &self.base)
166            .field("inner", &"{{..}}")
167            .finish_non_exhaustive()
168    }
169}
170
171impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
172    type Service = RetryBackoffService<S, P>;
173
174    fn layer(&self, inner: S) -> Self::Service {
175        RetryBackoffService {
176            inner,
177            policy: self.policy.clone(),
178            max_rate_limit_retries: self.max_rate_limit_retries,
179            initial_backoff: self.initial_backoff,
180            compute_units_per_second: self.compute_units_per_second,
181            requests_enqueued: Arc::new(AtomicU32::new(0)),
182            avg_cost: self.avg_cost,
183        }
184    }
185}
186
187/// A Tower Service used by the RetryBackoffLayer that is responsible for retrying requests based
188/// on the error type. See [TransportError] and [RateLimitRetryPolicy].
189#[derive(Debug, Clone)]
190pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
191    /// The inner service
192    inner: S,
193    /// The [RetryPolicy] to use.
194    policy: P,
195    /// The maximum number of retries for rate limit errors
196    max_rate_limit_retries: u32,
197    /// The initial backoff in milliseconds
198    initial_backoff: u64,
199    /// The number of compute units per second for this service
200    compute_units_per_second: u64,
201    /// The number of requests currently enqueued
202    requests_enqueued: Arc<AtomicU32>,
203    /// The average cost of a request.
204    avg_cost: u64,
205}
206
207impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
208    const fn initial_backoff(&self) -> Duration {
209        Duration::from_millis(self.initial_backoff)
210    }
211}
212
213impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
214where
215    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
216        + Send
217        + 'static
218        + Clone,
219    P: RetryPolicy + Clone + 'static,
220{
221    type Response = ResponsePacket;
222    type Error = TransportError;
223    type Future = TransportFut<'static>;
224
225    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226        // Our middleware doesn't care about backpressure, so it's ready as long
227        // as the inner service is ready.
228        self.inner.poll_ready(cx)
229    }
230
231    fn call(&mut self, request: RequestPacket) -> Self::Future {
232        let inner = self.inner.clone();
233        let this = self.clone();
234        let mut inner = std::mem::replace(&mut self.inner, inner);
235        Box::pin(async move {
236            let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
237            let mut rate_limit_retry_number: u32 = 0;
238            loop {
239                let err;
240                let res = inner.call(request.clone()).await;
241
242                match res {
243                    Ok(res) => {
244                        if let Some(e) = res.as_error() {
245                            err = TransportError::ErrorResp(e.clone())
246                        } else {
247                            this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
248                            return Ok(res);
249                        }
250                    }
251                    Err(e) => err = e,
252                }
253
254                let should_retry = this.policy.should_retry(&err);
255                if should_retry {
256                    rate_limit_retry_number += 1;
257                    if rate_limit_retry_number > this.max_rate_limit_retries {
258                        this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
259                        return Err(TransportErrorKind::custom_str(&format!(
260                            "Max retries exceeded {err}"
261                        )));
262                    }
263                    trace!(%err, "retrying request");
264
265                    let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
266
267                    // try to extract the requested backoff from the error or compute the next
268                    // backoff based on retry count
269                    let backoff_hint = this.policy.backoff_hint(&err);
270                    let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
271
272                    let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
273                        this.avg_cost,
274                        this.compute_units_per_second,
275                        current_queued_reqs,
276                        ahead_in_queue,
277                    );
278                    let total_backoff = next_backoff
279                        + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
280
281                    trace!(
282                        total_backoff_millis = total_backoff.as_millis(),
283                        budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
284                        default_backoff_millis = next_backoff.as_millis(),
285                        backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
286                        "(all in ms) backing off due to rate limit"
287                    );
288
289                    sleep(total_backoff).await;
290                } else {
291                    this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
292                    return Err(err);
293                }
294            }
295        })
296    }
297}
298
299/// Calculates an offset in seconds by taking into account the number of currently queued requests,
300/// number of requests that were ahead in the queue when the request was first issued, the average
301/// cost a weighted request (heuristic), and the number of available compute units per seconds.
302///
303/// Returns the number of seconds (the unit the remote endpoint measures compute budget) a request
304/// is supposed to wait to not get rate limited. The budget per second is
305/// `compute_units_per_second`, assuming an average cost of `avg_cost` this allows (in theory)
306/// `compute_units_per_second / avg_cost` requests per seconds without getting rate limited.
307/// By taking into account the number of concurrent request and the position in queue when the
308/// request was first issued and determine the number of seconds a request is supposed to wait, if
309/// at all
310fn compute_unit_offset_in_secs(
311    avg_cost: u64,
312    compute_units_per_second: u64,
313    current_queued_requests: u64,
314    ahead_in_queue: u64,
315) -> u64 {
316    let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
317    if current_queued_requests > request_capacity_per_second {
318        current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
319    } else {
320        0
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_compute_units_per_second() {
330        let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
331        assert_eq!(offset, 0);
332        let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
333        assert_eq!(offset, 2);
334    }
335}