scylla/policies/
speculative_execution.rs1use futures::{
2 future::FutureExt,
3 stream::{FuturesUnordered, StreamExt},
4};
5#[cfg(feature = "metrics")]
6use std::sync::Arc;
7use std::{future::Future, time::Duration};
8use tracing::{trace_span, Instrument};
9
10use crate::errors::{RequestAttemptError, RequestError};
11#[cfg(feature = "metrics")]
12use crate::observability::metrics::Metrics;
13
14#[non_exhaustive]
16pub struct Context {
17 #[cfg(feature = "metrics")]
18 pub metrics: Arc<Metrics>,
19}
20
21pub trait SpeculativeExecutionPolicy: std::fmt::Debug + Send + Sync {
24 fn max_retry_count(&self, context: &Context) -> usize;
27
28 fn retry_interval(&self, context: &Context) -> Duration;
30}
31
32#[derive(Debug, Clone)]
35pub struct SimpleSpeculativeExecutionPolicy {
36 pub max_retry_count: usize,
39
40 pub retry_interval: Duration,
42}
43
44#[cfg(feature = "metrics")]
47#[derive(Debug, Clone)]
48pub struct PercentileSpeculativeExecutionPolicy {
49 pub max_retry_count: usize,
52
53 pub percentile: f64,
56}
57
58impl SpeculativeExecutionPolicy for SimpleSpeculativeExecutionPolicy {
59 fn max_retry_count(&self, _: &Context) -> usize {
60 self.max_retry_count
61 }
62
63 fn retry_interval(&self, _: &Context) -> Duration {
64 self.retry_interval
65 }
66}
67
68#[cfg(feature = "metrics")]
69impl SpeculativeExecutionPolicy for PercentileSpeculativeExecutionPolicy {
70 fn max_retry_count(&self, _: &Context) -> usize {
71 self.max_retry_count
72 }
73
74 fn retry_interval(&self, context: &Context) -> Duration {
75 let interval = context.metrics.get_latency_percentile_ms(self.percentile);
76 let ms = match interval {
77 Ok(d) => d,
78 Err(e) => {
79 tracing::warn!(
80 "Failed to get latency percentile ({}), defaulting to 100 ms",
81 e
82 );
83 100
84 }
85 };
86 Duration::from_millis(ms)
87 }
88}
89
90fn can_be_ignored<ResT>(result: &Result<ResT, RequestError>) -> bool {
95 match result {
96 Ok(_) => false,
97 #[deny(clippy::wildcard_enum_match_arm)]
101 Err(e) => match e {
102 RequestError::EmptyPlan => false,
105
106 RequestError::RequestTimeout(_) => false,
108
109 RequestError::ConnectionPoolError { .. } => true,
111
112 RequestError::LastAttemptError(e) => {
113 #[deny(clippy::wildcard_enum_match_arm)]
117 match e {
118 RequestAttemptError::SerializationError(_)
120 | RequestAttemptError::CqlRequestSerialization(_)
121 | RequestAttemptError::BodyExtensionsParseError(_)
122 | RequestAttemptError::CqlResultParseError(_)
123 | RequestAttemptError::CqlErrorParseError(_)
124 | RequestAttemptError::UnexpectedResponse(_)
125 | RequestAttemptError::RepreparedIdChanged { .. }
126 | RequestAttemptError::RepreparedIdMissingInBatch
127 | RequestAttemptError::NonfinishedPagingState => false,
128
129 RequestAttemptError::BrokenConnectionError(_)
131 | RequestAttemptError::UnableToAllocStreamId => true,
132
133 RequestAttemptError::DbError(db_error, _) => db_error.can_speculative_retry(),
135 }
136 }
137 },
138 }
139}
140
141const EMPTY_PLAN_ERROR: RequestError = RequestError::EmptyPlan;
142
143pub(crate) async fn execute<QueryFut, ResT>(
144 policy: &dyn SpeculativeExecutionPolicy,
145 context: &Context,
146 query_runner_generator: impl Fn(bool) -> QueryFut,
147) -> Result<ResT, RequestError>
148where
149 QueryFut: Future<Output = Option<Result<ResT, RequestError>>>,
150{
151 let mut retries_remaining = policy.max_retry_count(context);
152 let retry_interval = policy.retry_interval(context);
153
154 let mut async_tasks = FuturesUnordered::new();
155 async_tasks.push(
156 query_runner_generator(false)
157 .instrument(trace_span!("Speculative execution: original query")),
158 );
159
160 let sleep = tokio::time::sleep(retry_interval).fuse();
161 tokio::pin!(sleep);
162
163 let mut last_error = None;
164 loop {
165 futures::select! {
166 _ = &mut sleep => {
167 if retries_remaining > 0 {
168 async_tasks.push(query_runner_generator(true).instrument(trace_span!("Speculative execution", retries_remaining = retries_remaining)));
169 retries_remaining -= 1;
170
171 sleep.set(tokio::time::sleep(retry_interval).fuse());
173 }
174 }
175 res = async_tasks.select_next_some() => {
176 if let Some(r) = res {
177 if !can_be_ignored(&r) {
178 return r;
179 } else {
180 last_error = Some(r)
181 }
182 }
183 if async_tasks.is_empty() && retries_remaining == 0 {
184 return last_error.unwrap_or({
185 Err(EMPTY_PLAN_ERROR)
186 });
187 }
188 }
189 }
190 }
191}