1pub mod frame_errors;
2pub mod protocol_features;
3pub mod request;
4pub mod response;
5pub mod server_event_type;
6pub mod types;
7
8use bytes::{Buf, BufMut, Bytes};
9use frame_errors::{
10 CqlRequestSerializationError, FrameBodyExtensionsParseError, FrameHeaderParseError,
11};
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncReadExt};
14use uuid::Uuid;
15
16use std::fmt::Display;
17use std::sync::Arc;
18use std::{collections::HashMap, convert::TryFrom};
19
20use request::SerializableRequest;
21use response::ResponseOpcode;
22
23const HEADER_SIZE: usize = 9;
24
25const FLAG_COMPRESSION: u8 = 0x01;
27const FLAG_TRACING: u8 = 0x02;
28const FLAG_CUSTOM_PAYLOAD: u8 = 0x04;
29const FLAG_WARNING: u8 = 0x08;
30
31#[derive(Debug, PartialEq, Eq, Clone)]
33pub enum Authenticator {
34 AllowAllAuthenticator,
35 PasswordAuthenticator,
36 CassandraPasswordAuthenticator,
37 CassandraAllowAllAuthenticator,
38 ScyllaTransitionalAuthenticator,
39}
40
41#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
43pub enum Compression {
44 Lz4,
46 Snappy,
48}
49
50impl Compression {
51 pub fn as_str(&self) -> &'static str {
52 match self {
53 Compression::Lz4 => "lz4",
54 Compression::Snappy => "snappy",
55 }
56 }
57}
58
59impl Display for Compression {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.write_str(self.as_str())
62 }
63}
64
65pub struct SerializedRequest {
66 data: Vec<u8>,
67}
68
69impl SerializedRequest {
70 pub fn make<R: SerializableRequest>(
71 req: &R,
72 compression: Option<Compression>,
73 tracing: bool,
74 ) -> Result<SerializedRequest, CqlRequestSerializationError> {
75 let mut flags = 0;
76 let mut data = vec![0; HEADER_SIZE];
77
78 if let Some(compression) = compression {
79 flags |= FLAG_COMPRESSION;
80 let body = req.to_bytes()?;
81 compress_append(&body, compression, &mut data)?;
82 } else {
83 req.serialize(&mut data)?;
84 }
85
86 if tracing {
87 flags |= FLAG_TRACING;
88 }
89
90 data[0] = 4; data[1] = flags;
92 data[4] = R::OPCODE as u8;
94
95 let req_size = (data.len() - HEADER_SIZE) as u32;
96 data[5..9].copy_from_slice(&req_size.to_be_bytes());
97
98 Ok(Self { data })
99 }
100
101 pub fn set_stream(&mut self, stream: i16) {
102 self.data[2..4].copy_from_slice(&stream.to_be_bytes());
103 }
104
105 pub fn get_data(&self) -> &[u8] {
106 &self.data[..]
107 }
108}
109
110#[derive(Debug, Copy, Clone, PartialEq, Eq)]
112pub struct FrameParams {
113 pub version: u8,
114 pub flags: u8,
115 pub stream: i16,
116}
117
118impl Default for FrameParams {
119 fn default() -> Self {
120 Self {
121 version: 0x04,
122 flags: 0x00,
123 stream: 0,
124 }
125 }
126}
127
128pub async fn read_response_frame(
129 reader: &mut (impl AsyncRead + Unpin),
130) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameHeaderParseError> {
131 let mut raw_header = [0u8; HEADER_SIZE];
132 reader
133 .read_exact(&mut raw_header[..])
134 .await
135 .map_err(FrameHeaderParseError::HeaderIoError)?;
136
137 let mut buf = &raw_header[..];
138
139 let version = buf.get_u8();
141 if version & 0x80 != 0x80 {
142 return Err(FrameHeaderParseError::FrameFromClient);
143 }
144 if version & 0x7F != 0x04 {
145 return Err(FrameHeaderParseError::VersionNotSupported(version & 0x7f));
146 }
147
148 let flags = buf.get_u8();
149 let stream = buf.get_i16();
150
151 let frame_params = FrameParams {
152 version,
153 flags,
154 stream,
155 };
156
157 let opcode = ResponseOpcode::try_from(buf.get_u8())?;
158
159 let length = buf.get_u32() as usize;
161
162 let mut raw_body = Vec::with_capacity(length).limit(length);
163 while raw_body.has_remaining_mut() {
164 let n = reader.read_buf(&mut raw_body).await.map_err(|err| {
165 FrameHeaderParseError::BodyChunkIoError(raw_body.remaining_mut(), err)
166 })?;
167 if n == 0 {
168 return Err(FrameHeaderParseError::ConnectionClosed(
170 raw_body.remaining_mut(),
171 length,
172 ));
173 }
174 }
175
176 Ok((frame_params, opcode, raw_body.into_inner().into()))
177}
178
179pub struct ResponseBodyWithExtensions {
180 pub trace_id: Option<Uuid>,
181 pub warnings: Vec<String>,
182 pub body: Bytes,
183 pub custom_payload: Option<HashMap<String, Bytes>>,
184}
185
186pub fn parse_response_body_extensions(
187 flags: u8,
188 compression: Option<Compression>,
189 mut body: Bytes,
190) -> Result<ResponseBodyWithExtensions, FrameBodyExtensionsParseError> {
191 if flags & FLAG_COMPRESSION != 0 {
192 if let Some(compression) = compression {
193 body = decompress(&body, compression)?.into();
194 } else {
195 return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated);
196 }
197 }
198
199 let trace_id = if flags & FLAG_TRACING != 0 {
200 let buf = &mut &*body;
201 let trace_id =
202 types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?;
203 body.advance(16);
204 Some(trace_id)
205 } else {
206 None
207 };
208
209 let warnings = if flags & FLAG_WARNING != 0 {
210 let body_len = body.len();
211 let buf = &mut &*body;
212 let warnings = types::read_string_list(buf)
213 .map_err(FrameBodyExtensionsParseError::WarningsListParse)?;
214 let buf_len = buf.len();
215 body.advance(body_len - buf_len);
216 warnings
217 } else {
218 Vec::new()
219 };
220
221 let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 {
222 let body_len = body.len();
223 let buf = &mut &*body;
224 let payload_map = types::read_bytes_map(buf)
225 .map_err(FrameBodyExtensionsParseError::CustomPayloadMapParse)?;
226 let buf_len = buf.len();
227 body.advance(body_len - buf_len);
228 Some(payload_map)
229 } else {
230 None
231 };
232
233 Ok(ResponseBodyWithExtensions {
234 trace_id,
235 warnings,
236 body,
237 custom_payload,
238 })
239}
240
241fn compress_append(
242 uncomp_body: &[u8],
243 compression: Compression,
244 out: &mut Vec<u8>,
245) -> Result<(), CqlRequestSerializationError> {
246 match compression {
247 Compression::Lz4 => {
248 let uncomp_len = uncomp_body.len() as u32;
249 let tmp = lz4_flex::compress(uncomp_body);
250 out.reserve_exact(std::mem::size_of::<u32>() + tmp.len());
251 out.put_u32(uncomp_len);
252 out.extend_from_slice(&tmp[..]);
253 Ok(())
254 }
255 Compression::Snappy => {
256 let old_size = out.len();
257 out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0);
258 let compressed_size = snap::raw::Encoder::new()
259 .compress(uncomp_body, &mut out[old_size..])
260 .map_err(|err| CqlRequestSerializationError::SnapCompressError(Arc::new(err)))?;
261 out.truncate(old_size + compressed_size);
262 Ok(())
263 }
264 }
265}
266
267fn decompress(
268 mut comp_body: &[u8],
269 compression: Compression,
270) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
271 match compression {
272 Compression::Lz4 => {
273 let uncomp_len = comp_body.get_u32() as usize;
274 let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)
275 .map_err(|err| FrameBodyExtensionsParseError::Lz4DecompressError(Arc::new(err)))?;
276 Ok(uncomp_body)
277 }
278 Compression::Snappy => snap::raw::Decoder::new()
279 .decompress_vec(comp_body)
280 .map_err(|err| FrameBodyExtensionsParseError::SnapDecompressError(Arc::new(err))),
281 }
282}
283
284#[derive(Error, Debug, Clone, PartialEq, Eq)]
286#[error("No discrimant in enum `{enum_name}` matches the value `{primitive:?}`")]
287pub struct TryFromPrimitiveError<T: Copy + std::fmt::Debug> {
288 enum_name: &'static str,
289 primitive: T,
290}
291
292#[cfg(test)]
293mod test {
294 use super::*;
295
296 #[test]
297 fn test_lz4_compress() {
298 let mut out = Vec::from(&b"Hello"[..]);
299 let uncomp_body = b", World!";
300 let compression = Compression::Lz4;
301 let expect = vec![
302 72, 101, 108, 108, 111, 0, 0, 0, 8, 128, 44, 32, 87, 111, 114, 108, 100, 33,
303 ];
304
305 compress_append(uncomp_body, compression, &mut out).unwrap();
306 assert_eq!(expect, out);
307 }
308
309 #[test]
310 fn test_lz4_decompress() {
311 let mut comp_body = Vec::new();
312 let uncomp_body = "Hello, World!".repeat(100);
313 let compression = Compression::Lz4;
314 compress_append(uncomp_body.as_bytes(), compression, &mut comp_body).unwrap();
315 let result = decompress(&comp_body[..], compression).unwrap();
316 assert_eq!(32, comp_body.len());
317 assert_eq!(uncomp_body.as_bytes(), result);
318 }
319}