scylla/policies/load_balancing/
plan.rs

1use rand::{rng, Rng};
2use tracing::error;
3
4use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
5use crate::cluster::ClusterState;
6use crate::routing::Shard;
7
8enum PlanState<'a> {
9    Created,
10    PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements.
11    Picked((NodeRef<'a>, Option<Shard>)),
12    Fallback {
13        iter: FallbackPlan<'a>,
14        target_to_filter_out: (NodeRef<'a>, Option<Shard>),
15    },
16}
17
18/// The list of targets constituting the query plan. Target here is a pair `(NodeRef<'a>, Shard)`.
19///
20/// The plan is partly lazily computed, with the first target computed
21/// eagerly in the first place and the remaining targets computed on-demand
22/// (all at once).
23/// This significantly reduces the allocation overhead on "the happy path"
24/// (when the first target successfully handles the request).
25///
26/// `Plan` implements `Iterator<Item=(NodeRef<'a>, Shard)>` but LoadBalancingPolicy
27/// returns `Option<Shard>` instead of `Shard` both in `pick` and in `fallback`.
28/// `Plan` handles the `None` case by using random shard for a given node.
29/// There is currently no way to configure RNG used by `Plan`.
30/// If you don't want `Plan` to do randomize shards or you want to control the RNG,
31/// use custom LBP that will always return non-`None` shards.
32/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers:
33///
34/// ```
35/// # use std::sync::Arc;
36/// # use scylla::cluster::NodeRef;
37/// # use scylla::cluster::ClusterState;
38/// # use scylla::policies::load_balancing::FallbackPlan;
39/// # use scylla::policies::load_balancing::LoadBalancingPolicy;
40/// # use scylla::policies::load_balancing::RoutingInfo;
41/// # use scylla::routing::Shard;
42///
43/// #[derive(Debug)]
44/// struct NonRandomLBP {
45///     inner: Arc<dyn LoadBalancingPolicy>,
46/// }
47/// impl LoadBalancingPolicy for NonRandomLBP {
48///     fn pick<'a>(
49///         &'a self,
50///         info: &'a RoutingInfo,
51///         cluster: &'a ClusterState,
52///     ) -> Option<(NodeRef<'a>, Option<Shard>)> {
53///         self.inner
54///             .pick(info, cluster)
55///             .map(|(node, shard)| (node, shard.or(Some(0))))
56///     }
57///
58///     fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterState) -> FallbackPlan<'a> {
59///         Box::new(self.inner
60///             .fallback(info, cluster)
61///             .map(|(node, shard)| (node, shard.or(Some(0)))))
62///     }
63///
64///     fn name(&self) -> String {
65///         "NonRandomLBP".to_string()
66///     }
67/// }
68/// ```
69pub struct Plan<'a> {
70    policy: &'a dyn LoadBalancingPolicy,
71    routing_info: &'a RoutingInfo<'a>,
72    cluster: &'a ClusterState,
73
74    state: PlanState<'a>,
75}
76
77impl<'a> Plan<'a> {
78    pub fn new(
79        policy: &'a dyn LoadBalancingPolicy,
80        routing_info: &'a RoutingInfo<'a>,
81        cluster: &'a ClusterState,
82    ) -> Self {
83        Self {
84            policy,
85            routing_info,
86            cluster,
87            state: PlanState::Created,
88        }
89    }
90
91    fn with_random_shard_if_unknown(
92        (node, shard): (NodeRef<'_>, Option<Shard>),
93    ) -> (NodeRef<'_>, Shard) {
94        (
95            node,
96            shard.unwrap_or_else(|| {
97                let nr_shards = node
98                    .sharder()
99                    .map(|sharder| sharder.nr_shards.get())
100                    .unwrap_or(1);
101                rng().random_range(0..nr_shards).into()
102            }),
103        )
104    }
105}
106
107impl<'a> Iterator for Plan<'a> {
108    type Item = (NodeRef<'a>, Shard);
109
110    fn next(&mut self) -> Option<Self::Item> {
111        match &mut self.state {
112            PlanState::Created => {
113                let picked = self.policy.pick(self.routing_info, self.cluster);
114                if let Some(picked) = picked {
115                    self.state = PlanState::Picked(picked);
116                    Some(Self::with_random_shard_if_unknown(picked))
117                } else {
118                    // `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_.
119                    // This, however, does not imply that fallback would return an empty plan, too.
120                    // For instance, as a side effect of LWT optimisation in Default Policy, pick() may return None
121                    // when the primary replica is down. `fallback()` will nevertheless return the remaining replicas,
122                    // if there are such.
123                    let mut iter = self.policy.fallback(self.routing_info, self.cluster);
124                    let first_fallback_node = iter.next();
125                    if let Some(node) = first_fallback_node {
126                        self.state = PlanState::Fallback {
127                            iter,
128                            target_to_filter_out: node,
129                        };
130                        Some(Self::with_random_shard_if_unknown(node))
131                    } else {
132                        error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info);
133                        self.state = PlanState::PickedNone;
134                        None
135                    }
136                }
137            }
138            PlanState::Picked(node) => {
139                self.state = PlanState::Fallback {
140                    iter: self.policy.fallback(self.routing_info, self.cluster),
141                    target_to_filter_out: *node,
142                };
143
144                self.next()
145            }
146            PlanState::Fallback {
147                iter,
148                target_to_filter_out: node_to_filter_out,
149            } => {
150                for node in iter {
151                    if node == *node_to_filter_out {
152                        continue;
153                    } else {
154                        return Some(Self::with_random_shard_if_unknown(node));
155                    }
156                }
157
158                None
159            }
160            PlanState::PickedNone => None,
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use std::{net::SocketAddr, str::FromStr, sync::Arc};
168
169    use crate::{
170        cluster::{Node, NodeAddr},
171        routing::locator::test::{create_locator, mock_metadata_for_token_aware_tests},
172        test_utils::setup_tracing,
173    };
174
175    use super::*;
176
177    fn expected_nodes() -> Vec<(Arc<Node>, Shard)> {
178        vec![(
179            Arc::new(Node::new_for_test(
180                None,
181                Some(NodeAddr::Translatable(
182                    SocketAddr::from_str("127.0.0.1:9042").unwrap(),
183                )),
184                None,
185                None,
186            )),
187            42,
188        )]
189    }
190
191    #[derive(Debug)]
192    struct PickingNonePolicy {
193        expected_nodes: Vec<(Arc<Node>, Shard)>,
194    }
195    impl LoadBalancingPolicy for PickingNonePolicy {
196        fn pick<'a>(
197            &'a self,
198            _query: &'a RoutingInfo,
199            _cluster: &'a ClusterState,
200        ) -> Option<(NodeRef<'a>, Option<Shard>)> {
201            None
202        }
203
204        fn fallback<'a>(
205            &'a self,
206            _query: &'a RoutingInfo,
207            _cluster: &'a ClusterState,
208        ) -> FallbackPlan<'a> {
209            Box::new(
210                self.expected_nodes
211                    .iter()
212                    .map(|(node_ref, shard)| (node_ref, Some(*shard))),
213            )
214        }
215
216        fn name(&self) -> String {
217            "PickingNone".into()
218        }
219    }
220
221    #[tokio::test]
222    async fn plan_calls_fallback_even_if_pick_returned_none() {
223        setup_tracing();
224        let policy = PickingNonePolicy {
225            expected_nodes: expected_nodes(),
226        };
227        let locator = create_locator(&mock_metadata_for_token_aware_tests());
228        let cluster_state = ClusterState {
229            known_peers: Default::default(),
230            all_nodes: Default::default(),
231            keyspaces: Default::default(),
232            locator,
233        };
234        let routing_info = RoutingInfo::default();
235        let plan = Plan::new(&policy, &routing_info, &cluster_state);
236        assert_eq!(
237            Vec::from_iter(plan.map(|(node, shard)| (node.clone(), shard))),
238            policy.expected_nodes
239        );
240    }
241}