scylla_cql/frame/request/
mod.rs

1pub mod auth_response;
2pub mod batch;
3pub mod execute;
4pub mod options;
5pub mod prepare;
6pub mod query;
7pub mod register;
8pub mod startup;
9
10use batch::BatchTypeParseError;
11use thiserror::Error;
12
13use crate::serialize::row::SerializedValues;
14use crate::Consistency;
15use bytes::Bytes;
16
17pub use auth_response::AuthResponse;
18pub use batch::Batch;
19pub use execute::Execute;
20pub use options::Options;
21pub use prepare::Prepare;
22pub use query::Query;
23pub use startup::Startup;
24
25use self::batch::BatchStatement;
26
27use super::frame_errors::{CqlRequestSerializationError, LowLevelDeserializationError};
28use super::types::SerialConsistency;
29use super::TryFromPrimitiveError;
30
31/// Possible requests sent by the client.
32#[derive(Debug, Copy, Clone)]
33#[non_exhaustive]
34pub enum CqlRequestKind {
35    Startup,
36    AuthResponse,
37    Options,
38    Query,
39    Prepare,
40    Execute,
41    Batch,
42    Register,
43}
44
45impl std::fmt::Display for CqlRequestKind {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        let kind_str = match self {
48            CqlRequestKind::Startup => "STARTUP",
49            CqlRequestKind::AuthResponse => "AUTH_RESPONSE",
50            CqlRequestKind::Options => "OPTIONS",
51            CqlRequestKind::Query => "QUERY",
52            CqlRequestKind::Prepare => "PREPARE",
53            CqlRequestKind::Execute => "EXECUTE",
54            CqlRequestKind::Batch => "BATCH",
55            CqlRequestKind::Register => "REGISTER",
56        };
57
58        f.write_str(kind_str)
59    }
60}
61
62#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
63#[repr(u8)]
64pub enum RequestOpcode {
65    Startup = 0x01,
66    Options = 0x05,
67    Query = 0x07,
68    Prepare = 0x09,
69    Execute = 0x0A,
70    Register = 0x0B,
71    Batch = 0x0D,
72    AuthResponse = 0x0F,
73}
74
75impl TryFrom<u8> for RequestOpcode {
76    type Error = TryFromPrimitiveError<u8>;
77
78    fn try_from(value: u8) -> Result<Self, Self::Error> {
79        match value {
80            0x01 => Ok(Self::Startup),
81            0x05 => Ok(Self::Options),
82            0x07 => Ok(Self::Query),
83            0x09 => Ok(Self::Prepare),
84            0x0A => Ok(Self::Execute),
85            0x0B => Ok(Self::Register),
86            0x0D => Ok(Self::Batch),
87            0x0F => Ok(Self::AuthResponse),
88            _ => Err(TryFromPrimitiveError {
89                enum_name: "RequestOpcode",
90                primitive: value,
91            }),
92        }
93    }
94}
95
96pub trait SerializableRequest {
97    const OPCODE: RequestOpcode;
98
99    fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError>;
100
101    fn to_bytes(&self) -> Result<Bytes, CqlRequestSerializationError> {
102        let mut v = Vec::new();
103        self.serialize(&mut v)?;
104        Ok(v.into())
105    }
106}
107
108/// Not intended for driver's direct usage (as driver has no interest in deserialising CQL requests),
109/// but very useful for testing (e.g. asserting that the sent requests have proper parameters set).
110pub trait DeserializableRequest: SerializableRequest + Sized {
111    fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError>;
112}
113
114/// An error type returned by [`DeserializableRequest::deserialize`].
115/// This is not intended for driver's direct usage. It's a testing utility,
116/// mainly used by `scylla-proxy` crate.
117#[doc(hidden)]
118#[derive(Debug, Error)]
119pub enum RequestDeserializationError {
120    #[error("Low level deser error: {0}")]
121    LowLevelDeserialization(#[from] LowLevelDeserializationError),
122    #[error("Io error: {0}")]
123    IoError(#[from] std::io::Error),
124    #[error("Specified flags are not recognised: {:02x}", flags)]
125    UnknownFlags { flags: u8 },
126    #[error("Named values in frame are currently unsupported")]
127    NamedValuesUnsupported,
128    #[error("Expected SerialConsistency, got regular Consistency: {0}")]
129    ExpectedSerialConsistency(Consistency),
130    #[error(transparent)]
131    BatchTypeParse(#[from] BatchTypeParseError),
132    #[error("Unexpected batch statement kind: {0}")]
133    UnexpectedBatchStatementKind(u8),
134}
135
136#[non_exhaustive] // TODO: add remaining request types
137pub enum Request<'r> {
138    Query(Query<'r>),
139    Execute(Execute<'r>),
140    Batch(Batch<'r, BatchStatement<'r>, Vec<SerializedValues>>),
141}
142
143impl Request<'_> {
144    pub fn deserialize(
145        buf: &mut &[u8],
146        opcode: RequestOpcode,
147    ) -> Result<Self, RequestDeserializationError> {
148        match opcode {
149            RequestOpcode::Query => Query::deserialize(buf).map(Self::Query),
150            RequestOpcode::Execute => Execute::deserialize(buf).map(Self::Execute),
151            RequestOpcode::Batch => Batch::deserialize(buf).map(Self::Batch),
152            _ => unimplemented!(
153                "Deserialization of opcode {:?} is not yet supported",
154                opcode
155            ),
156        }
157    }
158
159    /// Retrieves consistency from request frame, if present.
160    pub fn get_consistency(&self) -> Option<Consistency> {
161        match self {
162            Request::Query(q) => Some(q.parameters.consistency),
163            Request::Execute(e) => Some(e.parameters.consistency),
164            Request::Batch(b) => Some(b.consistency),
165            #[allow(unreachable_patterns)] // until other opcodes are supported
166            _ => None,
167        }
168    }
169
170    /// Retrieves serial consistency from request frame.
171    pub fn get_serial_consistency(&self) -> Option<Option<SerialConsistency>> {
172        match self {
173            Request::Query(q) => Some(q.parameters.serial_consistency),
174            Request::Execute(e) => Some(e.parameters.serial_consistency),
175            Request::Batch(b) => Some(b.serial_consistency),
176            #[allow(unreachable_patterns)] // until other opcodes are supported
177            _ => None,
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use std::{borrow::Cow, ops::Deref};
185
186    use bytes::Bytes;
187
188    use crate::serialize::row::SerializedValues;
189    use crate::{
190        frame::{
191            request::{
192                batch::{Batch, BatchStatement, BatchType},
193                execute::Execute,
194                query::{Query, QueryParameters},
195                DeserializableRequest, SerializableRequest,
196            },
197            response::result::{ColumnType, NativeType},
198            types::{self, SerialConsistency},
199        },
200        Consistency,
201    };
202
203    use super::query::PagingState;
204
205    #[test]
206    fn request_ser_de_identity() {
207        // Query
208        let contents = Cow::Borrowed("SELECT host_id from system.peers");
209        let parameters = QueryParameters {
210            consistency: Consistency::All,
211            serial_consistency: Some(SerialConsistency::Serial),
212            timestamp: None,
213            page_size: Some(323),
214            paging_state: PagingState::new_from_raw_bytes(&[2_u8, 1, 3, 7] as &[u8]),
215            skip_metadata: false,
216            values: {
217                let mut vals = SerializedValues::new();
218                vals.add_value(&2137, &ColumnType::Native(NativeType::Int))
219                    .unwrap();
220                Cow::Owned(vals)
221            },
222        };
223        let query = Query {
224            contents,
225            parameters,
226        };
227
228        {
229            let mut buf = Vec::new();
230            query.serialize(&mut buf).unwrap();
231
232            let query_deserialized = Query::deserialize(&mut &buf[..]).unwrap();
233            assert_eq!(&query_deserialized, &query);
234        }
235
236        // Execute
237        let id: Bytes = vec![2, 4, 5, 2, 6, 7, 3, 1].into();
238        let parameters = QueryParameters {
239            consistency: Consistency::Any,
240            serial_consistency: None,
241            timestamp: Some(3423434),
242            page_size: None,
243            paging_state: PagingState::start(),
244            skip_metadata: false,
245            values: {
246                let mut vals = SerializedValues::new();
247                vals.add_value(&42, &ColumnType::Native(NativeType::Int))
248                    .unwrap();
249                vals.add_value(&2137, &ColumnType::Native(NativeType::Int))
250                    .unwrap();
251                Cow::Owned(vals)
252            },
253        };
254        let execute = Execute { id, parameters };
255        {
256            let mut buf = Vec::new();
257            execute.serialize(&mut buf).unwrap();
258
259            let execute_deserialized = Execute::deserialize(&mut &buf[..]).unwrap();
260            assert_eq!(&execute_deserialized, &execute);
261        }
262
263        // Batch
264        let statements = vec![
265            BatchStatement::Query {
266                text: query.contents,
267            },
268            BatchStatement::Prepared {
269                id: Cow::Borrowed(&execute.id),
270            },
271        ];
272        let batch = Batch {
273            statements: Cow::Owned(statements),
274            batch_type: BatchType::Logged,
275            consistency: Consistency::EachQuorum,
276            serial_consistency: Some(SerialConsistency::LocalSerial),
277            timestamp: Some(32432),
278
279            // Not execute's values, because named values are not supported in batches.
280            values: vec![
281                query.parameters.values.deref().clone(),
282                query.parameters.values.deref().clone(),
283            ],
284        };
285        {
286            let mut buf = Vec::new();
287            batch.serialize(&mut buf).unwrap();
288
289            let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap();
290            assert_eq!(&batch_deserialized, &batch);
291        }
292    }
293
294    #[test]
295    fn deser_rejects_unknown_flags() {
296        // Query
297        let contents = Cow::Borrowed("SELECT host_id from system.peers");
298        let parameters = QueryParameters {
299            consistency: Default::default(),
300            serial_consistency: Some(SerialConsistency::LocalSerial),
301            timestamp: None,
302            page_size: None,
303            paging_state: PagingState::start(),
304            skip_metadata: false,
305            values: Cow::Borrowed(SerializedValues::EMPTY),
306        };
307        let query = Query {
308            contents: contents.clone(),
309            parameters,
310        };
311
312        {
313            let mut buf = Vec::new();
314            query.serialize(&mut buf).unwrap();
315
316            // Sanity check: query deserializes to the equivalent.
317            let query_deserialized = Query::deserialize(&mut &buf[..]).unwrap();
318            assert_eq!(&query_deserialized.contents, &query.contents);
319            assert_eq!(&query_deserialized.parameters, &query.parameters);
320
321            // Now modify flags by adding an unknown one.
322            // Find flags in buffer:
323            let mut buf_ptr = buf.as_slice();
324            let serialised_contents = types::read_long_string(&mut buf_ptr).unwrap();
325            assert_eq!(serialised_contents, contents);
326
327            // Now buf_ptr points at consistency.
328            let consistency = types::read_consistency(&mut buf_ptr).unwrap();
329            assert_eq!(consistency, Consistency::default());
330
331            // Now buf_ptr points at flags, but it is immutable. Get mutable reference into the buffer.
332            let flags_idx = buf.len() - buf_ptr.len();
333            let flags_mut = &mut buf[flags_idx];
334
335            // This assumes that the following flag is unknown, which is true at the time of writing this test.
336            *flags_mut |= 0x80;
337
338            // Unknown flag should lead to frame rejection, as unknown flags can be new protocol extensions
339            // leading to different semantics.
340            let _parse_error = Query::deserialize(&mut &buf[..]).unwrap_err();
341        }
342
343        // Batch
344        let statements = vec![BatchStatement::Query {
345            text: query.contents,
346        }];
347        let batch = Batch {
348            statements: Cow::Owned(statements),
349            batch_type: BatchType::Logged,
350            consistency: Consistency::EachQuorum,
351            serial_consistency: None,
352            timestamp: None,
353
354            values: vec![query.parameters.values.deref().clone()],
355        };
356        {
357            let mut buf = Vec::new();
358            batch.serialize(&mut buf).unwrap();
359
360            // Sanity check: batch deserializes to the equivalent.
361            let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap();
362            assert_eq!(batch, batch_deserialized);
363
364            // Now modify flags by adding an unknown one.
365            // There are no timestamp nor serial consistency, so flags are the last byte in the buf.
366            let buf_len = buf.len();
367            let flags_mut = &mut buf[buf_len - 1];
368            // This assumes that the following flag is unknown, which is true at the time of writing this test.
369            *flags_mut |= 0x80;
370
371            // Unknown flag should lead to frame rejection, as unknown flags can be new protocol extensions
372            // leading to different semantics.
373            let _parse_error = Batch::deserialize(&mut &buf[..]).unwrap_err();
374        }
375    }
376}