scylla_cql/frame/
protocol_features.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3
4const RATE_LIMIT_ERROR_EXTENSION: &str = "SCYLLA_RATE_LIMIT_ERROR";
5pub const SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION: &str = "SCYLLA_LWT_ADD_METADATA_MARK";
6pub const LWT_OPTIMIZATION_META_BIT_MASK_KEY: &str = "LWT_OPTIMIZATION_META_BIT_MASK";
7const TABLETS_ROUTING_V1_KEY: &str = "TABLETS_ROUTING_V1";
8
9#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
10#[non_exhaustive]
11pub struct ProtocolFeatures {
12    pub rate_limit_error: Option<i32>,
13    pub lwt_optimization_meta_bit_mask: Option<u32>,
14    pub tablets_v1_supported: bool,
15}
16
17// TODO: Log information about options which failed to parse
18
19impl ProtocolFeatures {
20    pub fn parse_from_supported(supported: &HashMap<String, Vec<String>>) -> Self {
21        Self {
22            rate_limit_error: Self::maybe_parse_rate_limit_error(supported),
23            lwt_optimization_meta_bit_mask: Self::maybe_parse_lwt_optimization_meta_bit_mask(
24                supported,
25            ),
26            tablets_v1_supported: Self::check_tablets_routing_v1_support(supported),
27        }
28    }
29
30    fn maybe_parse_rate_limit_error(supported: &HashMap<String, Vec<String>>) -> Option<i32> {
31        let vals = supported.get(RATE_LIMIT_ERROR_EXTENSION)?;
32        let code_str = Self::get_cql_extension_field(vals.as_slice(), "ERROR_CODE")?;
33        code_str.parse::<i32>().ok()
34    }
35
36    fn maybe_parse_lwt_optimization_meta_bit_mask(
37        supported: &HashMap<String, Vec<String>>,
38    ) -> Option<u32> {
39        let vals = supported.get(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION)?;
40        let mask_str =
41            Self::get_cql_extension_field(vals.as_slice(), LWT_OPTIMIZATION_META_BIT_MASK_KEY)?;
42        mask_str.parse::<u32>().ok()
43    }
44
45    fn check_tablets_routing_v1_support(supported: &HashMap<String, Vec<String>>) -> bool {
46        supported.contains_key(TABLETS_ROUTING_V1_KEY)
47    }
48
49    // Looks up a field which starts with `key=` and returns the rest
50    fn get_cql_extension_field<'a>(vals: &'a [String], key: &str) -> Option<&'a str> {
51        vals.iter()
52            .find_map(|v| v.as_str().strip_prefix(key)?.strip_prefix('='))
53    }
54
55    pub fn add_startup_options(&self, options: &mut HashMap<Cow<'_, str>, Cow<'_, str>>) {
56        if self.rate_limit_error.is_some() {
57            options.insert(Cow::Borrowed(RATE_LIMIT_ERROR_EXTENSION), Cow::Borrowed(""));
58        }
59        if let Some(mask) = self.lwt_optimization_meta_bit_mask {
60            options.insert(
61                Cow::Borrowed(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION),
62                Cow::Owned(format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, mask)),
63            );
64        }
65
66        if self.tablets_v1_supported {
67            options.insert(Cow::Borrowed(TABLETS_ROUTING_V1_KEY), Cow::Borrowed(""));
68        }
69    }
70
71    pub fn prepared_flags_contain_lwt_mark(&self, flags: u32) -> bool {
72        self.lwt_optimization_meta_bit_mask
73            .map(|mask| (flags & mask) == mask)
74            .unwrap_or(false)
75    }
76}