scylla_cql/frame/request/
batch.rs1use bytes::{Buf, BufMut};
4use std::{borrow::Cow, convert::TryInto, num::TryFromIntError};
5use thiserror::Error;
6
7use crate::frame::{
8 frame_errors::CqlRequestSerializationError,
9 request::{RequestOpcode, SerializableRequest},
10 types::{self, SerialConsistency},
11};
12use crate::serialize::{
13 raw_batch::{RawBatchValues, RawBatchValuesIterator},
14 row::SerializedValues,
15 RowWriter, SerializationError,
16};
17
18use super::{DeserializableRequest, RequestDeserializationError};
19
20const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10;
22const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
23const ALL_FLAGS: u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP;
24
25#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
28pub struct Batch<'b, Statement, Values>
29where
30 BatchStatement<'b>: From<&'b Statement>,
31 Statement: Clone,
32 Values: RawBatchValues,
33{
34 pub statements: Cow<'b, [Statement]>,
36
37 pub batch_type: BatchType,
39
40 pub consistency: types::Consistency,
42
43 pub serial_consistency: Option<types::SerialConsistency>,
45
46 pub timestamp: Option<i64>,
48
49 pub values: Values,
51}
52
53#[derive(Clone, Copy)]
55#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
56pub enum BatchType {
57 Logged = 0,
61
62 Unlogged = 1,
68
69 Counter = 2,
72}
73
74#[derive(Debug, Error)]
76#[error("Malformed batch type: {value}")]
77pub struct BatchTypeParseError {
78 value: u8,
79}
80
81impl TryFrom<u8> for BatchType {
82 type Error = BatchTypeParseError;
83
84 fn try_from(value: u8) -> Result<Self, Self::Error> {
85 match value {
86 0 => Ok(Self::Logged),
87 1 => Ok(Self::Unlogged),
88 2 => Ok(Self::Counter),
89 _ => Err(BatchTypeParseError { value }),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
96pub enum BatchStatement<'a> {
97 Query {
99 text: Cow<'a, str>,
101 },
102 Prepared {
104 id: Cow<'a, [u8]>,
106 },
107}
108
109impl<Statement, Values> Batch<'_, Statement, Values>
110where
111 for<'s> BatchStatement<'s>: From<&'s Statement>,
112 Statement: Clone,
113 Values: RawBatchValues,
114{
115 fn do_serialize(&self, buf: &mut Vec<u8>) -> Result<(), BatchSerializationError> {
116 buf.put_u8(self.batch_type as u8);
118
119 types::write_short(
121 self.statements
122 .len()
123 .try_into()
124 .map_err(|_| BatchSerializationError::TooManyStatements(self.statements.len()))?,
125 buf,
126 );
127
128 let counts_mismatch_err = |n_value_lists: usize, n_statements: usize| {
129 BatchSerializationError::ValuesAndStatementsLengthMismatch {
130 n_value_lists,
131 n_statements,
132 }
133 };
134 let mut n_serialized_statements = 0usize;
135 let mut value_lists = self.values.batch_values_iter();
136 for (idx, statement) in self.statements.iter().enumerate() {
137 BatchStatement::from(statement)
138 .serialize(buf)
139 .map_err(|err| BatchSerializationError::StatementSerialization {
140 statement_idx: idx,
141 error: err,
142 })?;
143
144 let length_pos = buf.len();
146 buf.extend_from_slice(&[0, 0]);
147 let mut row_writer = RowWriter::new(buf);
148 value_lists
149 .serialize_next(&mut row_writer)
150 .ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))?
151 .map_err(|err: SerializationError| {
152 BatchSerializationError::StatementSerialization {
153 statement_idx: idx,
154 error: BatchStatementSerializationError::ValuesSerialiation(err),
155 }
156 })?;
157 let count: u16 = match row_writer.value_count().try_into() {
159 Ok(n) => n,
160 Err(_) => {
161 return Err(BatchSerializationError::StatementSerialization {
162 statement_idx: idx,
163 error: BatchStatementSerializationError::TooManyValues(
164 row_writer.value_count(),
165 ),
166 })
167 }
168 };
169 buf[length_pos..length_pos + 2].copy_from_slice(&count.to_be_bytes());
170
171 n_serialized_statements += 1;
172 }
173 if value_lists.skip_next().is_some() {
175 return Err(counts_mismatch_err(
176 n_serialized_statements + 1 + value_lists.count(),
177 n_serialized_statements,
178 ));
179 }
180 if n_serialized_statements != self.statements.len() {
181 return Err(BatchSerializationError::BadBatchConstructed {
184 n_announced_statements: self.statements.len(),
185 n_serialized_statements,
186 });
187 }
188
189 types::write_consistency(self.consistency, buf);
191
192 let mut flags = 0;
194 if self.serial_consistency.is_some() {
195 flags |= FLAG_WITH_SERIAL_CONSISTENCY;
196 }
197 if self.timestamp.is_some() {
198 flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
199 }
200
201 buf.put_u8(flags);
202
203 if let Some(serial_consistency) = self.serial_consistency {
204 types::write_serial_consistency(serial_consistency, buf);
205 }
206 if let Some(timestamp) = self.timestamp {
207 types::write_long(timestamp, buf);
208 }
209
210 Ok(())
211 }
212}
213
214impl<Statement, Values> SerializableRequest for Batch<'_, Statement, Values>
215where
216 for<'s> BatchStatement<'s>: From<&'s Statement>,
217 Statement: Clone,
218 Values: RawBatchValues,
219{
220 const OPCODE: RequestOpcode = RequestOpcode::Batch;
221
222 fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
223 self.do_serialize(buf)?;
224 Ok(())
225 }
226}
227
228impl BatchStatement<'_> {
229 fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
230 let kind = buf.get_u8();
231 match kind {
232 0 => {
233 let text = Cow::Owned(types::read_long_string(buf)?.to_owned());
234 Ok(BatchStatement::Query { text })
235 }
236 1 => {
237 let id = types::read_short_bytes(buf)?.to_vec().into();
238 Ok(BatchStatement::Prepared { id })
239 }
240 _ => Err(RequestDeserializationError::UnexpectedBatchStatementKind(
241 kind,
242 )),
243 }
244 }
245}
246
247impl BatchStatement<'_> {
248 fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
249 match self {
250 Self::Query { text } => {
251 buf.put_u8(0);
252 types::write_long_string(text, buf)
253 .map_err(BatchStatementSerializationError::StatementStringSerialization)?;
254 }
255 Self::Prepared { id } => {
256 buf.put_u8(1);
257 types::write_short_bytes(id, buf)
258 .map_err(BatchStatementSerializationError::StatementIdSerialization)?;
259 }
260 }
261
262 Ok(())
263 }
264}
265
266impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> {
267 fn from(value: &'s BatchStatement) -> Self {
268 match value {
269 BatchStatement::Query { text } => BatchStatement::Query { text: text.clone() },
270 BatchStatement::Prepared { id } => BatchStatement::Prepared { id: id.clone() },
271 }
272 }
273}
274
275impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedValues>> {
276 fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
277 let batch_type = buf.get_u8().try_into()?;
278
279 let statements_count: usize = types::read_short(buf)?.into();
280 let statements_with_values = (0..statements_count)
281 .map(|_| {
282 let batch_statement = BatchStatement::deserialize(buf)?;
283
284 let values = SerializedValues::new_from_frame(buf)?;
286
287 Ok((batch_statement, values))
288 })
289 .collect::<Result<Vec<_>, RequestDeserializationError>>()?;
290
291 let consistency = types::read_consistency(buf)?;
292
293 let flags = buf.get_u8();
294 let unknown_flags = flags & (!ALL_FLAGS);
295 if unknown_flags != 0 {
296 return Err(RequestDeserializationError::UnknownFlags {
297 flags: unknown_flags,
298 });
299 }
300 let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
301 let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
302
303 let serial_consistency = serial_consistency_flag
304 .then(|| types::read_consistency(buf))
305 .transpose()?
306 .map(
307 |consistency| match SerialConsistency::try_from(consistency) {
308 Ok(serial_consistency) => Ok(serial_consistency),
309 Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
310 consistency,
311 )),
312 },
313 )
314 .transpose()?;
315
316 let timestamp = default_timestamp_flag
317 .then(|| types::read_long(buf))
318 .transpose()?;
319
320 let (statements, values): (Vec<BatchStatement>, Vec<SerializedValues>) =
321 statements_with_values.into_iter().unzip();
322
323 Ok(Self {
324 batch_type,
325 consistency,
326 serial_consistency,
327 timestamp,
328 statements: Cow::Owned(statements),
329 values,
330 })
331 }
332}
333
334#[non_exhaustive]
336#[derive(Error, Debug, Clone)]
337pub enum BatchSerializationError {
338 #[error("Too many statements in the batch. Received {0} statements, when u16::MAX is maximum possible value.")]
340 TooManyStatements(usize),
341
342 #[error("Number of provided value lists must be equal to number of batch statements (got {n_value_lists} value lists, {n_statements} statements)")]
344 ValuesAndStatementsLengthMismatch {
345 n_value_lists: usize,
346 n_statements: usize,
347 },
348
349 #[error("Failed to serialize batch statement. statement idx: {statement_idx}, error: {error}")]
351 StatementSerialization {
352 statement_idx: usize,
353 error: BatchStatementSerializationError,
354 },
355
356 #[error("Invalid Batch constructed: not as many statements serialized as announced (announced: {n_announced_statements}, serialized: {n_serialized_statements})")]
358 BadBatchConstructed {
359 n_announced_statements: usize,
360 n_serialized_statements: usize,
361 },
362}
363
364#[non_exhaustive]
367#[derive(Error, Debug, Clone)]
368pub enum BatchStatementSerializationError {
369 #[error("Failed to serialize unprepared statement's content: {0}")]
371 StatementStringSerialization(TryFromIntError),
372
373 #[error("Malformed prepared statement's id: {0}")]
375 StatementIdSerialization(TryFromIntError),
376
377 #[error("Failed to serialize statement's values: {0}")]
379 ValuesSerialiation(SerializationError),
380
381 #[error("Too many values provided for the statement: {0}")]
383 TooManyValues(usize),
384}