1use std::{
2 collections::VecDeque,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::Arc,
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9
10use alloy_json_rpc::{RequestPacket, ResponsePacket};
11use derive_more::{Deref, DerefMut};
12use futures::{stream::FuturesUnordered, StreamExt};
13use parking_lot::RwLock;
14use tower::{Layer, Service};
15use tracing::trace;
16
17use crate::{TransportError, TransportErrorKind, TransportFut};
18
19const STABILITY_WEIGHT: f64 = 0.7;
21const LATENCY_WEIGHT: f64 = 0.3;
22const DEFAULT_SAMPLE_COUNT: usize = 10;
23const DEFAULT_ACTIVE_TRANSPORT_COUNT: usize = 3;
24
25#[derive(Debug, Clone)]
31pub struct FallbackService<S> {
32 transports: Arc<Vec<ScoredTransport<S>>>,
34 active_transport_count: usize,
36}
37
38impl<S: Clone> FallbackService<S> {
39 pub fn new(transports: Vec<S>, active_transport_count: usize) -> Self {
44 let scored_transports = transports
45 .into_iter()
46 .enumerate()
47 .map(|(id, transport)| ScoredTransport::new(id, transport))
48 .collect::<Vec<_>>();
49
50 Self { transports: Arc::new(scored_transports), active_transport_count }
51 }
52
53 fn log_transport_rankings(&self)
55 where
56 S: Debug,
57 {
58 let mut transports = (*self.transports).clone();
59 transports.sort_by(|a, b| b.cmp(a));
60
61 trace!(
62 target: "alloy_fallback_transport_rankings",
63 "Current transport rankings:"
64 );
65 for (idx, transport) in transports.iter().enumerate() {
66 trace!(
67 target: "alloy_fallback_transport_rankings",
68 " #{}: Transport[{}] - {}", idx + 1, transport.id, transport.metrics_summary()
69 );
70 }
71 }
72}
73
74impl<S> FallbackService<S>
75where
76 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
77 + Send
78 + Clone
79 + Debug
80 + 'static,
81{
82 async fn make_request(&self, req: RequestPacket) -> Result<ResponsePacket, TransportError> {
94 let top_transports = {
96 let mut transports_clone = (*self.transports).clone();
98 transports_clone.sort_by(|a, b| b.cmp(a));
99 transports_clone.into_iter().take(self.active_transport_count).collect::<Vec<_>>()
100 };
101
102 let mut futures = FuturesUnordered::new();
104
105 for transport in top_transports {
107 let req_clone = req.clone();
108 let mut transport_clone = transport.clone();
109
110 let future = async move {
111 let start = Instant::now();
112 let result = transport_clone.call(req_clone).await;
113 trace!(
114 "Transport[{}] completed: latency={:?}, status={}",
115 transport_clone.id,
116 start.elapsed(),
117 if result.is_ok() { "success" } else { "fail" }
118 );
119
120 (result, transport_clone, start.elapsed())
121 };
122
123 futures.push(future);
124 }
125
126 let mut last_error = None;
128
129 while let Some((result, transport, duration)) = futures.next().await {
130 match result {
131 Ok(response) => {
132 transport.track_success(duration);
134
135 self.log_transport_rankings();
136
137 return Ok(response);
138 }
139 Err(error) => {
140 transport.track_failure();
142
143 last_error = Some(error);
144 }
145 }
146 }
147
148 Err(last_error.unwrap_or_else(|| {
149 TransportErrorKind::custom_str("All transport futures failed to complete")
150 }))
151 }
152}
153
154impl<S> Service<RequestPacket> for FallbackService<S>
155where
156 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
157 + Send
158 + Sync
159 + Clone
160 + Debug
161 + 'static,
162{
163 type Response = ResponsePacket;
164 type Error = TransportError;
165 type Future = TransportFut<'static>;
166
167 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 Poll::Ready(Ok(()))
170 }
171
172 fn call(&mut self, req: RequestPacket) -> Self::Future {
173 let this = self.clone();
174 Box::pin(async move { this.make_request(req).await })
175 }
176}
177
178#[derive(Debug, Clone)]
197pub struct FallbackLayer {
198 active_transport_count: usize,
200}
201
202impl FallbackLayer {
203 pub const fn with_active_transport_count(mut self, count: NonZeroUsize) -> Self {
205 self.active_transport_count = count.get();
206 self
207 }
208}
209
210impl<S> Layer<Vec<S>> for FallbackLayer
211where
212 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
213 + Send
214 + Clone
215 + Debug
216 + 'static,
217{
218 type Service = FallbackService<S>;
219
220 fn layer(&self, inner: Vec<S>) -> Self::Service {
221 FallbackService::new(inner, self.active_transport_count)
222 }
223}
224
225impl Default for FallbackLayer {
226 fn default() -> Self {
227 Self { active_transport_count: DEFAULT_ACTIVE_TRANSPORT_COUNT }
228 }
229}
230
231#[derive(Debug, Clone, Deref, DerefMut)]
244struct ScoredTransport<S> {
245 #[deref]
247 #[deref_mut]
248 transport: S,
249 id: usize,
251 metrics: Arc<RwLock<TransportMetrics>>,
253}
254
255impl<S> ScoredTransport<S> {
256 fn new(id: usize, transport: S) -> Self {
258 Self { id, transport, metrics: Arc::new(Default::default()) }
259 }
260
261 fn score(&self) -> f64 {
263 let metrics = self.metrics.read();
264 metrics.calculate_score()
265 }
266
267 fn metrics_summary(&self) -> String {
269 let metrics = self.metrics.read();
270 metrics.get_summary()
271 }
272
273 fn track_success(&self, duration: Duration) {
275 let mut metrics = self.metrics.write();
276 metrics.track_success(duration);
277 }
278
279 fn track_failure(&self) {
281 let mut metrics = self.metrics.write();
282 metrics.track_failure();
283 }
284}
285
286impl<S> PartialEq for ScoredTransport<S> {
287 fn eq(&self, other: &Self) -> bool {
288 self.score().eq(&other.score())
289 }
290}
291
292impl<S> Eq for ScoredTransport<S> {}
293
294#[expect(clippy::non_canonical_partial_ord_impl)]
295impl<S> PartialOrd for ScoredTransport<S> {
296 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297 self.score().partial_cmp(&other.score())
298 }
299}
300
301impl<S> Ord for ScoredTransport<S> {
302 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
304 }
305}
306
307#[derive(Debug)]
309struct TransportMetrics {
310 latencies: VecDeque<Duration>,
312 successes: VecDeque<bool>,
314 last_update: Instant,
316 total_requests: u64,
318 successful_requests: u64,
320}
321
322impl TransportMetrics {
323 fn track_success(&mut self, duration: Duration) {
325 self.total_requests += 1;
326 self.successful_requests += 1;
327 self.last_update = Instant::now();
328
329 self.latencies.push_back(duration);
331 self.successes.push_back(true);
332
333 while self.latencies.len() > DEFAULT_SAMPLE_COUNT {
335 self.latencies.pop_front();
336 }
337 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
338 self.successes.pop_front();
339 }
340 }
341
342 fn track_failure(&mut self) {
344 self.total_requests += 1;
345 self.last_update = Instant::now();
346
347 self.successes.push_back(false);
349
350 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
352 self.successes.pop_front();
353 }
354 }
355
356 fn calculate_score(&self) -> f64 {
358 if self.successes.is_empty() {
360 return 0.0;
361 }
362
363 let success_count = self.successes.iter().filter(|&&s| s).count();
365 let stability_score = success_count as f64 / self.successes.len() as f64;
366
367 let latency_score = if !self.latencies.is_empty() {
369 let avg_latency = self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
370 / self.latencies.len() as f64;
371
372 1.0 / (1.0 + avg_latency)
374 } else {
375 0.0
376 };
377
378 (stability_score * STABILITY_WEIGHT) + (latency_score * LATENCY_WEIGHT)
380 }
381
382 fn get_summary(&self) -> String {
384 let success_rate = if !self.successes.is_empty() {
385 let success_count = self.successes.iter().filter(|&&s| s).count();
386 success_count as f64 / self.successes.len() as f64
387 } else {
388 0.0
389 };
390
391 let avg_latency = if !self.latencies.is_empty() {
392 self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
393 / self.latencies.len() as f64
394 } else {
395 0.0
396 };
397
398 format!(
399 "success_rate: {:.2}%, avg_latency: {:.2}ms, samples: {}, score: {:.4}",
400 success_rate * 100.0,
401 avg_latency * 1000.0,
402 self.successes.len(),
403 self.calculate_score()
404 )
405 }
406}
407
408impl Default for TransportMetrics {
409 fn default() -> Self {
410 Self {
411 latencies: VecDeque::new(),
412 successes: VecDeque::new(),
413 last_update: Instant::now(),
414 total_requests: 0,
415 successful_requests: 0,
416 }
417 }
418}