1use crate::{
2 error::{RpcErrorExt, TransportError, TransportErrorKind},
3 TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use std::{
7 sync::{
8 atomic::{AtomicU32, Ordering},
9 Arc,
10 },
11 task::{Context, Poll},
12 time::Duration,
13};
14use tower::{Layer, Service};
15use tracing::trace;
16
17#[cfg(target_family = "wasm")]
18use wasmtimer::tokio::sleep;
19
20#[cfg(not(target_family = "wasm"))]
21use tokio::time::sleep;
22
23const DEFAULT_AVG_COST: u64 = 17u64;
25
26#[derive(Debug, Clone)]
31pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
32 max_rate_limit_retries: u32,
34 initial_backoff: u64,
36 compute_units_per_second: u64,
38 avg_cost: u64,
40 policy: P,
42}
43
44impl RetryBackoffLayer {
45 pub const fn new(
47 max_rate_limit_retries: u32,
48 initial_backoff: u64,
49 compute_units_per_second: u64,
50 ) -> Self {
51 Self {
52 max_rate_limit_retries,
53 initial_backoff,
54 compute_units_per_second,
55 avg_cost: DEFAULT_AVG_COST,
56 policy: RateLimitRetryPolicy,
57 }
58 }
59
60 pub const fn with_avg_unit_cost(mut self, avg_cost: u64) -> Self {
72 self.avg_cost = avg_cost;
73 self
74 }
75}
76
77impl<P: RetryPolicy> RetryBackoffLayer<P> {
78 pub const fn new_with_policy(
80 max_rate_limit_retries: u32,
81 initial_backoff: u64,
82 compute_units_per_second: u64,
83 policy: P,
84 ) -> Self {
85 Self {
86 max_rate_limit_retries,
87 initial_backoff,
88 compute_units_per_second,
89 policy,
90 avg_cost: DEFAULT_AVG_COST,
91 }
92 }
93}
94
95#[derive(Debug, Copy, Clone, Default)]
98#[non_exhaustive]
99pub struct RateLimitRetryPolicy;
100
101pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
104 fn should_retry(&self, error: &TransportError) -> bool;
106
107 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
109}
110
111impl RetryPolicy for RateLimitRetryPolicy {
112 fn should_retry(&self, error: &TransportError) -> bool {
113 error.is_retryable()
114 }
115
116 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
118 error.backoff_hint()
119 }
120}
121
122impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
123 type Service = RetryBackoffService<S, P>;
124
125 fn layer(&self, inner: S) -> Self::Service {
126 RetryBackoffService {
127 inner,
128 policy: self.policy.clone(),
129 max_rate_limit_retries: self.max_rate_limit_retries,
130 initial_backoff: self.initial_backoff,
131 compute_units_per_second: self.compute_units_per_second,
132 requests_enqueued: Arc::new(AtomicU32::new(0)),
133 avg_cost: self.avg_cost,
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
141pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
142 inner: S,
144 policy: P,
146 max_rate_limit_retries: u32,
148 initial_backoff: u64,
150 compute_units_per_second: u64,
152 requests_enqueued: Arc<AtomicU32>,
154 avg_cost: u64,
156}
157
158impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
159 const fn initial_backoff(&self) -> Duration {
160 Duration::from_millis(self.initial_backoff)
161 }
162}
163
164impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
165where
166 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
167 + Send
168 + 'static
169 + Clone,
170 P: RetryPolicy + Clone + 'static,
171{
172 type Response = ResponsePacket;
173 type Error = TransportError;
174 type Future = TransportFut<'static>;
175
176 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.inner.poll_ready(cx)
180 }
181
182 fn call(&mut self, request: RequestPacket) -> Self::Future {
183 let inner = self.inner.clone();
184 let this = self.clone();
185 let mut inner = std::mem::replace(&mut self.inner, inner);
186 Box::pin(async move {
187 let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
188 let mut rate_limit_retry_number: u32 = 0;
189 loop {
190 let err;
191 let res = inner.call(request.clone()).await;
192
193 match res {
194 Ok(res) => {
195 if let Some(e) = res.as_error() {
196 err = TransportError::ErrorResp(e.clone())
197 } else {
198 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
199 return Ok(res);
200 }
201 }
202 Err(e) => err = e,
203 }
204
205 let should_retry = this.policy.should_retry(&err);
206 if should_retry {
207 rate_limit_retry_number += 1;
208 if rate_limit_retry_number > this.max_rate_limit_retries {
209 return Err(TransportErrorKind::custom_str(&format!(
210 "Max retries exceeded {err}"
211 )));
212 }
213 trace!(%err, "retrying request");
214
215 let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
216
217 let backoff_hint = this.policy.backoff_hint(&err);
220 let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
221
222 let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
223 this.avg_cost,
224 this.compute_units_per_second,
225 current_queued_reqs,
226 ahead_in_queue,
227 );
228 let total_backoff = next_backoff
229 + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
230
231 trace!(
232 total_backoff_millis = total_backoff.as_millis(),
233 budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
234 default_backoff_millis = next_backoff.as_millis(),
235 backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
236 "(all in ms) backing off due to rate limit"
237 );
238
239 sleep(total_backoff).await;
240 } else {
241 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
242 return Err(err);
243 }
244 }
245 })
246 }
247}
248
249fn compute_unit_offset_in_secs(
261 avg_cost: u64,
262 compute_units_per_second: u64,
263 current_queued_requests: u64,
264 ahead_in_queue: u64,
265) -> u64 {
266 let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
267 if current_queued_requests > request_capacity_per_second {
268 current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
269 } else {
270 0
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_compute_units_per_second() {
280 let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
281 assert_eq!(offset, 0);
282 let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
283 assert_eq!(offset, 2);
284 }
285}