1use std::collections::HashMap;
2use std::num::NonZeroU16;
3use std::ops::RangeInclusive;
4
5use rand::Rng as _;
6use thiserror::Error;
7
8use super::Token;
9
10#[derive(Debug, Clone)]
14#[cfg_attr(test, derive(PartialEq, Eq))]
15pub struct ShardAwarePortRange(RangeInclusive<u16>);
16
17impl ShardAwarePortRange {
18 pub const EPHEMERAL_PORT_RANGE: Self = Self(49152..=65535);
20
21 #[inline]
28 pub fn new(range: impl Into<RangeInclusive<u16>>) -> Result<Self, InvalidShardAwarePortRange> {
29 let range = range.into();
30 if range.is_empty() || range.start() < &1024 {
31 return Err(InvalidShardAwarePortRange);
32 }
33 Ok(Self(range))
34 }
35}
36
37impl Default for ShardAwarePortRange {
38 fn default() -> Self {
39 Self::EPHEMERAL_PORT_RANGE
40 }
41}
42
43#[derive(Debug, Error)]
45#[error("Invalid shard-aware local port range")]
46pub struct InvalidShardAwarePortRange;
47
48pub type Shard = u32;
49pub type ShardCount = NonZeroU16;
50
51#[derive(PartialEq, Eq, Clone, Debug)]
52pub(crate) struct ShardInfo {
53 pub(crate) shard: u16,
54 pub(crate) nr_shards: ShardCount,
55 pub(crate) msb_ignore: u8,
56}
57
58#[derive(PartialEq, Eq, Clone, Debug)]
59pub struct Sharder {
60 pub nr_shards: ShardCount,
61 pub msb_ignore: u8,
62}
63
64impl std::str::FromStr for Token {
65 type Err = std::num::ParseIntError;
66 fn from_str(s: &str) -> Result<Token, std::num::ParseIntError> {
67 Ok(Token { value: s.parse()? })
68 }
69}
70
71impl ShardInfo {
72 pub(crate) fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self {
73 ShardInfo {
74 shard,
75 nr_shards,
76 msb_ignore,
77 }
78 }
79
80 pub(crate) fn get_sharder(&self) -> Sharder {
81 Sharder::new(self.nr_shards, self.msb_ignore)
82 }
83}
84
85impl Sharder {
86 pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self {
87 Sharder {
88 nr_shards,
89 msb_ignore,
90 }
91 }
92
93 pub fn shard_of(&self, token: Token) -> Shard {
94 let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63);
95 biased_token <<= self.msb_ignore;
96 (((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard
97 }
98
99 pub fn shard_of_source_port(&self, source_port: u16) -> Shard {
102 (source_port % self.nr_shards.get()) as Shard
103 }
104
105 pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
109 self.draw_source_port_for_shard_from_range(
110 shard,
111 &ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
112 )
113 }
114
115 pub(crate) fn draw_source_port_for_shard_from_range(
119 &self,
120 shard: Shard,
121 port_range: &ShardAwarePortRange,
122 ) -> u16 {
123 assert!(shard < self.nr_shards.get() as u32);
124 let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
125 rand::rng().random_range(
126 (range_start + self.nr_shards.get() - 1)..(range_end - self.nr_shards.get() + 1),
127 ) / self.nr_shards.get()
128 * self.nr_shards.get()
129 + shard as u16
130 }
131
132 pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> {
138 self.iter_source_ports_for_shard_from_range(
139 shard,
140 &ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
141 )
142 }
143
144 pub(crate) fn iter_source_ports_for_shard_from_range(
150 &self,
151 shard: Shard,
152 port_range: &ShardAwarePortRange,
153 ) -> impl Iterator<Item = u16> {
154 assert!(shard < self.nr_shards.get() as u32);
155
156 let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
157
158 let starting_port = self.draw_source_port_for_shard_from_range(shard, port_range);
160
161 let first_valid_port = (range_start + self.nr_shards.get() - 1) / self.nr_shards.get()
164 * self.nr_shards.get()
165 + shard as u16;
166
167 let before_wrap = (starting_port..=*range_end).step_by(self.nr_shards.get().into());
168 let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());
169
170 before_wrap.chain(after_wrap)
171 }
172}
173
174#[derive(Clone, Error, Debug)]
175pub(crate) enum ShardingError {
176 #[error("Server did not provide any sharding information")]
179 NoShardInfo,
180
181 #[error("Missing some sharding info parameters")]
183 MissingSomeShardInfoParameters,
184
185 #[error("Missing some sharding info parameter values")]
187 MissingShardInfoParameterValues,
188
189 #[error("Sharding info contains an invalid number of shards (0)")]
191 ZeroShards,
192
193 #[error("Failed to parse a sharding info parameter's value: {0}")]
195 ParseIntError(#[from] std::num::ParseIntError),
196}
197
198const SHARD_ENTRY: &str = "SCYLLA_SHARD";
199const NR_SHARDS_ENTRY: &str = "SCYLLA_NR_SHARDS";
200const MSB_IGNORE_ENTRY: &str = "SCYLLA_SHARDING_IGNORE_MSB";
201
202impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo {
203 type Error = ShardingError;
204 fn try_from(options: &'a HashMap<String, Vec<String>>) -> Result<Self, Self::Error> {
205 let shard_entry = options.get(SHARD_ENTRY);
206 let nr_shards_entry = options.get(NR_SHARDS_ENTRY);
207 let msb_ignore_entry = options.get(MSB_IGNORE_ENTRY);
208
209 let (shard_entry, nr_shards_entry, msb_ignore_entry) =
211 match (shard_entry, nr_shards_entry, msb_ignore_entry) {
212 (Some(shard_entry), Some(nr_shards_entry), Some(msb_ignore_entry)) => {
213 (shard_entry, nr_shards_entry, msb_ignore_entry)
214 }
215 (None, None, None) => return Err(ShardingError::NoShardInfo),
217 _ => return Err(ShardingError::MissingSomeShardInfoParameters),
219 };
220
221 let (Some(shard_entry), Some(nr_shards_entry), Some(msb_ignore_entry)) = (
223 shard_entry.first(),
224 nr_shards_entry.first(),
225 msb_ignore_entry.first(),
226 ) else {
227 return Err(ShardingError::MissingShardInfoParameterValues);
228 };
229
230 let shard = shard_entry.parse::<u16>()?;
231 let nr_shards = nr_shards_entry.parse::<u16>()?;
232 let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?;
233 let msb_ignore = msb_ignore_entry.parse::<u8>()?;
234 Ok(ShardInfo::new(shard, nr_shards, msb_ignore))
235 }
236}
237
238#[cfg(test)]
239impl ShardInfo {
240 pub(crate) fn add_to_options(&self, options: &mut HashMap<String, Vec<String>>) {
241 for (k, v) in [
242 (SHARD_ENTRY, &self.shard as &dyn std::fmt::Display),
243 (NR_SHARDS_ENTRY, &self.nr_shards),
244 (MSB_IGNORE_ENTRY, &self.msb_ignore),
245 ] {
246 options.insert(k.to_owned(), vec![v.to_string()]);
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use crate::routing::{Shard, ShardAwarePortRange};
254 use crate::test_utils::setup_tracing;
255
256 use super::Token;
257 use super::{ShardCount, Sharder};
258 use std::collections::HashSet;
259
260 #[test]
261 fn test_shard_aware_port_range_constructor() {
262 setup_tracing();
263
264 let range = ShardAwarePortRange::new(49152..=65535).unwrap();
266 assert_eq!(range, ShardAwarePortRange::EPHEMERAL_PORT_RANGE);
267
268 #[allow(clippy::reversed_empty_ranges)]
270 {
271 assert!(ShardAwarePortRange::new(49152..=49151).is_err());
272 }
273 assert!(ShardAwarePortRange::new(0..=65535).is_err());
275 }
276
277 #[test]
278 fn test_shard_of() {
279 setup_tracing();
280 let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12);
282 assert_eq!(
283 sharder.shard_of(Token {
284 value: -9219783007514621794
285 }),
286 3
287 );
288 assert_eq!(
289 sharder.shard_of(Token {
290 value: 9222582454147032830
291 }),
292 3
293 );
294 }
295
296 #[test]
297 fn test_iter_source_ports_for_shard() {
298 setup_tracing();
299
300 fn test_helper<F, I>(nr_shards: u16, port_range: ShardAwarePortRange, get_iter: F)
301 where
302 F: Fn(&Sharder, Shard) -> I,
303 I: Iterator<Item = u16>,
304 {
305 let max_port_num = port_range.0.end();
306 let min_port_num = (port_range.0.start() + nr_shards - 1) / nr_shards * nr_shards;
307
308 let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);
309
310 for shard in 0..nr_shards {
312 let mut lowest_port = min_port_num;
314 while lowest_port % nr_shards != shard {
315 lowest_port += 1;
316 }
317
318 let possible_ports_number: usize =
320 ((max_port_num - lowest_port) / nr_shards + 1).into();
321
322 let port_iter = get_iter(&sharder, shard.into());
323
324 let mut returned_ports: HashSet<u16> = HashSet::new();
325 for port in port_iter {
326 assert!(!returned_ports.contains(&port)); assert!(port % nr_shards == shard); returned_ports.insert(port);
330 }
331
332 assert_eq!(returned_ports.len(), possible_ports_number);
334 }
335 }
336
337 {
339 test_helper(
340 4,
341 ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
342 |sharder, shard| sharder.iter_source_ports_for_shard(shard),
343 );
344 }
345
346 {
348 let port_range = ShardAwarePortRange::new(21371..=42424).unwrap();
349 test_helper(4, port_range.clone(), |sharder, shard| {
350 sharder.iter_source_ports_for_shard_from_range(shard, &port_range)
351 });
352 }
353 }
354}