1use crate::{
2 error::{RpcErrorExt, TransportError, TransportErrorKind},
3 TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use std::{
7 sync::{
8 atomic::{AtomicU32, Ordering},
9 Arc,
10 },
11 task::{Context, Poll},
12 time::Duration,
13};
14use tower::{Layer, Service};
15use tracing::trace;
16
17#[cfg(target_family = "wasm")]
18use wasmtimer::tokio::sleep;
19
20#[cfg(not(target_family = "wasm"))]
21use tokio::time::sleep;
22
23const DEFAULT_AVG_COST: u64 = 17u64;
25
26#[derive(Debug, Clone)]
31pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
32 max_rate_limit_retries: u32,
34 initial_backoff: u64,
36 compute_units_per_second: u64,
38 avg_cost: u64,
40 policy: P,
42}
43
44impl RetryBackoffLayer {
45 pub const fn new(
47 max_rate_limit_retries: u32,
48 initial_backoff: u64,
49 compute_units_per_second: u64,
50 ) -> Self {
51 Self {
52 max_rate_limit_retries,
53 initial_backoff,
54 compute_units_per_second,
55 avg_cost: DEFAULT_AVG_COST,
56 policy: RateLimitRetryPolicy,
57 }
58 }
59
60 pub const fn with_avg_unit_cost(mut self, avg_cost: u64) {
72 self.avg_cost = avg_cost;
73 }
74}
75
76impl<P: RetryPolicy> RetryBackoffLayer<P> {
77 pub const fn new_with_policy(
79 max_rate_limit_retries: u32,
80 initial_backoff: u64,
81 compute_units_per_second: u64,
82 policy: P,
83 ) -> Self {
84 Self {
85 max_rate_limit_retries,
86 initial_backoff,
87 compute_units_per_second,
88 policy,
89 avg_cost: DEFAULT_AVG_COST,
90 }
91 }
92}
93
94#[derive(Debug, Copy, Clone, Default)]
97#[non_exhaustive]
98pub struct RateLimitRetryPolicy;
99
100pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
103 fn should_retry(&self, error: &TransportError) -> bool;
105
106 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
108}
109
110impl RetryPolicy for RateLimitRetryPolicy {
111 fn should_retry(&self, error: &TransportError) -> bool {
112 error.is_retryable()
113 }
114
115 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
117 error.backoff_hint()
118 }
119}
120
121impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
122 type Service = RetryBackoffService<S, P>;
123
124 fn layer(&self, inner: S) -> Self::Service {
125 RetryBackoffService {
126 inner,
127 policy: self.policy.clone(),
128 max_rate_limit_retries: self.max_rate_limit_retries,
129 initial_backoff: self.initial_backoff,
130 compute_units_per_second: self.compute_units_per_second,
131 requests_enqueued: Arc::new(AtomicU32::new(0)),
132 avg_cost: self.avg_cost,
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
140pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
141 inner: S,
143 policy: P,
145 max_rate_limit_retries: u32,
147 initial_backoff: u64,
149 compute_units_per_second: u64,
151 requests_enqueued: Arc<AtomicU32>,
153 avg_cost: u64,
155}
156
157impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
158 const fn initial_backoff(&self) -> Duration {
159 Duration::from_millis(self.initial_backoff)
160 }
161}
162
163impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
164where
165 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
166 + Send
167 + 'static
168 + Clone,
169 P: RetryPolicy + Clone + 'static,
170{
171 type Response = ResponsePacket;
172 type Error = TransportError;
173 type Future = TransportFut<'static>;
174
175 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176 self.inner.poll_ready(cx)
179 }
180
181 fn call(&mut self, request: RequestPacket) -> Self::Future {
182 let inner = self.inner.clone();
183 let this = self.clone();
184 let mut inner = std::mem::replace(&mut self.inner, inner);
185 Box::pin(async move {
186 let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
187 let mut rate_limit_retry_number: u32 = 0;
188 loop {
189 let err;
190 let res = inner.call(request.clone()).await;
191
192 match res {
193 Ok(res) => {
194 if let Some(e) = res.as_error() {
195 err = TransportError::ErrorResp(e.clone())
196 } else {
197 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
198 return Ok(res);
199 }
200 }
201 Err(e) => err = e,
202 }
203
204 let should_retry = this.policy.should_retry(&err);
205 if should_retry {
206 rate_limit_retry_number += 1;
207 if rate_limit_retry_number > this.max_rate_limit_retries {
208 return Err(TransportErrorKind::custom_str(&format!(
209 "Max retries exceeded {err}"
210 )));
211 }
212 trace!(%err, "retrying request");
213
214 let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
215
216 let backoff_hint = this.policy.backoff_hint(&err);
219 let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
220
221 let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
222 this.avg_cost,
223 this.compute_units_per_second,
224 current_queued_reqs,
225 ahead_in_queue,
226 );
227 let total_backoff = next_backoff
228 + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
229
230 trace!(
231 total_backoff_millis = total_backoff.as_millis(),
232 budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
233 default_backoff_millis = next_backoff.as_millis(),
234 backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
235 "(all in ms) backing off due to rate limit"
236 );
237
238 sleep(total_backoff).await;
239 } else {
240 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
241 return Err(err);
242 }
243 }
244 })
245 }
246}
247
248fn compute_unit_offset_in_secs(
260 avg_cost: u64,
261 compute_units_per_second: u64,
262 current_queued_requests: u64,
263 ahead_in_queue: u64,
264) -> u64 {
265 let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
266 if current_queued_requests > request_capacity_per_second {
267 current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
268 } else {
269 0
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_compute_units_per_second() {
279 let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
280 assert_eq!(offset, 0);
281 let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
282 assert_eq!(offset, 2);
283 }
284}