alloy_transport/layers/
retry.rs

1use crate::{
2    error::{RpcErrorExt, TransportError, TransportErrorKind},
3    TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use std::{
7    sync::{
8        atomic::{AtomicU32, Ordering},
9        Arc,
10    },
11    task::{Context, Poll},
12    time::Duration,
13};
14use tower::{Layer, Service};
15use tracing::trace;
16
17#[cfg(target_family = "wasm")]
18use wasmtimer::tokio::sleep;
19
20#[cfg(not(target_family = "wasm"))]
21use tokio::time::sleep;
22
23/// The default average cost of a request in compute units (CU).
24const DEFAULT_AVG_COST: u64 = 17u64;
25
26/// A Transport Layer that is responsible for retrying requests based on the
27/// error type. See [`TransportError`].
28///
29/// TransportError: crate::error::TransportError
30#[derive(Debug, Clone)]
31pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
32    /// The maximum number of retries for rate limit errors
33    max_rate_limit_retries: u32,
34    /// The initial backoff in milliseconds
35    initial_backoff: u64,
36    /// The number of compute units per second for this provider
37    compute_units_per_second: u64,
38    /// The average cost of a request. Defaults to [DEFAULT_AVG_COST]
39    avg_cost: u64,
40    /// The [RetryPolicy] to use. Defaults to [RateLimitRetryPolicy]
41    policy: P,
42}
43
44impl RetryBackoffLayer {
45    /// Creates a new retry layer with the given parameters and the default [RateLimitRetryPolicy].
46    pub const fn new(
47        max_rate_limit_retries: u32,
48        initial_backoff: u64,
49        compute_units_per_second: u64,
50    ) -> Self {
51        Self {
52            max_rate_limit_retries,
53            initial_backoff,
54            compute_units_per_second,
55            avg_cost: DEFAULT_AVG_COST,
56            policy: RateLimitRetryPolicy,
57        }
58    }
59
60    /// Set the average cost of a request. Defaults to `17` CU
61    /// The cost of requests are usually weighted and can vary from 10 CU to several 100 CU,
62    /// cheaper requests are more common some example alchemy
63    /// weights:
64    /// - `eth_getStorageAt`: 17
65    /// - `eth_getBlockByNumber`: 16
66    /// - `eth_newFilter`: 20
67    ///
68    /// (coming from forking mode) assuming here that storage request will be the
69    /// driver for Rate limits we choose `17` as the average cost
70    /// of any request
71    pub const fn with_avg_unit_cost(mut self, avg_cost: u64) {
72        self.avg_cost = avg_cost;
73    }
74}
75
76impl<P: RetryPolicy> RetryBackoffLayer<P> {
77    /// Creates a new retry layer with the given parameters and [RetryPolicy].
78    pub const fn new_with_policy(
79        max_rate_limit_retries: u32,
80        initial_backoff: u64,
81        compute_units_per_second: u64,
82        policy: P,
83    ) -> Self {
84        Self {
85            max_rate_limit_retries,
86            initial_backoff,
87            compute_units_per_second,
88            policy,
89            avg_cost: DEFAULT_AVG_COST,
90        }
91    }
92}
93
94/// [RateLimitRetryPolicy] implements [RetryPolicy] to determine whether to retry depending on the
95/// err.
96#[derive(Debug, Copy, Clone, Default)]
97#[non_exhaustive]
98pub struct RateLimitRetryPolicy;
99
100/// [RetryPolicy] defines logic for which [TransportError] instances should
101/// the client retry the request and try to recover from.
102pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
103    /// Whether to retry the request based on the given `error`
104    fn should_retry(&self, error: &TransportError) -> bool;
105
106    /// Providers may include the `backoff` in the error response directly
107    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
108}
109
110impl RetryPolicy for RateLimitRetryPolicy {
111    fn should_retry(&self, error: &TransportError) -> bool {
112        error.is_retryable()
113    }
114
115    /// Provides a backoff hint if the error response contains it
116    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
117        error.backoff_hint()
118    }
119}
120
121impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
122    type Service = RetryBackoffService<S, P>;
123
124    fn layer(&self, inner: S) -> Self::Service {
125        RetryBackoffService {
126            inner,
127            policy: self.policy.clone(),
128            max_rate_limit_retries: self.max_rate_limit_retries,
129            initial_backoff: self.initial_backoff,
130            compute_units_per_second: self.compute_units_per_second,
131            requests_enqueued: Arc::new(AtomicU32::new(0)),
132            avg_cost: self.avg_cost,
133        }
134    }
135}
136
137/// A Tower Service used by the RetryBackoffLayer that is responsible for retrying requests based
138/// on the error type. See [TransportError] and [RateLimitRetryPolicy].
139#[derive(Debug, Clone)]
140pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
141    /// The inner service
142    inner: S,
143    /// The [RetryPolicy] to use.
144    policy: P,
145    /// The maximum number of retries for rate limit errors
146    max_rate_limit_retries: u32,
147    /// The initial backoff in milliseconds
148    initial_backoff: u64,
149    /// The number of compute units per second for this service
150    compute_units_per_second: u64,
151    /// The number of requests currently enqueued
152    requests_enqueued: Arc<AtomicU32>,
153    /// The average cost of a request.
154    avg_cost: u64,
155}
156
157impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
158    const fn initial_backoff(&self) -> Duration {
159        Duration::from_millis(self.initial_backoff)
160    }
161}
162
163impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
164where
165    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
166        + Send
167        + 'static
168        + Clone,
169    P: RetryPolicy + Clone + 'static,
170{
171    type Response = ResponsePacket;
172    type Error = TransportError;
173    type Future = TransportFut<'static>;
174
175    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176        // Our middleware doesn't care about backpressure, so it's ready as long
177        // as the inner service is ready.
178        self.inner.poll_ready(cx)
179    }
180
181    fn call(&mut self, request: RequestPacket) -> Self::Future {
182        let inner = self.inner.clone();
183        let this = self.clone();
184        let mut inner = std::mem::replace(&mut self.inner, inner);
185        Box::pin(async move {
186            let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
187            let mut rate_limit_retry_number: u32 = 0;
188            loop {
189                let err;
190                let res = inner.call(request.clone()).await;
191
192                match res {
193                    Ok(res) => {
194                        if let Some(e) = res.as_error() {
195                            err = TransportError::ErrorResp(e.clone())
196                        } else {
197                            this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
198                            return Ok(res);
199                        }
200                    }
201                    Err(e) => err = e,
202                }
203
204                let should_retry = this.policy.should_retry(&err);
205                if should_retry {
206                    rate_limit_retry_number += 1;
207                    if rate_limit_retry_number > this.max_rate_limit_retries {
208                        return Err(TransportErrorKind::custom_str(&format!(
209                            "Max retries exceeded {err}"
210                        )));
211                    }
212                    trace!(%err, "retrying request");
213
214                    let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
215
216                    // try to extract the requested backoff from the error or compute the next
217                    // backoff based on retry count
218                    let backoff_hint = this.policy.backoff_hint(&err);
219                    let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
220
221                    let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
222                        this.avg_cost,
223                        this.compute_units_per_second,
224                        current_queued_reqs,
225                        ahead_in_queue,
226                    );
227                    let total_backoff = next_backoff
228                        + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
229
230                    trace!(
231                        total_backoff_millis = total_backoff.as_millis(),
232                        budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
233                        default_backoff_millis = next_backoff.as_millis(),
234                        backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
235                        "(all in ms) backing off due to rate limit"
236                    );
237
238                    sleep(total_backoff).await;
239                } else {
240                    this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
241                    return Err(err);
242                }
243            }
244        })
245    }
246}
247
248/// Calculates an offset in seconds by taking into account the number of currently queued requests,
249/// number of requests that were ahead in the queue when the request was first issued, the average
250/// cost a weighted request (heuristic), and the number of available compute units per seconds.
251///
252/// Returns the number of seconds (the unit the remote endpoint measures compute budget) a request
253/// is supposed to wait to not get rate limited. The budget per second is
254/// `compute_units_per_second`, assuming an average cost of `avg_cost` this allows (in theory)
255/// `compute_units_per_second / avg_cost` requests per seconds without getting rate limited.
256/// By taking into account the number of concurrent request and the position in queue when the
257/// request was first issued and determine the number of seconds a request is supposed to wait, if
258/// at all
259fn compute_unit_offset_in_secs(
260    avg_cost: u64,
261    compute_units_per_second: u64,
262    current_queued_requests: u64,
263    ahead_in_queue: u64,
264) -> u64 {
265    let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
266    if current_queued_requests > request_capacity_per_second {
267        current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
268    } else {
269        0
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_compute_units_per_second() {
279        let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
280        assert_eq!(offset, 0);
281        let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
282        assert_eq!(offset, 2);
283    }
284}