scylla/routing/
sharding.rs

1use std::collections::HashMap;
2use std::num::NonZeroU16;
3use std::ops::RangeInclusive;
4
5use rand::Rng as _;
6use thiserror::Error;
7
8use super::Token;
9
10/// A range of ports that can be used for shard-aware connections.
11///
12/// The range is inclusive and has to be a sub-range of [1024, 65535].
13#[derive(Debug, Clone)]
14#[cfg_attr(test, derive(PartialEq, Eq))]
15pub struct ShardAwarePortRange(RangeInclusive<u16>);
16
17impl ShardAwarePortRange {
18    /// The default shard-aware local port range - [49152, 65535].
19    pub const EPHEMERAL_PORT_RANGE: Self = Self(49152..=65535);
20
21    /// Creates a new `ShardAwarePortRange` with the given range.
22    ///
23    /// The error is returned in two cases:
24    /// 1. Provided range is empty (`end` < `start`).
25    /// 2. Provided range starts at a port lower than 1024. Ports 0-1023 are reserved and
26    ///    should not be used by application.
27    #[inline]
28    pub fn new(range: impl Into<RangeInclusive<u16>>) -> Result<Self, InvalidShardAwarePortRange> {
29        let range = range.into();
30        if range.is_empty() || range.start() < &1024 {
31            return Err(InvalidShardAwarePortRange);
32        }
33        Ok(Self(range))
34    }
35}
36
37impl Default for ShardAwarePortRange {
38    fn default() -> Self {
39        Self::EPHEMERAL_PORT_RANGE
40    }
41}
42
43/// An error returned by [`ShardAwarePortRange::new()`].
44#[derive(Debug, Error)]
45#[error("Invalid shard-aware local port range")]
46pub struct InvalidShardAwarePortRange;
47
48pub type Shard = u32;
49pub type ShardCount = NonZeroU16;
50
51#[derive(PartialEq, Eq, Clone, Debug)]
52pub(crate) struct ShardInfo {
53    pub(crate) shard: u16,
54    pub(crate) nr_shards: ShardCount,
55    pub(crate) msb_ignore: u8,
56}
57
58#[derive(PartialEq, Eq, Clone, Debug)]
59pub struct Sharder {
60    pub nr_shards: ShardCount,
61    pub msb_ignore: u8,
62}
63
64impl std::str::FromStr for Token {
65    type Err = std::num::ParseIntError;
66    fn from_str(s: &str) -> Result<Token, std::num::ParseIntError> {
67        Ok(Token { value: s.parse()? })
68    }
69}
70
71impl ShardInfo {
72    pub(crate) fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self {
73        ShardInfo {
74            shard,
75            nr_shards,
76            msb_ignore,
77        }
78    }
79
80    pub(crate) fn get_sharder(&self) -> Sharder {
81        Sharder::new(self.nr_shards, self.msb_ignore)
82    }
83}
84
85impl Sharder {
86    pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self {
87        Sharder {
88            nr_shards,
89            msb_ignore,
90        }
91    }
92
93    pub fn shard_of(&self, token: Token) -> Shard {
94        let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63);
95        biased_token <<= self.msb_ignore;
96        (((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard
97    }
98
99    /// If we connect to Scylla using Scylla's shard aware port, then Scylla assigns a shard to the
100    /// connection based on the source port. This calculates the assigned shard.
101    pub fn shard_of_source_port(&self, source_port: u16) -> Shard {
102        (source_port % self.nr_shards.get()) as Shard
103    }
104
105    /// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`.
106    ///
107    /// The port is chosen from ephemeral port range [49152, 65535].
108    pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
109        self.draw_source_port_for_shard_from_range(
110            shard,
111            &ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
112        )
113    }
114
115    /// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`.
116    ///
117    /// The port is chosen from the provided port range.
118    pub(crate) fn draw_source_port_for_shard_from_range(
119        &self,
120        shard: Shard,
121        port_range: &ShardAwarePortRange,
122    ) -> u16 {
123        assert!(shard < self.nr_shards.get() as u32);
124        let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
125        rand::rng().random_range(
126            (range_start + self.nr_shards.get() - 1)..(range_end - self.nr_shards.get() + 1),
127        ) / self.nr_shards.get()
128            * self.nr_shards.get()
129            + shard as u16
130    }
131
132    /// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`.
133    /// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around.
134    /// Stops once all possible ports have been returned
135    ///
136    /// The ports are chosen from ephemeral port range [49152, 65535].
137    pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> {
138        self.iter_source_ports_for_shard_from_range(
139            shard,
140            &ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
141        )
142    }
143
144    /// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`.
145    /// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around.
146    /// Stops once all possible ports have been returned
147    ///
148    /// The ports are chosen from the provided port range.
149    pub(crate) fn iter_source_ports_for_shard_from_range(
150        &self,
151        shard: Shard,
152        port_range: &ShardAwarePortRange,
153    ) -> impl Iterator<Item = u16> {
154        assert!(shard < self.nr_shards.get() as u32);
155
156        let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
157
158        // Randomly choose a port to start at
159        let starting_port = self.draw_source_port_for_shard_from_range(shard, port_range);
160
161        // Choose smallest available port number to begin at after wrapping
162        // apply the formula from draw_source_port_for_shard for lowest possible gen_range result
163        let first_valid_port = (range_start + self.nr_shards.get() - 1) / self.nr_shards.get()
164            * self.nr_shards.get()
165            + shard as u16;
166
167        let before_wrap = (starting_port..=*range_end).step_by(self.nr_shards.get().into());
168        let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());
169
170        before_wrap.chain(after_wrap)
171    }
172}
173
174#[derive(Clone, Error, Debug)]
175pub(crate) enum ShardingError {
176    /// This indicates that we are most likely connected to a Cassandra cluster.
177    /// Unless, there is some serious bug in Scylla.
178    #[error("Server did not provide any sharding information")]
179    NoShardInfo,
180
181    /// A bug in scylla. Some of the parameters are present, while others are missing.
182    #[error("Missing some sharding info parameters")]
183    MissingSomeShardInfoParameters,
184
185    /// A bug in Scylla. All parameters are present, but some do not contain any values.
186    #[error("Missing some sharding info parameter values")]
187    MissingShardInfoParameterValues,
188
189    /// A bug in Scylla. Number of shards is equal to zero.
190    #[error("Sharding info contains an invalid number of shards (0)")]
191    ZeroShards,
192
193    /// A bug in Scylla. Failed to parse string to number.
194    #[error("Failed to parse a sharding info parameter's value: {0}")]
195    ParseIntError(#[from] std::num::ParseIntError),
196}
197
198const SHARD_ENTRY: &str = "SCYLLA_SHARD";
199const NR_SHARDS_ENTRY: &str = "SCYLLA_NR_SHARDS";
200const MSB_IGNORE_ENTRY: &str = "SCYLLA_SHARDING_IGNORE_MSB";
201
202impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo {
203    type Error = ShardingError;
204    fn try_from(options: &'a HashMap<String, Vec<String>>) -> Result<Self, Self::Error> {
205        let shard_entry = options.get(SHARD_ENTRY);
206        let nr_shards_entry = options.get(NR_SHARDS_ENTRY);
207        let msb_ignore_entry = options.get(MSB_IGNORE_ENTRY);
208
209        // Unwrap entries.
210        let (shard_entry, nr_shards_entry, msb_ignore_entry) =
211            match (shard_entry, nr_shards_entry, msb_ignore_entry) {
212                (Some(shard_entry), Some(nr_shards_entry), Some(msb_ignore_entry)) => {
213                    (shard_entry, nr_shards_entry, msb_ignore_entry)
214                }
215                // All parameters are missing - most likely a Cassandra cluster.
216                (None, None, None) => return Err(ShardingError::NoShardInfo),
217                // At least one of the parameters is present, but some are missing. A bug in Scylla.
218                _ => return Err(ShardingError::MissingSomeShardInfoParameters),
219            };
220
221        // Further unwrap entries (they should be the first entries of their corresponding Vecs).
222        let (Some(shard_entry), Some(nr_shards_entry), Some(msb_ignore_entry)) = (
223            shard_entry.first(),
224            nr_shards_entry.first(),
225            msb_ignore_entry.first(),
226        ) else {
227            return Err(ShardingError::MissingShardInfoParameterValues);
228        };
229
230        let shard = shard_entry.parse::<u16>()?;
231        let nr_shards = nr_shards_entry.parse::<u16>()?;
232        let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?;
233        let msb_ignore = msb_ignore_entry.parse::<u8>()?;
234        Ok(ShardInfo::new(shard, nr_shards, msb_ignore))
235    }
236}
237
238#[cfg(test)]
239impl ShardInfo {
240    pub(crate) fn add_to_options(&self, options: &mut HashMap<String, Vec<String>>) {
241        for (k, v) in [
242            (SHARD_ENTRY, &self.shard as &dyn std::fmt::Display),
243            (NR_SHARDS_ENTRY, &self.nr_shards),
244            (MSB_IGNORE_ENTRY, &self.msb_ignore),
245        ] {
246            options.insert(k.to_owned(), vec![v.to_string()]);
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use crate::routing::{Shard, ShardAwarePortRange};
254    use crate::test_utils::setup_tracing;
255
256    use super::Token;
257    use super::{ShardCount, Sharder};
258    use std::collections::HashSet;
259
260    #[test]
261    fn test_shard_aware_port_range_constructor() {
262        setup_tracing();
263
264        // Test valid range
265        let range = ShardAwarePortRange::new(49152..=65535).unwrap();
266        assert_eq!(range, ShardAwarePortRange::EPHEMERAL_PORT_RANGE);
267
268        // Test invalid range (empty)
269        #[allow(clippy::reversed_empty_ranges)]
270        {
271            assert!(ShardAwarePortRange::new(49152..=49151).is_err());
272        }
273        // Test invalid range (too low)
274        assert!(ShardAwarePortRange::new(0..=65535).is_err());
275    }
276
277    #[test]
278    fn test_shard_of() {
279        setup_tracing();
280        /* Test values taken from the gocql driver.  */
281        let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12);
282        assert_eq!(
283            sharder.shard_of(Token {
284                value: -9219783007514621794
285            }),
286            3
287        );
288        assert_eq!(
289            sharder.shard_of(Token {
290                value: 9222582454147032830
291            }),
292            3
293        );
294    }
295
296    #[test]
297    fn test_iter_source_ports_for_shard() {
298        setup_tracing();
299
300        fn test_helper<F, I>(nr_shards: u16, port_range: ShardAwarePortRange, get_iter: F)
301        where
302            F: Fn(&Sharder, Shard) -> I,
303            I: Iterator<Item = u16>,
304        {
305            let max_port_num = port_range.0.end();
306            let min_port_num = (port_range.0.start() + nr_shards - 1) / nr_shards * nr_shards;
307
308            let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);
309
310            // Test for each shard
311            for shard in 0..nr_shards {
312                // Find lowest port for this shard
313                let mut lowest_port = min_port_num;
314                while lowest_port % nr_shards != shard {
315                    lowest_port += 1;
316                }
317
318                // Find total number of ports the iterator should return
319                let possible_ports_number: usize =
320                    ((max_port_num - lowest_port) / nr_shards + 1).into();
321
322                let port_iter = get_iter(&sharder, shard.into());
323
324                let mut returned_ports: HashSet<u16> = HashSet::new();
325                for port in port_iter {
326                    assert!(!returned_ports.contains(&port)); // No port occurs two times
327                    assert!(port % nr_shards == shard); // Each port maps to this shard
328
329                    returned_ports.insert(port);
330                }
331
332                // Numbers of ports returned matches the expected value
333                assert_eq!(returned_ports.len(), possible_ports_number);
334            }
335        }
336
337        // Test of public method (with default range)
338        {
339            test_helper(
340                4,
341                ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
342                |sharder, shard| sharder.iter_source_ports_for_shard(shard),
343            );
344        }
345
346        // Test of private method with some custom port range.
347        {
348            let port_range = ShardAwarePortRange::new(21371..=42424).unwrap();
349            test_helper(4, port_range.clone(), |sharder, shard| {
350                sharder.iter_source_ports_for_shard_from_range(shard, &port_range)
351            });
352        }
353    }
354}