1pub mod frame_errors;
12pub mod protocol_features;
13pub mod request;
14pub mod response;
15pub mod server_event_type;
16pub mod types;
17
18use bytes::{Buf, BufMut, Bytes};
19use frame_errors::{
20 CqlRequestSerializationError, FrameBodyExtensionsParseError, FrameHeaderParseError,
21};
22use thiserror::Error;
23use tokio::io::{AsyncRead, AsyncReadExt};
24use uuid::Uuid;
25
26use std::fmt::Display;
27use std::str::FromStr;
28use std::sync::Arc;
29use std::{collections::HashMap, convert::TryFrom};
30
31use request::SerializableRequest;
32use response::ResponseOpcode;
33
34const HEADER_SIZE: usize = 9;
35
36pub mod flag {
37 pub const COMPRESSION: u8 = 0x01;
41
42 pub const TRACING: u8 = 0x02;
44
45 pub const CUSTOM_PAYLOAD: u8 = 0x04;
47
48 pub const WARNING: u8 = 0x08;
50}
51
52#[derive(Debug, PartialEq, Eq, Clone)]
54#[expect(clippy::enum_variant_names)]
57pub enum Authenticator {
58 AllowAllAuthenticator,
59 PasswordAuthenticator,
60 CassandraPasswordAuthenticator,
61 CassandraAllowAllAuthenticator,
62 ScyllaTransitionalAuthenticator,
63}
64
65#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
67pub enum Compression {
68 Lz4,
70 Snappy,
72}
73
74impl Compression {
75 pub fn as_str(&self) -> &'static str {
77 match self {
78 Compression::Lz4 => "lz4",
79 Compression::Snappy => "snappy",
80 }
81 }
82}
83
84#[derive(Error, Debug, Clone)]
86#[error("Unknown compression: {name}")]
87pub struct CompressionFromStrError {
88 name: String,
89}
90
91impl FromStr for Compression {
92 type Err = CompressionFromStrError;
93
94 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 match s {
96 "lz4" => Ok(Self::Lz4),
97 "snappy" => Ok(Self::Snappy),
98 other => Err(Self::Err {
99 name: other.to_owned(),
100 }),
101 }
102 }
103}
104
105impl Display for Compression {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.write_str(self.as_str())
108 }
109}
110
111pub struct SerializedRequest {
116 data: Vec<u8>,
117}
118
119impl SerializedRequest {
120 pub fn make<R: SerializableRequest>(
127 req: &R,
128 compression: Option<Compression>,
129 tracing: bool,
130 ) -> Result<SerializedRequest, CqlRequestSerializationError> {
131 let mut flags = 0;
132 let mut data = vec![0; HEADER_SIZE];
133
134 if let Some(compression) = compression {
135 flags |= flag::COMPRESSION;
136 let body = req.to_bytes()?;
137 compress_append(&body, compression, &mut data)?;
138 } else {
139 req.serialize(&mut data)?;
140 }
141
142 if tracing {
143 flags |= flag::TRACING;
144 }
145
146 data[0] = 4; data[1] = flags;
148 data[4] = R::OPCODE as u8;
150
151 let req_size = (data.len() - HEADER_SIZE) as u32;
152 data[5..9].copy_from_slice(&req_size.to_be_bytes());
153
154 Ok(Self { data })
155 }
156
157 pub fn set_stream(&mut self, stream: i16) {
161 self.data[2..4].copy_from_slice(&stream.to_be_bytes());
162 }
163
164 pub fn get_data(&self) -> &[u8] {
166 &self.data[..]
167 }
168}
169
170#[derive(Debug, Copy, Clone, PartialEq, Eq)]
172pub struct FrameParams {
173 pub version: u8,
177
178 pub flags: u8,
180
181 pub stream: i16,
184}
185
186impl Default for FrameParams {
187 fn default() -> Self {
188 Self {
189 version: 0x04,
190 flags: 0x00,
191 stream: 0,
192 }
193 }
194}
195
196pub async fn read_response_frame(
199 reader: &mut (impl AsyncRead + Unpin),
200) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameHeaderParseError> {
201 let mut raw_header = [0u8; HEADER_SIZE];
202 reader
203 .read_exact(&mut raw_header[..])
204 .await
205 .map_err(FrameHeaderParseError::HeaderIoError)?;
206
207 let mut buf = &raw_header[..];
208
209 let version = buf.get_u8();
211 if version & 0x80 != 0x80 {
212 return Err(FrameHeaderParseError::FrameFromClient);
213 }
214 if version & 0x7F != 0x04 {
215 return Err(FrameHeaderParseError::VersionNotSupported(version & 0x7f));
216 }
217
218 let flags = buf.get_u8();
219 let stream = buf.get_i16();
220
221 let frame_params = FrameParams {
222 version,
223 flags,
224 stream,
225 };
226
227 let opcode = ResponseOpcode::try_from(buf.get_u8())?;
228
229 let length = buf.get_u32() as usize;
231
232 let mut raw_body = Vec::with_capacity(length).limit(length);
233 while raw_body.has_remaining_mut() {
234 let n = reader.read_buf(&mut raw_body).await.map_err(|err| {
235 FrameHeaderParseError::BodyChunkIoError(raw_body.remaining_mut(), err)
236 })?;
237 if n == 0 {
238 return Err(FrameHeaderParseError::ConnectionClosed(
240 raw_body.remaining_mut(),
241 length,
242 ));
243 }
244 }
245
246 Ok((frame_params, opcode, raw_body.into_inner().into()))
247}
248
249pub struct ResponseBodyWithExtensions {
253 pub trace_id: Option<Uuid>,
258
259 pub warnings: Vec<String>,
261
262 pub custom_payload: Option<HashMap<String, Bytes>>,
265
266 pub body: Bytes,
268}
269
270pub fn parse_response_body_extensions(
273 flags: u8,
274 compression: Option<Compression>,
275 mut body: Bytes,
276) -> Result<ResponseBodyWithExtensions, FrameBodyExtensionsParseError> {
277 if flags & flag::COMPRESSION != 0 {
278 if let Some(compression) = compression {
279 body = decompress(&body, compression)?.into();
280 } else {
281 return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated);
282 }
283 }
284
285 let trace_id = if flags & flag::TRACING != 0 {
286 let buf = &mut &*body;
287 let trace_id =
288 types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?;
289 body.advance(16);
290 Some(trace_id)
291 } else {
292 None
293 };
294
295 let warnings = if flags & flag::WARNING != 0 {
296 let body_len = body.len();
297 let buf = &mut &*body;
298 let warnings = types::read_string_list(buf)
299 .map_err(FrameBodyExtensionsParseError::WarningsListParse)?;
300 let buf_len = buf.len();
301 body.advance(body_len - buf_len);
302 warnings
303 } else {
304 Vec::new()
305 };
306
307 let custom_payload = if flags & flag::CUSTOM_PAYLOAD != 0 {
308 let body_len = body.len();
309 let buf = &mut &*body;
310 let payload_map = types::read_bytes_map(buf)
311 .map_err(FrameBodyExtensionsParseError::CustomPayloadMapParse)?;
312 let buf_len = buf.len();
313 body.advance(body_len - buf_len);
314 Some(payload_map)
315 } else {
316 None
317 };
318
319 Ok(ResponseBodyWithExtensions {
320 trace_id,
321 warnings,
322 custom_payload,
323 body,
324 })
325}
326
327pub fn compress_append(
330 uncomp_body: &[u8],
331 compression: Compression,
332 out: &mut Vec<u8>,
333) -> Result<(), CqlRequestSerializationError> {
334 match compression {
335 Compression::Lz4 => {
336 let uncomp_len = uncomp_body.len() as u32;
337 let tmp = lz4_flex::compress(uncomp_body);
338 out.reserve_exact(std::mem::size_of::<u32>() + tmp.len());
339 out.put_u32(uncomp_len);
340 out.extend_from_slice(&tmp[..]);
341 Ok(())
342 }
343 Compression::Snappy => {
344 let old_size = out.len();
345 out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0);
346 let compressed_size = snap::raw::Encoder::new()
347 .compress(uncomp_body, &mut out[old_size..])
348 .map_err(|err| CqlRequestSerializationError::SnapCompressError(Arc::new(err)))?;
349 out.truncate(old_size + compressed_size);
350 Ok(())
351 }
352 }
353}
354
355pub fn decompress(
358 mut comp_body: &[u8],
359 compression: Compression,
360) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
361 match compression {
362 Compression::Lz4 => {
363 let uncomp_len = comp_body.get_u32() as usize;
364 let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)
365 .map_err(|err| FrameBodyExtensionsParseError::Lz4DecompressError(Arc::new(err)))?;
366 Ok(uncomp_body)
367 }
368 Compression::Snappy => snap::raw::Decoder::new()
369 .decompress_vec(comp_body)
370 .map_err(|err| FrameBodyExtensionsParseError::SnapDecompressError(Arc::new(err))),
371 }
372}
373
374#[derive(Error, Debug, Clone, PartialEq, Eq)]
376#[error("No discrimant in enum `{enum_name}` matches the value `{primitive:?}`")]
377pub struct TryFromPrimitiveError<T: Copy + std::fmt::Debug> {
378 enum_name: &'static str,
379 primitive: T,
380}
381
382#[cfg(test)]
383mod test {
384 use super::*;
385
386 #[test]
387 fn test_lz4_compress() {
388 let mut out = Vec::from(&b"Hello"[..]);
389 let uncomp_body = b", World!";
390 let compression = Compression::Lz4;
391 let expect = vec![
392 72, 101, 108, 108, 111, 0, 0, 0, 8, 128, 44, 32, 87, 111, 114, 108, 100, 33,
393 ];
394
395 compress_append(uncomp_body, compression, &mut out).unwrap();
396 assert_eq!(expect, out);
397 }
398
399 #[test]
400 fn test_lz4_decompress() {
401 let mut comp_body = Vec::new();
402 let uncomp_body = "Hello, World!".repeat(100);
403 let compression = Compression::Lz4;
404 compress_append(uncomp_body.as_bytes(), compression, &mut comp_body).unwrap();
405 let result = decompress(&comp_body[..], compression).unwrap();
406 assert_eq!(32, comp_body.len());
407 assert_eq!(uncomp_body.as_bytes(), result);
408 }
409}