scylla_cql/frame/request/
batch.rs

1//! CQL protocol-level representation of a `BATCH` request.
2
3use bytes::{Buf, BufMut};
4use std::{borrow::Cow, convert::TryInto, num::TryFromIntError};
5use thiserror::Error;
6
7use crate::frame::{
8    frame_errors::CqlRequestSerializationError,
9    request::{RequestOpcode, SerializableRequest},
10    types::{self, SerialConsistency},
11};
12use crate::serialize::{
13    raw_batch::{RawBatchValues, RawBatchValuesIterator},
14    row::SerializedValues,
15    RowWriter, SerializationError,
16};
17
18use super::{DeserializableRequest, RequestDeserializationError};
19
20// Batch flags
21const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10;
22const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
23const ALL_FLAGS: u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP;
24
25/// CQL protocol-level representation of a `BATCH` request, used to execute
26/// a batch of statements (prepared, unprepared, or a mix of both).
27#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
28pub struct Batch<'b, Statement, Values>
29where
30    BatchStatement<'b>: From<&'b Statement>,
31    Statement: Clone,
32    Values: RawBatchValues,
33{
34    /// The statements in the batch.
35    pub statements: Cow<'b, [Statement]>,
36
37    /// The type of the batch.
38    pub batch_type: BatchType,
39
40    /// The consistency level for the batch.
41    pub consistency: types::Consistency,
42
43    /// The serial consistency level for the batch, if any.
44    pub serial_consistency: Option<types::SerialConsistency>,
45
46    /// The client-side-assigned timestamp for the batch, if any.
47    pub timestamp: Option<i64>,
48
49    /// The bound values for the batch statements.
50    pub values: Values,
51}
52
53/// The type of a batch.
54#[derive(Clone, Copy)]
55#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
56pub enum BatchType {
57    /// By default, all operations in the batch are performed as logged, to ensure all mutations
58    /// eventually complete (or none will). See the notes on [UNLOGGED](BatchType::Unlogged) batches for more details.
59    /// A `LOGGED` batch to a single partition will be converted to an `UNLOGGED` batch as an optimization.
60    Logged = 0,
61
62    /// By default, ScyllaDB uses a batch log to ensure all operations in a batch eventually complete or none will
63    /// (note, however, that operations are only isolated within a single partition).
64    /// There is a performance penalty for batch atomicity when a batch spans multiple partitions. If you do not want
65    /// to incur this penalty, you can tell Scylla to skip the batchlog with the `UNLOGGED` option. If the `UNLOGGED`
66    /// option is used, a failed batch might leave the batch only partly applied.
67    Unlogged = 1,
68
69    /// Use the `COUNTER` option for batched counter updates. Unlike other updates in ScyllaDB, counter updates
70    /// are not idempotent.
71    Counter = 2,
72}
73
74/// Encountered a malformed batch type.
75#[derive(Debug, Error)]
76#[error("Malformed batch type: {value}")]
77pub struct BatchTypeParseError {
78    value: u8,
79}
80
81impl TryFrom<u8> for BatchType {
82    type Error = BatchTypeParseError;
83
84    fn try_from(value: u8) -> Result<Self, Self::Error> {
85        match value {
86            0 => Ok(Self::Logged),
87            1 => Ok(Self::Unlogged),
88            2 => Ok(Self::Counter),
89            _ => Err(BatchTypeParseError { value }),
90        }
91    }
92}
93
94/// A single statement in a batch, which can either be a statement string or a prepared statement ID.
95#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
96pub enum BatchStatement<'a> {
97    /// Unprepared CQL statement.
98    Query {
99        /// CQL statement string.
100        text: Cow<'a, str>,
101    },
102    /// Prepared CQL statement.
103    Prepared {
104        /// Prepared CQL statement's ID.
105        id: Cow<'a, [u8]>,
106    },
107}
108
109impl<Statement, Values> Batch<'_, Statement, Values>
110where
111    for<'s> BatchStatement<'s>: From<&'s Statement>,
112    Statement: Clone,
113    Values: RawBatchValues,
114{
115    fn do_serialize(&self, buf: &mut Vec<u8>) -> Result<(), BatchSerializationError> {
116        // Serializing type of batch
117        buf.put_u8(self.batch_type as u8);
118
119        // Serializing queries
120        types::write_short(
121            self.statements
122                .len()
123                .try_into()
124                .map_err(|_| BatchSerializationError::TooManyStatements(self.statements.len()))?,
125            buf,
126        );
127
128        let counts_mismatch_err = |n_value_lists: usize, n_statements: usize| {
129            BatchSerializationError::ValuesAndStatementsLengthMismatch {
130                n_value_lists,
131                n_statements,
132            }
133        };
134        let mut n_serialized_statements = 0usize;
135        let mut value_lists = self.values.batch_values_iter();
136        for (idx, statement) in self.statements.iter().enumerate() {
137            BatchStatement::from(statement)
138                .serialize(buf)
139                .map_err(|err| BatchSerializationError::StatementSerialization {
140                    statement_idx: idx,
141                    error: err,
142                })?;
143
144            // Reserve two bytes for length
145            let length_pos = buf.len();
146            buf.extend_from_slice(&[0, 0]);
147            let mut row_writer = RowWriter::new(buf);
148            value_lists
149                .serialize_next(&mut row_writer)
150                .ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))?
151                .map_err(|err: SerializationError| {
152                    BatchSerializationError::StatementSerialization {
153                        statement_idx: idx,
154                        error: BatchStatementSerializationError::ValuesSerialiation(err),
155                    }
156                })?;
157            // Go back and put the length
158            let count: u16 = match row_writer.value_count().try_into() {
159                Ok(n) => n,
160                Err(_) => {
161                    return Err(BatchSerializationError::StatementSerialization {
162                        statement_idx: idx,
163                        error: BatchStatementSerializationError::TooManyValues(
164                            row_writer.value_count(),
165                        ),
166                    })
167                }
168            };
169            buf[length_pos..length_pos + 2].copy_from_slice(&count.to_be_bytes());
170
171            n_serialized_statements += 1;
172        }
173        // At this point, we have all statements serialized. If any values are still left, we have a mismatch.
174        if value_lists.skip_next().is_some() {
175            return Err(counts_mismatch_err(
176                n_serialized_statements + 1 /*skipped above*/ + value_lists.count(),
177                n_serialized_statements,
178            ));
179        }
180        if n_serialized_statements != self.statements.len() {
181            // We want to check this to avoid propagating an invalid construction of self.statements_count as a
182            // hard-to-debug silent fail
183            return Err(BatchSerializationError::BadBatchConstructed {
184                n_announced_statements: self.statements.len(),
185                n_serialized_statements,
186            });
187        }
188
189        // Serializing consistency
190        types::write_consistency(self.consistency, buf);
191
192        // Serializing flags
193        let mut flags = 0;
194        if self.serial_consistency.is_some() {
195            flags |= FLAG_WITH_SERIAL_CONSISTENCY;
196        }
197        if self.timestamp.is_some() {
198            flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
199        }
200
201        buf.put_u8(flags);
202
203        if let Some(serial_consistency) = self.serial_consistency {
204            types::write_serial_consistency(serial_consistency, buf);
205        }
206        if let Some(timestamp) = self.timestamp {
207            types::write_long(timestamp, buf);
208        }
209
210        Ok(())
211    }
212}
213
214impl<Statement, Values> SerializableRequest for Batch<'_, Statement, Values>
215where
216    for<'s> BatchStatement<'s>: From<&'s Statement>,
217    Statement: Clone,
218    Values: RawBatchValues,
219{
220    const OPCODE: RequestOpcode = RequestOpcode::Batch;
221
222    fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
223        self.do_serialize(buf)?;
224        Ok(())
225    }
226}
227
228impl BatchStatement<'_> {
229    fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
230        let kind = buf.get_u8();
231        match kind {
232            0 => {
233                let text = Cow::Owned(types::read_long_string(buf)?.to_owned());
234                Ok(BatchStatement::Query { text })
235            }
236            1 => {
237                let id = types::read_short_bytes(buf)?.to_vec().into();
238                Ok(BatchStatement::Prepared { id })
239            }
240            _ => Err(RequestDeserializationError::UnexpectedBatchStatementKind(
241                kind,
242            )),
243        }
244    }
245}
246
247impl BatchStatement<'_> {
248    fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
249        match self {
250            Self::Query { text } => {
251                buf.put_u8(0);
252                types::write_long_string(text, buf)
253                    .map_err(BatchStatementSerializationError::StatementStringSerialization)?;
254            }
255            Self::Prepared { id } => {
256                buf.put_u8(1);
257                types::write_short_bytes(id, buf)
258                    .map_err(BatchStatementSerializationError::StatementIdSerialization)?;
259            }
260        }
261
262        Ok(())
263    }
264}
265
266impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> {
267    fn from(value: &'s BatchStatement) -> Self {
268        match value {
269            BatchStatement::Query { text } => BatchStatement::Query { text: text.clone() },
270            BatchStatement::Prepared { id } => BatchStatement::Prepared { id: id.clone() },
271        }
272    }
273}
274
275impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedValues>> {
276    fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
277        let batch_type = buf.get_u8().try_into()?;
278
279        let statements_count: usize = types::read_short(buf)?.into();
280        let statements_with_values = (0..statements_count)
281            .map(|_| {
282                let batch_statement = BatchStatement::deserialize(buf)?;
283
284                // As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
285                let values = SerializedValues::new_from_frame(buf)?;
286
287                Ok((batch_statement, values))
288            })
289            .collect::<Result<Vec<_>, RequestDeserializationError>>()?;
290
291        let consistency = types::read_consistency(buf)?;
292
293        let flags = buf.get_u8();
294        let unknown_flags = flags & (!ALL_FLAGS);
295        if unknown_flags != 0 {
296            return Err(RequestDeserializationError::UnknownFlags {
297                flags: unknown_flags,
298            });
299        }
300        let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
301        let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
302
303        let serial_consistency = serial_consistency_flag
304            .then(|| types::read_consistency(buf))
305            .transpose()?
306            .map(
307                |consistency| match SerialConsistency::try_from(consistency) {
308                    Ok(serial_consistency) => Ok(serial_consistency),
309                    Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
310                        consistency,
311                    )),
312                },
313            )
314            .transpose()?;
315
316        let timestamp = default_timestamp_flag
317            .then(|| types::read_long(buf))
318            .transpose()?;
319
320        let (statements, values): (Vec<BatchStatement>, Vec<SerializedValues>) =
321            statements_with_values.into_iter().unzip();
322
323        Ok(Self {
324            batch_type,
325            consistency,
326            serial_consistency,
327            timestamp,
328            statements: Cow::Owned(statements),
329            values,
330        })
331    }
332}
333
334/// An error type returned when serialization of BATCH request fails.
335#[non_exhaustive]
336#[derive(Error, Debug, Clone)]
337pub enum BatchSerializationError {
338    /// Maximum number of batch statements exceeded.
339    #[error("Too many statements in the batch. Received {0} statements, when u16::MAX is maximum possible value.")]
340    TooManyStatements(usize),
341
342    /// Number of batch statements differs from number of provided bound value lists.
343    #[error("Number of provided value lists must be equal to number of batch statements (got {n_value_lists} value lists, {n_statements} statements)")]
344    ValuesAndStatementsLengthMismatch {
345        n_value_lists: usize,
346        n_statements: usize,
347    },
348
349    /// Failed to serialize a statement in the batch.
350    #[error("Failed to serialize batch statement. statement idx: {statement_idx}, error: {error}")]
351    StatementSerialization {
352        statement_idx: usize,
353        error: BatchStatementSerializationError,
354    },
355
356    /// Number of announced batch statements differs from actual number of batch statements.
357    #[error("Invalid Batch constructed: not as many statements serialized as announced (announced: {n_announced_statements}, serialized: {n_serialized_statements})")]
358    BadBatchConstructed {
359        n_announced_statements: usize,
360        n_serialized_statements: usize,
361    },
362}
363
364/// An error type returned when serialization of one of the
365/// batch statements fails.
366#[non_exhaustive]
367#[derive(Error, Debug, Clone)]
368pub enum BatchStatementSerializationError {
369    /// Failed to serialize the CQL statement string.
370    #[error("Failed to serialize unprepared statement's content: {0}")]
371    StatementStringSerialization(TryFromIntError),
372
373    /// Maximum value of statement id exceeded.
374    #[error("Malformed prepared statement's id: {0}")]
375    StatementIdSerialization(TryFromIntError),
376
377    /// Failed to serialize statement's bound values.
378    #[error("Failed to serialize statement's values: {0}")]
379    ValuesSerialiation(SerializationError),
380
381    /// Too many bound values provided.
382    #[error("Too many values provided for the statement: {0}")]
383    TooManyValues(usize),
384}