alloy_transport/layers/
fallback.rs

1use std::{
2    collections::VecDeque,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::Arc,
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use alloy_json_rpc::{RequestPacket, ResponsePacket};
11use derive_more::{Deref, DerefMut};
12use futures::{stream::FuturesUnordered, StreamExt};
13use parking_lot::RwLock;
14use tower::{Layer, Service};
15use tracing::trace;
16
17use crate::{TransportError, TransportErrorKind, TransportFut};
18
19// Constants for the transport ranking algorithm
20const STABILITY_WEIGHT: f64 = 0.7;
21const LATENCY_WEIGHT: f64 = 0.3;
22const DEFAULT_SAMPLE_COUNT: usize = 10;
23const DEFAULT_ACTIVE_TRANSPORT_COUNT: usize = 3;
24
25/// The [`FallbackService`] consumes multiple transports and is able to
26/// query them in parallel, returning the first successful response.
27///
28/// The service ranks transports based on latency and stability metrics,
29/// and will attempt to always use the best available transports.
30#[derive(Debug, Clone)]
31pub struct FallbackService<S> {
32    /// The list of transports to use
33    transports: Arc<Vec<ScoredTransport<S>>>,
34    /// The maximum number of transports to use in parallel
35    active_transport_count: usize,
36}
37
38impl<S: Clone> FallbackService<S> {
39    /// Create a new fallback service from a list of transports.
40    ///
41    /// The `active_transport_count` parameter controls how many transports are used for requests
42    /// at any one time.
43    pub fn new(transports: Vec<S>, active_transport_count: usize) -> Self {
44        let scored_transports = transports
45            .into_iter()
46            .enumerate()
47            .map(|(id, transport)| ScoredTransport::new(id, transport))
48            .collect::<Vec<_>>();
49
50        Self { transports: Arc::new(scored_transports), active_transport_count }
51    }
52
53    /// Log the current ranking of transports
54    fn log_transport_rankings(&self)
55    where
56        S: Debug,
57    {
58        let mut transports = (*self.transports).clone();
59        transports.sort_by(|a, b| b.cmp(a));
60
61        trace!(
62            target: "alloy_fallback_transport_rankings",
63            "Current transport rankings:"
64        );
65        for (idx, transport) in transports.iter().enumerate() {
66            trace!(
67                target: "alloy_fallback_transport_rankings",
68                "  #{}: Transport[{}] - {}", idx + 1, transport.id, transport.metrics_summary()
69            );
70        }
71    }
72}
73
74impl<S> FallbackService<S>
75where
76    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
77        + Send
78        + Clone
79        + Debug
80        + 'static,
81{
82    /// Make a request to the fallback service middleware.
83    ///
84    /// Here is a high-level overview of how requests are handled:
85    ///
86    /// - At the start of each request, we sort transports by score
87    /// - We take the top `self.active_transport_count` and call them in parallel
88    /// - If any of them succeeds, we update the transport scores and return the response
89    /// - If all transports fail, we update the scores and return the last error that occurred
90    ///
91    /// This strategy allows us to always make requests to the best available transports
92    /// while keeping them available.
93    async fn make_request(&self, req: RequestPacket) -> Result<ResponsePacket, TransportError> {
94        // Get the top transports to use for this request
95        let top_transports = {
96            // Clone the vec, sort it, and take the top `self.active_transport_count`
97            let mut transports_clone = (*self.transports).clone();
98            transports_clone.sort_by(|a, b| b.cmp(a));
99            transports_clone.into_iter().take(self.active_transport_count).collect::<Vec<_>>()
100        };
101
102        // Create a collection of future requests
103        let mut futures = FuturesUnordered::new();
104
105        // Launch requests to all active transports in parallel
106        for transport in top_transports {
107            let req_clone = req.clone();
108            let mut transport_clone = transport.clone();
109
110            let future = async move {
111                let start = Instant::now();
112                let result = transport_clone.call(req_clone).await;
113                trace!(
114                    "Transport[{}] completed: latency={:?}, status={}",
115                    transport_clone.id,
116                    start.elapsed(),
117                    if result.is_ok() { "success" } else { "fail" }
118                );
119
120                (result, transport_clone, start.elapsed())
121            };
122
123            futures.push(future);
124        }
125
126        // Wait for the first successful response or until all fail
127        let mut last_error = None;
128
129        while let Some((result, transport, duration)) = futures.next().await {
130            match result {
131                Ok(response) => {
132                    // Record success
133                    transport.track_success(duration);
134
135                    self.log_transport_rankings();
136
137                    return Ok(response);
138                }
139                Err(error) => {
140                    // Record failure
141                    transport.track_failure();
142
143                    last_error = Some(error);
144                }
145            }
146        }
147
148        Err(last_error.unwrap_or_else(|| {
149            TransportErrorKind::custom_str("All transport futures failed to complete")
150        }))
151    }
152}
153
154impl<S> Service<RequestPacket> for FallbackService<S>
155where
156    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
157        + Send
158        + Sync
159        + Clone
160        + Debug
161        + 'static,
162{
163    type Response = ResponsePacket;
164    type Error = TransportError;
165    type Future = TransportFut<'static>;
166
167    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        // Service is always ready
169        Poll::Ready(Ok(()))
170    }
171
172    fn call(&mut self, req: RequestPacket) -> Self::Future {
173        let this = self.clone();
174        Box::pin(async move { this.make_request(req).await })
175    }
176}
177
178/// Fallback layer for transparent transport failover. This layer will
179/// consume a list of transports to provide better availability and
180/// reliability.
181///
182/// The [`FallbackService`] will attempt to make requests to multiple
183/// transports in parallel, and return the first successful response.
184///
185/// If all transports fail, the fallback service will return an error.
186///
187/// # Automatic Transport Ranking
188///
189/// Each transport is automatically ranked based on latency & stability
190/// using a weighted algorithm. By default:
191///
192/// - Stability (success rate) is weighted at 70%
193/// - Latency (response time) is weighted at 30%
194/// - The `active_transport_count` parameter controls how many transports are queried at any one
195///   time.
196#[derive(Debug, Clone)]
197pub struct FallbackLayer {
198    /// The maximum number of transports to use in parallel
199    active_transport_count: usize,
200}
201
202impl FallbackLayer {
203    /// Set the number of active transports to use (must be greater than 0)
204    pub const fn with_active_transport_count(mut self, count: NonZeroUsize) -> Self {
205        self.active_transport_count = count.get();
206        self
207    }
208}
209
210impl<S> Layer<Vec<S>> for FallbackLayer
211where
212    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
213        + Send
214        + Clone
215        + Debug
216        + 'static,
217{
218    type Service = FallbackService<S>;
219
220    fn layer(&self, inner: Vec<S>) -> Self::Service {
221        FallbackService::new(inner, self.active_transport_count)
222    }
223}
224
225impl Default for FallbackLayer {
226    fn default() -> Self {
227        Self { active_transport_count: DEFAULT_ACTIVE_TRANSPORT_COUNT }
228    }
229}
230
231/// A scored transport that can be ordered in a heap.
232///
233/// The transport is scored every time it is used according to
234/// a simple weighted algorithm that favors latency and stability.
235///
236/// The score is calculated as follows (by default):
237///
238/// - Stability (success rate) is weighted at 70%
239/// - Latency (response time) is weighted at 30%
240///
241/// The score is then used to determine which transport to use next in
242/// the [`FallbackService`].
243#[derive(Debug, Clone, Deref, DerefMut)]
244struct ScoredTransport<S> {
245    /// The transport itself
246    #[deref]
247    #[deref_mut]
248    transport: S,
249    /// Unique identifier for the transport
250    id: usize,
251    /// Metrics for the transport
252    metrics: Arc<RwLock<TransportMetrics>>,
253}
254
255impl<S> ScoredTransport<S> {
256    /// Create a new scored transport
257    fn new(id: usize, transport: S) -> Self {
258        Self { id, transport, metrics: Arc::new(Default::default()) }
259    }
260
261    /// Returns the current score of the transport based on the weighted algorithm.
262    fn score(&self) -> f64 {
263        let metrics = self.metrics.read();
264        metrics.calculate_score()
265    }
266
267    /// Get metrics summary for debugging
268    fn metrics_summary(&self) -> String {
269        let metrics = self.metrics.read();
270        metrics.get_summary()
271    }
272
273    /// Track a successful request and its latency.
274    fn track_success(&self, duration: Duration) {
275        let mut metrics = self.metrics.write();
276        metrics.track_success(duration);
277    }
278
279    /// Track a failed request.
280    fn track_failure(&self) {
281        let mut metrics = self.metrics.write();
282        metrics.track_failure();
283    }
284}
285
286impl<S> PartialEq for ScoredTransport<S> {
287    fn eq(&self, other: &Self) -> bool {
288        self.score().eq(&other.score())
289    }
290}
291
292impl<S> Eq for ScoredTransport<S> {}
293
294#[expect(clippy::non_canonical_partial_ord_impl)]
295impl<S> PartialOrd for ScoredTransport<S> {
296    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297        self.score().partial_cmp(&other.score())
298    }
299}
300
301impl<S> Ord for ScoredTransport<S> {
302    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303        self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
304    }
305}
306
307/// Represents performance metrics for a transport.
308#[derive(Debug)]
309struct TransportMetrics {
310    // Latency history - tracks last N responses
311    latencies: VecDeque<Duration>,
312    // Success history - tracks last N successes (true) or failures (false)
313    successes: VecDeque<bool>,
314    // Last time this transport was checked/used
315    last_update: Instant,
316    // Total number of requests made to this transport
317    total_requests: u64,
318    // Total number of successful requests
319    successful_requests: u64,
320}
321
322impl TransportMetrics {
323    /// Track a successful request and its latency.
324    fn track_success(&mut self, duration: Duration) {
325        self.total_requests += 1;
326        self.successful_requests += 1;
327        self.last_update = Instant::now();
328
329        // Add to sample windows
330        self.latencies.push_back(duration);
331        self.successes.push_back(true);
332
333        // Limit to sample count
334        while self.latencies.len() > DEFAULT_SAMPLE_COUNT {
335            self.latencies.pop_front();
336        }
337        while self.successes.len() > DEFAULT_SAMPLE_COUNT {
338            self.successes.pop_front();
339        }
340    }
341
342    /// Track a failed request.
343    fn track_failure(&mut self) {
344        self.total_requests += 1;
345        self.last_update = Instant::now();
346
347        // Add to sample windows (no latency for failures)
348        self.successes.push_back(false);
349
350        // Limit to sample count
351        while self.successes.len() > DEFAULT_SAMPLE_COUNT {
352            self.successes.pop_front();
353        }
354    }
355
356    /// Calculate weighted score based on stability and latency
357    fn calculate_score(&self) -> f64 {
358        // If no data yet, return initial neutral score
359        if self.successes.is_empty() {
360            return 0.0;
361        }
362
363        // Calculate stability score (percentage of successful requests)
364        let success_count = self.successes.iter().filter(|&&s| s).count();
365        let stability_score = success_count as f64 / self.successes.len() as f64;
366
367        // Calculate latency score (lower is better)
368        let latency_score = if !self.latencies.is_empty() {
369            let avg_latency = self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
370                / self.latencies.len() as f64;
371
372            // Normalize latency score (1.0 for 0ms, approaches 0.0 as latency increases)
373            1.0 / (1.0 + avg_latency)
374        } else {
375            0.0
376        };
377
378        // Apply weights to calculate final score
379        (stability_score * STABILITY_WEIGHT) + (latency_score * LATENCY_WEIGHT)
380    }
381
382    /// Get a summary of metrics for debugging
383    fn get_summary(&self) -> String {
384        let success_rate = if !self.successes.is_empty() {
385            let success_count = self.successes.iter().filter(|&&s| s).count();
386            success_count as f64 / self.successes.len() as f64
387        } else {
388            0.0
389        };
390
391        let avg_latency = if !self.latencies.is_empty() {
392            self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
393                / self.latencies.len() as f64
394        } else {
395            0.0
396        };
397
398        format!(
399            "success_rate: {:.2}%, avg_latency: {:.2}ms, samples: {}, score: {:.4}",
400            success_rate * 100.0,
401            avg_latency * 1000.0,
402            self.successes.len(),
403            self.calculate_score()
404        )
405    }
406}
407
408impl Default for TransportMetrics {
409    fn default() -> Self {
410        Self {
411            latencies: VecDeque::new(),
412            successes: VecDeque::new(),
413            last_update: Instant::now(),
414            total_requests: 0,
415            successful_requests: 0,
416        }
417    }
418}