scylla_cql/frame/request/
batch.rs

1use bytes::{Buf, BufMut};
2use std::{borrow::Cow, convert::TryInto, num::TryFromIntError};
3use thiserror::Error;
4
5use crate::frame::{
6    frame_errors::CqlRequestSerializationError,
7    request::{RequestOpcode, SerializableRequest},
8    types::{self, SerialConsistency},
9};
10use crate::serialize::{
11    raw_batch::{RawBatchValues, RawBatchValuesIterator},
12    row::SerializedValues,
13    RowWriter, SerializationError,
14};
15
16use super::{DeserializableRequest, RequestDeserializationError};
17
18// Batch flags
19const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10;
20const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
21const ALL_FLAGS: u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP;
22
23#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
24pub struct Batch<'b, Statement, Values>
25where
26    BatchStatement<'b>: From<&'b Statement>,
27    Statement: Clone,
28    Values: RawBatchValues,
29{
30    pub statements: Cow<'b, [Statement]>,
31    pub batch_type: BatchType,
32    pub consistency: types::Consistency,
33    pub serial_consistency: Option<types::SerialConsistency>,
34    pub timestamp: Option<i64>,
35    pub values: Values,
36}
37
38/// The type of a batch.
39#[derive(Clone, Copy)]
40#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
41pub enum BatchType {
42    Logged = 0,
43    Unlogged = 1,
44    Counter = 2,
45}
46
47#[derive(Debug, Error)]
48#[error("Malformed batch type: {value}")]
49pub struct BatchTypeParseError {
50    value: u8,
51}
52
53impl TryFrom<u8> for BatchType {
54    type Error = BatchTypeParseError;
55
56    fn try_from(value: u8) -> Result<Self, Self::Error> {
57        match value {
58            0 => Ok(Self::Logged),
59            1 => Ok(Self::Unlogged),
60            2 => Ok(Self::Counter),
61            _ => Err(BatchTypeParseError { value }),
62        }
63    }
64}
65
66#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
67pub enum BatchStatement<'a> {
68    Query { text: Cow<'a, str> },
69    Prepared { id: Cow<'a, [u8]> },
70}
71
72impl<Statement, Values> Batch<'_, Statement, Values>
73where
74    for<'s> BatchStatement<'s>: From<&'s Statement>,
75    Statement: Clone,
76    Values: RawBatchValues,
77{
78    fn do_serialize(&self, buf: &mut Vec<u8>) -> Result<(), BatchSerializationError> {
79        // Serializing type of batch
80        buf.put_u8(self.batch_type as u8);
81
82        // Serializing queries
83        types::write_short(
84            self.statements
85                .len()
86                .try_into()
87                .map_err(|_| BatchSerializationError::TooManyStatements(self.statements.len()))?,
88            buf,
89        );
90
91        let counts_mismatch_err = |n_value_lists: usize, n_statements: usize| {
92            BatchSerializationError::ValuesAndStatementsLengthMismatch {
93                n_value_lists,
94                n_statements,
95            }
96        };
97        let mut n_serialized_statements = 0usize;
98        let mut value_lists = self.values.batch_values_iter();
99        for (idx, statement) in self.statements.iter().enumerate() {
100            BatchStatement::from(statement)
101                .serialize(buf)
102                .map_err(|err| BatchSerializationError::StatementSerialization {
103                    statement_idx: idx,
104                    error: err,
105                })?;
106
107            // Reserve two bytes for length
108            let length_pos = buf.len();
109            buf.extend_from_slice(&[0, 0]);
110            let mut row_writer = RowWriter::new(buf);
111            value_lists
112                .serialize_next(&mut row_writer)
113                .ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))?
114                .map_err(|err: SerializationError| {
115                    BatchSerializationError::StatementSerialization {
116                        statement_idx: idx,
117                        error: BatchStatementSerializationError::ValuesSerialiation(err),
118                    }
119                })?;
120            // Go back and put the length
121            let count: u16 = match row_writer.value_count().try_into() {
122                Ok(n) => n,
123                Err(_) => {
124                    return Err(BatchSerializationError::StatementSerialization {
125                        statement_idx: idx,
126                        error: BatchStatementSerializationError::TooManyValues(
127                            row_writer.value_count(),
128                        ),
129                    })
130                }
131            };
132            buf[length_pos..length_pos + 2].copy_from_slice(&count.to_be_bytes());
133
134            n_serialized_statements += 1;
135        }
136        // At this point, we have all statements serialized. If any values are still left, we have a mismatch.
137        if value_lists.skip_next().is_some() {
138            return Err(counts_mismatch_err(
139                n_serialized_statements + 1 /*skipped above*/ + value_lists.count(),
140                n_serialized_statements,
141            ));
142        }
143        if n_serialized_statements != self.statements.len() {
144            // We want to check this to avoid propagating an invalid construction of self.statements_count as a
145            // hard-to-debug silent fail
146            return Err(BatchSerializationError::BadBatchConstructed {
147                n_announced_statements: self.statements.len(),
148                n_serialized_statements,
149            });
150        }
151
152        // Serializing consistency
153        types::write_consistency(self.consistency, buf);
154
155        // Serializing flags
156        let mut flags = 0;
157        if self.serial_consistency.is_some() {
158            flags |= FLAG_WITH_SERIAL_CONSISTENCY;
159        }
160        if self.timestamp.is_some() {
161            flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
162        }
163
164        buf.put_u8(flags);
165
166        if let Some(serial_consistency) = self.serial_consistency {
167            types::write_serial_consistency(serial_consistency, buf);
168        }
169        if let Some(timestamp) = self.timestamp {
170            types::write_long(timestamp, buf);
171        }
172
173        Ok(())
174    }
175}
176
177impl<Statement, Values> SerializableRequest for Batch<'_, Statement, Values>
178where
179    for<'s> BatchStatement<'s>: From<&'s Statement>,
180    Statement: Clone,
181    Values: RawBatchValues,
182{
183    const OPCODE: RequestOpcode = RequestOpcode::Batch;
184
185    fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
186        self.do_serialize(buf)?;
187        Ok(())
188    }
189}
190
191impl BatchStatement<'_> {
192    fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
193        let kind = buf.get_u8();
194        match kind {
195            0 => {
196                let text = Cow::Owned(types::read_long_string(buf)?.to_owned());
197                Ok(BatchStatement::Query { text })
198            }
199            1 => {
200                let id = types::read_short_bytes(buf)?.to_vec().into();
201                Ok(BatchStatement::Prepared { id })
202            }
203            _ => Err(RequestDeserializationError::UnexpectedBatchStatementKind(
204                kind,
205            )),
206        }
207    }
208}
209
210impl BatchStatement<'_> {
211    fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
212        match self {
213            Self::Query { text } => {
214                buf.put_u8(0);
215                types::write_long_string(text, buf)
216                    .map_err(BatchStatementSerializationError::StatementStringSerialization)?;
217            }
218            Self::Prepared { id } => {
219                buf.put_u8(1);
220                types::write_short_bytes(id, buf)
221                    .map_err(BatchStatementSerializationError::StatementIdSerialization)?;
222            }
223        }
224
225        Ok(())
226    }
227}
228
229// Disable the lint, if there is more than one lifetime included.
230// Can be removed once https://github.com/rust-lang/rust-clippy/issues/12495 is fixed.
231#[allow(clippy::needless_lifetimes)]
232impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> {
233    fn from(value: &'s BatchStatement) -> Self {
234        match value {
235            BatchStatement::Query { text } => BatchStatement::Query { text: text.clone() },
236            BatchStatement::Prepared { id } => BatchStatement::Prepared { id: id.clone() },
237        }
238    }
239}
240
241impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedValues>> {
242    fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
243        let batch_type = buf.get_u8().try_into()?;
244
245        let statements_count: usize = types::read_short(buf)?.into();
246        let statements_with_values = (0..statements_count)
247            .map(|_| {
248                let batch_statement = BatchStatement::deserialize(buf)?;
249
250                // As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
251                let values = SerializedValues::new_from_frame(buf)?;
252
253                Ok((batch_statement, values))
254            })
255            .collect::<Result<Vec<_>, RequestDeserializationError>>()?;
256
257        let consistency = types::read_consistency(buf)?;
258
259        let flags = buf.get_u8();
260        let unknown_flags = flags & (!ALL_FLAGS);
261        if unknown_flags != 0 {
262            return Err(RequestDeserializationError::UnknownFlags {
263                flags: unknown_flags,
264            });
265        }
266        let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
267        let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
268
269        let serial_consistency = serial_consistency_flag
270            .then(|| types::read_consistency(buf))
271            .transpose()?
272            .map(
273                |consistency| match SerialConsistency::try_from(consistency) {
274                    Ok(serial_consistency) => Ok(serial_consistency),
275                    Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
276                        consistency,
277                    )),
278                },
279            )
280            .transpose()?;
281
282        let timestamp = default_timestamp_flag
283            .then(|| types::read_long(buf))
284            .transpose()?;
285
286        let (statements, values): (Vec<BatchStatement>, Vec<SerializedValues>) =
287            statements_with_values.into_iter().unzip();
288
289        Ok(Self {
290            batch_type,
291            consistency,
292            serial_consistency,
293            timestamp,
294            statements: Cow::Owned(statements),
295            values,
296        })
297    }
298}
299
300/// An error type returned when serialization of BATCH request fails.
301#[non_exhaustive]
302#[derive(Error, Debug, Clone)]
303pub enum BatchSerializationError {
304    /// Maximum number of batch statements exceeded.
305    #[error("Too many statements in the batch. Received {0} statements, when u16::MAX is maximum possible value.")]
306    TooManyStatements(usize),
307
308    /// Number of batch statements differs from number of provided bound value lists.
309    #[error("Number of provided value lists must be equal to number of batch statements (got {n_value_lists} value lists, {n_statements} statements)")]
310    ValuesAndStatementsLengthMismatch {
311        n_value_lists: usize,
312        n_statements: usize,
313    },
314
315    /// Failed to serialize a statement in the batch.
316    #[error("Failed to serialize batch statement. statement idx: {statement_idx}, error: {error}")]
317    StatementSerialization {
318        statement_idx: usize,
319        error: BatchStatementSerializationError,
320    },
321
322    /// Number of announced batch statements differs from actual number of batch statements.
323    #[error("Invalid Batch constructed: not as many statements serialized as announced (announced: {n_announced_statements}, serialized: {n_serialized_statements})")]
324    BadBatchConstructed {
325        n_announced_statements: usize,
326        n_serialized_statements: usize,
327    },
328}
329
330/// An error type returned when serialization of one of the
331/// batch statements fails.
332#[non_exhaustive]
333#[derive(Error, Debug, Clone)]
334pub enum BatchStatementSerializationError {
335    /// Failed to serialize the CQL statement string.
336    #[error("Failed to serialize unprepared statement's content: {0}")]
337    StatementStringSerialization(TryFromIntError),
338
339    /// Maximum value of statement id exceeded.
340    #[error("Malformed prepared statement's id: {0}")]
341    StatementIdSerialization(TryFromIntError),
342
343    /// Failed to serialize statement's bound values.
344    #[error("Failed to serialize statement's values: {0}")]
345    ValuesSerialiation(SerializationError),
346
347    /// Too many bound values provided.
348    #[error("Too many values provided for the statement: {0}")]
349    TooManyValues(usize),
350}