scylla_cql/frame/
mod.rs

1//! Abstractions of the CQL wire protocol:
2//! - request and response frames' representation and ser/de;
3//! - frame header and body;
4//! - serialization and deserialization of low-level CQL protocol types;
5//! - protocol features negotiation;
6//! - compression, tracing, custom payload support;
7//! - consistency levels;
8//! - errors that can occur during the above operations.
9//!
10
11pub mod frame_errors;
12pub mod protocol_features;
13pub mod request;
14pub mod response;
15pub mod server_event_type;
16pub mod types;
17
18use bytes::{Buf, BufMut, Bytes};
19use frame_errors::{
20    CqlRequestSerializationError, FrameBodyExtensionsParseError, FrameHeaderParseError,
21};
22use thiserror::Error;
23use tokio::io::{AsyncRead, AsyncReadExt};
24use uuid::Uuid;
25
26use std::fmt::Display;
27use std::str::FromStr;
28use std::sync::Arc;
29use std::{collections::HashMap, convert::TryFrom};
30
31use request::SerializableRequest;
32use response::ResponseOpcode;
33
34const HEADER_SIZE: usize = 9;
35
36pub mod flag {
37    //! Frame flags
38
39    /// The frame contains a compressed body.
40    pub const COMPRESSION: u8 = 0x01;
41
42    /// The frame contains tracing ID.
43    pub const TRACING: u8 = 0x02;
44
45    /// The frame contains a custom payload.
46    pub const CUSTOM_PAYLOAD: u8 = 0x04;
47
48    /// The frame contains warnings.
49    pub const WARNING: u8 = 0x08;
50}
51
52/// All of the Authenticators supported by ScyllaDB
53#[derive(Debug, PartialEq, Eq, Clone)]
54// Check triggers because all variants end with "Authenticator".
55// TODO(2.0): Remove the "Authenticator" postfix from variants.
56#[expect(clippy::enum_variant_names)]
57pub enum Authenticator {
58    AllowAllAuthenticator,
59    PasswordAuthenticator,
60    CassandraPasswordAuthenticator,
61    CassandraAllowAllAuthenticator,
62    ScyllaTransitionalAuthenticator,
63}
64
65/// The wire protocol compression algorithm.
66#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
67pub enum Compression {
68    /// LZ4 compression algorithm.
69    Lz4,
70    /// Snappy compression algorithm.
71    Snappy,
72}
73
74impl Compression {
75    /// Returns the string representation of the compression algorithm.
76    pub fn as_str(&self) -> &'static str {
77        match self {
78            Compression::Lz4 => "lz4",
79            Compression::Snappy => "snappy",
80        }
81    }
82}
83
84/// Unknown compression.
85#[derive(Error, Debug, Clone)]
86#[error("Unknown compression: {name}")]
87pub struct CompressionFromStrError {
88    name: String,
89}
90
91impl FromStr for Compression {
92    type Err = CompressionFromStrError;
93
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        match s {
96            "lz4" => Ok(Self::Lz4),
97            "snappy" => Ok(Self::Snappy),
98            other => Err(Self::Err {
99                name: other.to_owned(),
100            }),
101        }
102    }
103}
104
105impl Display for Compression {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.write_str(self.as_str())
108    }
109}
110
111/// A serialized CQL request frame, nearly ready to be sent over the wire.
112///
113/// The only difference from a real frame is that it does not contain the stream number yet.
114/// The stream number is set by the `set_stream` method before sending.
115pub struct SerializedRequest {
116    data: Vec<u8>,
117}
118
119impl SerializedRequest {
120    /// Creates a new serialized request frame from a request object.
121    ///
122    /// # Parameters
123    /// - `req`: The request object to serialize. Must implement `SerializableRequest`.
124    /// - `compression`: An optional compression algorithm to use for the request body.
125    /// - `tracing`: A boolean indicating whether to request tracing information in the response.
126    pub fn make<R: SerializableRequest>(
127        req: &R,
128        compression: Option<Compression>,
129        tracing: bool,
130    ) -> Result<SerializedRequest, CqlRequestSerializationError> {
131        let mut flags = 0;
132        let mut data = vec![0; HEADER_SIZE];
133
134        if let Some(compression) = compression {
135            flags |= flag::COMPRESSION;
136            let body = req.to_bytes()?;
137            compress_append(&body, compression, &mut data)?;
138        } else {
139            req.serialize(&mut data)?;
140        }
141
142        if tracing {
143            flags |= flag::TRACING;
144        }
145
146        data[0] = 4; // We only support version 4 for now
147        data[1] = flags;
148        // Leave space for the stream number
149        data[4] = R::OPCODE as u8;
150
151        let req_size = (data.len() - HEADER_SIZE) as u32;
152        data[5..9].copy_from_slice(&req_size.to_be_bytes());
153
154        Ok(Self { data })
155    }
156
157    /// Sets the stream number for this request frame.
158    /// Intended to be called before sending the request,
159    /// once a stream ID has been assigned.
160    pub fn set_stream(&mut self, stream: i16) {
161        self.data[2..4].copy_from_slice(&stream.to_be_bytes());
162    }
163
164    /// Returns the serialized frame data, including the header and body.
165    pub fn get_data(&self) -> &[u8] {
166        &self.data[..]
167    }
168}
169
170/// Parts of the frame header which are not determined by the request/response type.
171#[derive(Debug, Copy, Clone, PartialEq, Eq)]
172pub struct FrameParams {
173    /// The version of the frame protocol. Currently, only version 4 is supported.
174    /// The most significant bit (0x80) is treated specially:
175    /// it indicates whether the frame is from the client or server.
176    pub version: u8,
177
178    /// Flags for the frame, indicating features like compression, tracing, etc.
179    pub flags: u8,
180
181    /// The stream ID for this frame, which allows matching requests and responses
182    /// in a multiplexed connection.
183    pub stream: i16,
184}
185
186impl Default for FrameParams {
187    fn default() -> Self {
188        Self {
189            version: 0x04,
190            flags: 0x00,
191            stream: 0,
192        }
193    }
194}
195
196/// Reads a response frame from the provided reader (usually, a socket).
197/// Then parses and validates the frame header and extracts the body.
198pub async fn read_response_frame(
199    reader: &mut (impl AsyncRead + Unpin),
200) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameHeaderParseError> {
201    let mut raw_header = [0u8; HEADER_SIZE];
202    reader
203        .read_exact(&mut raw_header[..])
204        .await
205        .map_err(FrameHeaderParseError::HeaderIoError)?;
206
207    let mut buf = &raw_header[..];
208
209    // TODO: Validate version
210    let version = buf.get_u8();
211    if version & 0x80 != 0x80 {
212        return Err(FrameHeaderParseError::FrameFromClient);
213    }
214    if version & 0x7F != 0x04 {
215        return Err(FrameHeaderParseError::VersionNotSupported(version & 0x7f));
216    }
217
218    let flags = buf.get_u8();
219    let stream = buf.get_i16();
220
221    let frame_params = FrameParams {
222        version,
223        flags,
224        stream,
225    };
226
227    let opcode = ResponseOpcode::try_from(buf.get_u8())?;
228
229    // TODO: Guard from frames that are too large
230    let length = buf.get_u32() as usize;
231
232    let mut raw_body = Vec::with_capacity(length).limit(length);
233    while raw_body.has_remaining_mut() {
234        let n = reader.read_buf(&mut raw_body).await.map_err(|err| {
235            FrameHeaderParseError::BodyChunkIoError(raw_body.remaining_mut(), err)
236        })?;
237        if n == 0 {
238            // EOF, too early
239            return Err(FrameHeaderParseError::ConnectionClosed(
240                raw_body.remaining_mut(),
241                length,
242            ));
243        }
244    }
245
246    Ok((frame_params, opcode, raw_body.into_inner().into()))
247}
248
249/// Represents the already parsed response body extensions,
250/// including trace ID, warnings, and custom payload,
251/// and the remaining body raw data.
252pub struct ResponseBodyWithExtensions {
253    /// The trace ID if tracing was requested in the request.
254    ///
255    /// This can be used to issue a follow-up request to the server
256    /// to get detailed tracing information about the request.
257    pub trace_id: Option<Uuid>,
258
259    /// Warnings returned by the server, if any.
260    pub warnings: Vec<String>,
261
262    /// Custom payload (see [the CQL protocol description of the feature](https://github.com/apache/cassandra/blob/a39f3b066f010d465a1be1038d5e06f1e31b0391/doc/native_protocol_v4.spec#L276))
263    /// returned by the server, if any.
264    pub custom_payload: Option<HashMap<String, Bytes>>,
265
266    /// The remaining body data after parsing the extensions.
267    pub body: Bytes,
268}
269
270/// Decompresses the response body if compression is enabled,
271/// and parses any extensions like trace ID, warnings, and custom payload.
272pub fn parse_response_body_extensions(
273    flags: u8,
274    compression: Option<Compression>,
275    mut body: Bytes,
276) -> Result<ResponseBodyWithExtensions, FrameBodyExtensionsParseError> {
277    if flags & flag::COMPRESSION != 0 {
278        if let Some(compression) = compression {
279            body = decompress(&body, compression)?.into();
280        } else {
281            return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated);
282        }
283    }
284
285    let trace_id = if flags & flag::TRACING != 0 {
286        let buf = &mut &*body;
287        let trace_id =
288            types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?;
289        body.advance(16);
290        Some(trace_id)
291    } else {
292        None
293    };
294
295    let warnings = if flags & flag::WARNING != 0 {
296        let body_len = body.len();
297        let buf = &mut &*body;
298        let warnings = types::read_string_list(buf)
299            .map_err(FrameBodyExtensionsParseError::WarningsListParse)?;
300        let buf_len = buf.len();
301        body.advance(body_len - buf_len);
302        warnings
303    } else {
304        Vec::new()
305    };
306
307    let custom_payload = if flags & flag::CUSTOM_PAYLOAD != 0 {
308        let body_len = body.len();
309        let buf = &mut &*body;
310        let payload_map = types::read_bytes_map(buf)
311            .map_err(FrameBodyExtensionsParseError::CustomPayloadMapParse)?;
312        let buf_len = buf.len();
313        body.advance(body_len - buf_len);
314        Some(payload_map)
315    } else {
316        None
317    };
318
319    Ok(ResponseBodyWithExtensions {
320        trace_id,
321        warnings,
322        custom_payload,
323        body,
324    })
325}
326
327/// Compresses the request body using the specified compression algorithm,
328/// appending the compressed data to the provided output buffer.
329pub fn compress_append(
330    uncomp_body: &[u8],
331    compression: Compression,
332    out: &mut Vec<u8>,
333) -> Result<(), CqlRequestSerializationError> {
334    match compression {
335        Compression::Lz4 => {
336            let uncomp_len = uncomp_body.len() as u32;
337            let tmp = lz4_flex::compress(uncomp_body);
338            out.reserve_exact(std::mem::size_of::<u32>() + tmp.len());
339            out.put_u32(uncomp_len);
340            out.extend_from_slice(&tmp[..]);
341            Ok(())
342        }
343        Compression::Snappy => {
344            let old_size = out.len();
345            out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0);
346            let compressed_size = snap::raw::Encoder::new()
347                .compress(uncomp_body, &mut out[old_size..])
348                .map_err(|err| CqlRequestSerializationError::SnapCompressError(Arc::new(err)))?;
349            out.truncate(old_size + compressed_size);
350            Ok(())
351        }
352    }
353}
354
355/// Deompresses the response body using the specified compression algorithm
356/// and returns the decompressed data as an owned buffer.
357pub fn decompress(
358    mut comp_body: &[u8],
359    compression: Compression,
360) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
361    match compression {
362        Compression::Lz4 => {
363            let uncomp_len = comp_body.get_u32() as usize;
364            let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)
365                .map_err(|err| FrameBodyExtensionsParseError::Lz4DecompressError(Arc::new(err)))?;
366            Ok(uncomp_body)
367        }
368        Compression::Snappy => snap::raw::Decoder::new()
369            .decompress_vec(comp_body)
370            .map_err(|err| FrameBodyExtensionsParseError::SnapDecompressError(Arc::new(err))),
371    }
372}
373
374/// An error type for parsing an enum value from a primitive.
375#[derive(Error, Debug, Clone, PartialEq, Eq)]
376#[error("No discrimant in enum `{enum_name}` matches the value `{primitive:?}`")]
377pub struct TryFromPrimitiveError<T: Copy + std::fmt::Debug> {
378    enum_name: &'static str,
379    primitive: T,
380}
381
382#[cfg(test)]
383mod test {
384    use super::*;
385
386    #[test]
387    fn test_lz4_compress() {
388        let mut out = Vec::from(&b"Hello"[..]);
389        let uncomp_body = b", World!";
390        let compression = Compression::Lz4;
391        let expect = vec![
392            72, 101, 108, 108, 111, 0, 0, 0, 8, 128, 44, 32, 87, 111, 114, 108, 100, 33,
393        ];
394
395        compress_append(uncomp_body, compression, &mut out).unwrap();
396        assert_eq!(expect, out);
397    }
398
399    #[test]
400    fn test_lz4_decompress() {
401        let mut comp_body = Vec::new();
402        let uncomp_body = "Hello, World!".repeat(100);
403        let compression = Compression::Lz4;
404        compress_append(uncomp_body.as_bytes(), compression, &mut comp_body).unwrap();
405        let result = decompress(&comp_body[..], compression).unwrap();
406        assert_eq!(32, comp_body.len());
407        assert_eq!(uncomp_body.as_bytes(), result);
408    }
409}