1use crate::{
2 error::{RpcErrorExt, TransportError, TransportErrorKind},
3 TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use core::fmt;
7use std::{
8 sync::{
9 atomic::{AtomicU32, Ordering},
10 Arc,
11 },
12 task::{Context, Poll},
13 time::Duration,
14};
15use tower::{Layer, Service};
16use tracing::trace;
17
18#[cfg(target_family = "wasm")]
19use wasmtimer::tokio::sleep;
20
21#[cfg(not(target_family = "wasm"))]
22use tokio::time::sleep;
23
24const DEFAULT_AVG_COST: u64 = 17u64;
26
27#[derive(Debug, Clone)]
32pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
33 max_rate_limit_retries: u32,
35 initial_backoff: u64,
37 compute_units_per_second: u64,
39 avg_cost: u64,
41 policy: P,
43}
44
45impl RetryBackoffLayer {
46 pub const fn new(
48 max_rate_limit_retries: u32,
49 initial_backoff: u64,
50 compute_units_per_second: u64,
51 ) -> Self {
52 Self {
53 max_rate_limit_retries,
54 initial_backoff,
55 compute_units_per_second,
56 avg_cost: DEFAULT_AVG_COST,
57 policy: RateLimitRetryPolicy,
58 }
59 }
60
61 pub const fn with_avg_unit_cost(mut self, avg_cost: u64) -> Self {
73 self.avg_cost = avg_cost;
74 self
75 }
76}
77
78impl<P: RetryPolicy> RetryBackoffLayer<P> {
79 pub const fn new_with_policy(
81 max_rate_limit_retries: u32,
82 initial_backoff: u64,
83 compute_units_per_second: u64,
84 policy: P,
85 ) -> Self {
86 Self {
87 max_rate_limit_retries,
88 initial_backoff,
89 compute_units_per_second,
90 policy,
91 avg_cost: DEFAULT_AVG_COST,
92 }
93 }
94}
95
96#[derive(Debug, Copy, Clone, Default)]
99#[non_exhaustive]
100pub struct RateLimitRetryPolicy;
101
102impl RateLimitRetryPolicy {
103 pub fn or<F>(self, f: F) -> OrRetryPolicyFn<Self>
106 where
107 F: Fn(&TransportError) -> bool + Send + Sync + 'static,
108 {
109 OrRetryPolicyFn::new(self, f)
110 }
111}
112
113pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
116 fn should_retry(&self, error: &TransportError) -> bool;
118
119 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
121}
122
123impl RetryPolicy for RateLimitRetryPolicy {
124 fn should_retry(&self, error: &TransportError) -> bool {
125 error.is_retryable()
126 }
127
128 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
130 error.backoff_hint()
131 }
132}
133
134#[derive(Clone)]
137pub struct OrRetryPolicyFn<P = RateLimitRetryPolicy> {
138 inner: Arc<dyn Fn(&TransportError) -> bool + Send + Sync>,
139 base: P,
140}
141
142impl<P> OrRetryPolicyFn<P> {
143 pub fn new<F>(base: P, or: F) -> Self
145 where
146 F: Fn(&TransportError) -> bool + Send + Sync + 'static,
147 {
148 Self { inner: Arc::new(or), base }
149 }
150}
151
152impl<P: RetryPolicy> RetryPolicy for OrRetryPolicyFn<P> {
153 fn should_retry(&self, error: &TransportError) -> bool {
154 self.inner.as_ref()(error) || self.base.should_retry(error)
155 }
156
157 fn backoff_hint(&self, error: &TransportError) -> Option<Duration> {
158 self.base.backoff_hint(error)
159 }
160}
161
162impl<P: fmt::Debug> fmt::Debug for OrRetryPolicyFn<P> {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 f.debug_struct("OrRetryPolicyFn")
165 .field("base", &self.base)
166 .field("inner", &"{{..}}")
167 .finish_non_exhaustive()
168 }
169}
170
171impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
172 type Service = RetryBackoffService<S, P>;
173
174 fn layer(&self, inner: S) -> Self::Service {
175 RetryBackoffService {
176 inner,
177 policy: self.policy.clone(),
178 max_rate_limit_retries: self.max_rate_limit_retries,
179 initial_backoff: self.initial_backoff,
180 compute_units_per_second: self.compute_units_per_second,
181 requests_enqueued: Arc::new(AtomicU32::new(0)),
182 avg_cost: self.avg_cost,
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
190pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
191 inner: S,
193 policy: P,
195 max_rate_limit_retries: u32,
197 initial_backoff: u64,
199 compute_units_per_second: u64,
201 requests_enqueued: Arc<AtomicU32>,
203 avg_cost: u64,
205}
206
207impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
208 const fn initial_backoff(&self) -> Duration {
209 Duration::from_millis(self.initial_backoff)
210 }
211}
212
213impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
214where
215 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
216 + Send
217 + 'static
218 + Clone,
219 P: RetryPolicy + Clone + 'static,
220{
221 type Response = ResponsePacket;
222 type Error = TransportError;
223 type Future = TransportFut<'static>;
224
225 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226 self.inner.poll_ready(cx)
229 }
230
231 fn call(&mut self, request: RequestPacket) -> Self::Future {
232 let inner = self.inner.clone();
233 let this = self.clone();
234 let mut inner = std::mem::replace(&mut self.inner, inner);
235 Box::pin(async move {
236 let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
237 let mut rate_limit_retry_number: u32 = 0;
238 loop {
239 let err;
240 let res = inner.call(request.clone()).await;
241
242 match res {
243 Ok(res) => {
244 if let Some(e) = res.as_error() {
245 err = TransportError::ErrorResp(e.clone())
246 } else {
247 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
248 return Ok(res);
249 }
250 }
251 Err(e) => err = e,
252 }
253
254 let should_retry = this.policy.should_retry(&err);
255 if should_retry {
256 rate_limit_retry_number += 1;
257 if rate_limit_retry_number > this.max_rate_limit_retries {
258 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
259 return Err(TransportErrorKind::custom_str(&format!(
260 "Max retries exceeded {err}"
261 )));
262 }
263 trace!(%err, "retrying request");
264
265 let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
266
267 let backoff_hint = this.policy.backoff_hint(&err);
270 let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
271
272 let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
273 this.avg_cost,
274 this.compute_units_per_second,
275 current_queued_reqs,
276 ahead_in_queue,
277 );
278 let total_backoff = next_backoff
279 + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
280
281 trace!(
282 total_backoff_millis = total_backoff.as_millis(),
283 budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
284 default_backoff_millis = next_backoff.as_millis(),
285 backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
286 "(all in ms) backing off due to rate limit"
287 );
288
289 sleep(total_backoff).await;
290 } else {
291 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
292 return Err(err);
293 }
294 }
295 })
296 }
297}
298
299fn compute_unit_offset_in_secs(
311 avg_cost: u64,
312 compute_units_per_second: u64,
313 current_queued_requests: u64,
314 ahead_in_queue: u64,
315) -> u64 {
316 let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
317 if current_queued_requests > request_capacity_per_second {
318 current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
319 } else {
320 0
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_compute_units_per_second() {
330 let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
331 assert_eq!(offset, 0);
332 let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
333 assert_eq!(offset, 2);
334 }
335}