scylla_cql/frame/
mod.rs

1pub mod frame_errors;
2pub mod protocol_features;
3pub mod request;
4pub mod response;
5pub mod server_event_type;
6pub mod types;
7
8use bytes::{Buf, BufMut, Bytes};
9use frame_errors::{
10    CqlRequestSerializationError, FrameBodyExtensionsParseError, FrameHeaderParseError,
11};
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncReadExt};
14use uuid::Uuid;
15
16use std::fmt::Display;
17use std::sync::Arc;
18use std::{collections::HashMap, convert::TryFrom};
19
20use request::SerializableRequest;
21use response::ResponseOpcode;
22
23const HEADER_SIZE: usize = 9;
24
25// Frame flags
26const FLAG_COMPRESSION: u8 = 0x01;
27const FLAG_TRACING: u8 = 0x02;
28const FLAG_CUSTOM_PAYLOAD: u8 = 0x04;
29const FLAG_WARNING: u8 = 0x08;
30
31// All of the Authenticators supported by Scylla
32#[derive(Debug, PartialEq, Eq, Clone)]
33pub enum Authenticator {
34    AllowAllAuthenticator,
35    PasswordAuthenticator,
36    CassandraPasswordAuthenticator,
37    CassandraAllowAllAuthenticator,
38    ScyllaTransitionalAuthenticator,
39}
40
41/// The wire protocol compression algorithm.
42#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
43pub enum Compression {
44    /// LZ4 compression algorithm.
45    Lz4,
46    /// Snappy compression algorithm.
47    Snappy,
48}
49
50impl Compression {
51    pub fn as_str(&self) -> &'static str {
52        match self {
53            Compression::Lz4 => "lz4",
54            Compression::Snappy => "snappy",
55        }
56    }
57}
58
59impl Display for Compression {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.write_str(self.as_str())
62    }
63}
64
65pub struct SerializedRequest {
66    data: Vec<u8>,
67}
68
69impl SerializedRequest {
70    pub fn make<R: SerializableRequest>(
71        req: &R,
72        compression: Option<Compression>,
73        tracing: bool,
74    ) -> Result<SerializedRequest, CqlRequestSerializationError> {
75        let mut flags = 0;
76        let mut data = vec![0; HEADER_SIZE];
77
78        if let Some(compression) = compression {
79            flags |= FLAG_COMPRESSION;
80            let body = req.to_bytes()?;
81            compress_append(&body, compression, &mut data)?;
82        } else {
83            req.serialize(&mut data)?;
84        }
85
86        if tracing {
87            flags |= FLAG_TRACING;
88        }
89
90        data[0] = 4; // We only support version 4 for now
91        data[1] = flags;
92        // Leave space for the stream number
93        data[4] = R::OPCODE as u8;
94
95        let req_size = (data.len() - HEADER_SIZE) as u32;
96        data[5..9].copy_from_slice(&req_size.to_be_bytes());
97
98        Ok(Self { data })
99    }
100
101    pub fn set_stream(&mut self, stream: i16) {
102        self.data[2..4].copy_from_slice(&stream.to_be_bytes());
103    }
104
105    pub fn get_data(&self) -> &[u8] {
106        &self.data[..]
107    }
108}
109
110// Parts of the frame header which are not determined by the request/response type.
111#[derive(Debug, Copy, Clone, PartialEq, Eq)]
112pub struct FrameParams {
113    pub version: u8,
114    pub flags: u8,
115    pub stream: i16,
116}
117
118impl Default for FrameParams {
119    fn default() -> Self {
120        Self {
121            version: 0x04,
122            flags: 0x00,
123            stream: 0,
124        }
125    }
126}
127
128pub async fn read_response_frame(
129    reader: &mut (impl AsyncRead + Unpin),
130) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameHeaderParseError> {
131    let mut raw_header = [0u8; HEADER_SIZE];
132    reader
133        .read_exact(&mut raw_header[..])
134        .await
135        .map_err(FrameHeaderParseError::HeaderIoError)?;
136
137    let mut buf = &raw_header[..];
138
139    // TODO: Validate version
140    let version = buf.get_u8();
141    if version & 0x80 != 0x80 {
142        return Err(FrameHeaderParseError::FrameFromClient);
143    }
144    if version & 0x7F != 0x04 {
145        return Err(FrameHeaderParseError::VersionNotSupported(version & 0x7f));
146    }
147
148    let flags = buf.get_u8();
149    let stream = buf.get_i16();
150
151    let frame_params = FrameParams {
152        version,
153        flags,
154        stream,
155    };
156
157    let opcode = ResponseOpcode::try_from(buf.get_u8())?;
158
159    // TODO: Guard from frames that are too large
160    let length = buf.get_u32() as usize;
161
162    let mut raw_body = Vec::with_capacity(length).limit(length);
163    while raw_body.has_remaining_mut() {
164        let n = reader.read_buf(&mut raw_body).await.map_err(|err| {
165            FrameHeaderParseError::BodyChunkIoError(raw_body.remaining_mut(), err)
166        })?;
167        if n == 0 {
168            // EOF, too early
169            return Err(FrameHeaderParseError::ConnectionClosed(
170                raw_body.remaining_mut(),
171                length,
172            ));
173        }
174    }
175
176    Ok((frame_params, opcode, raw_body.into_inner().into()))
177}
178
179pub struct ResponseBodyWithExtensions {
180    pub trace_id: Option<Uuid>,
181    pub warnings: Vec<String>,
182    pub body: Bytes,
183    pub custom_payload: Option<HashMap<String, Bytes>>,
184}
185
186pub fn parse_response_body_extensions(
187    flags: u8,
188    compression: Option<Compression>,
189    mut body: Bytes,
190) -> Result<ResponseBodyWithExtensions, FrameBodyExtensionsParseError> {
191    if flags & FLAG_COMPRESSION != 0 {
192        if let Some(compression) = compression {
193            body = decompress(&body, compression)?.into();
194        } else {
195            return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated);
196        }
197    }
198
199    let trace_id = if flags & FLAG_TRACING != 0 {
200        let buf = &mut &*body;
201        let trace_id =
202            types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?;
203        body.advance(16);
204        Some(trace_id)
205    } else {
206        None
207    };
208
209    let warnings = if flags & FLAG_WARNING != 0 {
210        let body_len = body.len();
211        let buf = &mut &*body;
212        let warnings = types::read_string_list(buf)
213            .map_err(FrameBodyExtensionsParseError::WarningsListParse)?;
214        let buf_len = buf.len();
215        body.advance(body_len - buf_len);
216        warnings
217    } else {
218        Vec::new()
219    };
220
221    let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 {
222        let body_len = body.len();
223        let buf = &mut &*body;
224        let payload_map = types::read_bytes_map(buf)
225            .map_err(FrameBodyExtensionsParseError::CustomPayloadMapParse)?;
226        let buf_len = buf.len();
227        body.advance(body_len - buf_len);
228        Some(payload_map)
229    } else {
230        None
231    };
232
233    Ok(ResponseBodyWithExtensions {
234        trace_id,
235        warnings,
236        body,
237        custom_payload,
238    })
239}
240
241fn compress_append(
242    uncomp_body: &[u8],
243    compression: Compression,
244    out: &mut Vec<u8>,
245) -> Result<(), CqlRequestSerializationError> {
246    match compression {
247        Compression::Lz4 => {
248            let uncomp_len = uncomp_body.len() as u32;
249            let tmp = lz4_flex::compress(uncomp_body);
250            out.reserve_exact(std::mem::size_of::<u32>() + tmp.len());
251            out.put_u32(uncomp_len);
252            out.extend_from_slice(&tmp[..]);
253            Ok(())
254        }
255        Compression::Snappy => {
256            let old_size = out.len();
257            out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0);
258            let compressed_size = snap::raw::Encoder::new()
259                .compress(uncomp_body, &mut out[old_size..])
260                .map_err(|err| CqlRequestSerializationError::SnapCompressError(Arc::new(err)))?;
261            out.truncate(old_size + compressed_size);
262            Ok(())
263        }
264    }
265}
266
267fn decompress(
268    mut comp_body: &[u8],
269    compression: Compression,
270) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
271    match compression {
272        Compression::Lz4 => {
273            let uncomp_len = comp_body.get_u32() as usize;
274            let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)
275                .map_err(|err| FrameBodyExtensionsParseError::Lz4DecompressError(Arc::new(err)))?;
276            Ok(uncomp_body)
277        }
278        Compression::Snappy => snap::raw::Decoder::new()
279            .decompress_vec(comp_body)
280            .map_err(|err| FrameBodyExtensionsParseError::SnapDecompressError(Arc::new(err))),
281    }
282}
283
284/// An error type for parsing an enum value from a primitive.
285#[derive(Error, Debug, Clone, PartialEq, Eq)]
286#[error("No discrimant in enum `{enum_name}` matches the value `{primitive:?}`")]
287pub struct TryFromPrimitiveError<T: Copy + std::fmt::Debug> {
288    enum_name: &'static str,
289    primitive: T,
290}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295
296    #[test]
297    fn test_lz4_compress() {
298        let mut out = Vec::from(&b"Hello"[..]);
299        let uncomp_body = b", World!";
300        let compression = Compression::Lz4;
301        let expect = vec![
302            72, 101, 108, 108, 111, 0, 0, 0, 8, 128, 44, 32, 87, 111, 114, 108, 100, 33,
303        ];
304
305        compress_append(uncomp_body, compression, &mut out).unwrap();
306        assert_eq!(expect, out);
307    }
308
309    #[test]
310    fn test_lz4_decompress() {
311        let mut comp_body = Vec::new();
312        let uncomp_body = "Hello, World!".repeat(100);
313        let compression = Compression::Lz4;
314        compress_append(uncomp_body.as_bytes(), compression, &mut comp_body).unwrap();
315        let result = decompress(&comp_body[..], compression).unwrap();
316        assert_eq!(32, comp_body.len());
317        assert_eq!(uncomp_body.as_bytes(), result);
318    }
319}