scylla_cql/frame/request/
batch.rs1use bytes::{Buf, BufMut};
2use std::{borrow::Cow, convert::TryInto, num::TryFromIntError};
3use thiserror::Error;
4
5use crate::frame::{
6 frame_errors::CqlRequestSerializationError,
7 request::{RequestOpcode, SerializableRequest},
8 types::{self, SerialConsistency},
9};
10use crate::serialize::{
11 raw_batch::{RawBatchValues, RawBatchValuesIterator},
12 row::SerializedValues,
13 RowWriter, SerializationError,
14};
15
16use super::{DeserializableRequest, RequestDeserializationError};
17
18const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10;
20const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
21const ALL_FLAGS: u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP;
22
23#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
24pub struct Batch<'b, Statement, Values>
25where
26 BatchStatement<'b>: From<&'b Statement>,
27 Statement: Clone,
28 Values: RawBatchValues,
29{
30 pub statements: Cow<'b, [Statement]>,
31 pub batch_type: BatchType,
32 pub consistency: types::Consistency,
33 pub serial_consistency: Option<types::SerialConsistency>,
34 pub timestamp: Option<i64>,
35 pub values: Values,
36}
37
38#[derive(Clone, Copy)]
40#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
41pub enum BatchType {
42 Logged = 0,
43 Unlogged = 1,
44 Counter = 2,
45}
46
47#[derive(Debug, Error)]
48#[error("Malformed batch type: {value}")]
49pub struct BatchTypeParseError {
50 value: u8,
51}
52
53impl TryFrom<u8> for BatchType {
54 type Error = BatchTypeParseError;
55
56 fn try_from(value: u8) -> Result<Self, Self::Error> {
57 match value {
58 0 => Ok(Self::Logged),
59 1 => Ok(Self::Unlogged),
60 2 => Ok(Self::Counter),
61 _ => Err(BatchTypeParseError { value }),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
67pub enum BatchStatement<'a> {
68 Query { text: Cow<'a, str> },
69 Prepared { id: Cow<'a, [u8]> },
70}
71
72impl<Statement, Values> Batch<'_, Statement, Values>
73where
74 for<'s> BatchStatement<'s>: From<&'s Statement>,
75 Statement: Clone,
76 Values: RawBatchValues,
77{
78 fn do_serialize(&self, buf: &mut Vec<u8>) -> Result<(), BatchSerializationError> {
79 buf.put_u8(self.batch_type as u8);
81
82 types::write_short(
84 self.statements
85 .len()
86 .try_into()
87 .map_err(|_| BatchSerializationError::TooManyStatements(self.statements.len()))?,
88 buf,
89 );
90
91 let counts_mismatch_err = |n_value_lists: usize, n_statements: usize| {
92 BatchSerializationError::ValuesAndStatementsLengthMismatch {
93 n_value_lists,
94 n_statements,
95 }
96 };
97 let mut n_serialized_statements = 0usize;
98 let mut value_lists = self.values.batch_values_iter();
99 for (idx, statement) in self.statements.iter().enumerate() {
100 BatchStatement::from(statement)
101 .serialize(buf)
102 .map_err(|err| BatchSerializationError::StatementSerialization {
103 statement_idx: idx,
104 error: err,
105 })?;
106
107 let length_pos = buf.len();
109 buf.extend_from_slice(&[0, 0]);
110 let mut row_writer = RowWriter::new(buf);
111 value_lists
112 .serialize_next(&mut row_writer)
113 .ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))?
114 .map_err(|err: SerializationError| {
115 BatchSerializationError::StatementSerialization {
116 statement_idx: idx,
117 error: BatchStatementSerializationError::ValuesSerialiation(err),
118 }
119 })?;
120 let count: u16 = match row_writer.value_count().try_into() {
122 Ok(n) => n,
123 Err(_) => {
124 return Err(BatchSerializationError::StatementSerialization {
125 statement_idx: idx,
126 error: BatchStatementSerializationError::TooManyValues(
127 row_writer.value_count(),
128 ),
129 })
130 }
131 };
132 buf[length_pos..length_pos + 2].copy_from_slice(&count.to_be_bytes());
133
134 n_serialized_statements += 1;
135 }
136 if value_lists.skip_next().is_some() {
138 return Err(counts_mismatch_err(
139 n_serialized_statements + 1 + value_lists.count(),
140 n_serialized_statements,
141 ));
142 }
143 if n_serialized_statements != self.statements.len() {
144 return Err(BatchSerializationError::BadBatchConstructed {
147 n_announced_statements: self.statements.len(),
148 n_serialized_statements,
149 });
150 }
151
152 types::write_consistency(self.consistency, buf);
154
155 let mut flags = 0;
157 if self.serial_consistency.is_some() {
158 flags |= FLAG_WITH_SERIAL_CONSISTENCY;
159 }
160 if self.timestamp.is_some() {
161 flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
162 }
163
164 buf.put_u8(flags);
165
166 if let Some(serial_consistency) = self.serial_consistency {
167 types::write_serial_consistency(serial_consistency, buf);
168 }
169 if let Some(timestamp) = self.timestamp {
170 types::write_long(timestamp, buf);
171 }
172
173 Ok(())
174 }
175}
176
177impl<Statement, Values> SerializableRequest for Batch<'_, Statement, Values>
178where
179 for<'s> BatchStatement<'s>: From<&'s Statement>,
180 Statement: Clone,
181 Values: RawBatchValues,
182{
183 const OPCODE: RequestOpcode = RequestOpcode::Batch;
184
185 fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
186 self.do_serialize(buf)?;
187 Ok(())
188 }
189}
190
191impl BatchStatement<'_> {
192 fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
193 let kind = buf.get_u8();
194 match kind {
195 0 => {
196 let text = Cow::Owned(types::read_long_string(buf)?.to_owned());
197 Ok(BatchStatement::Query { text })
198 }
199 1 => {
200 let id = types::read_short_bytes(buf)?.to_vec().into();
201 Ok(BatchStatement::Prepared { id })
202 }
203 _ => Err(RequestDeserializationError::UnexpectedBatchStatementKind(
204 kind,
205 )),
206 }
207 }
208}
209
210impl BatchStatement<'_> {
211 fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
212 match self {
213 Self::Query { text } => {
214 buf.put_u8(0);
215 types::write_long_string(text, buf)
216 .map_err(BatchStatementSerializationError::StatementStringSerialization)?;
217 }
218 Self::Prepared { id } => {
219 buf.put_u8(1);
220 types::write_short_bytes(id, buf)
221 .map_err(BatchStatementSerializationError::StatementIdSerialization)?;
222 }
223 }
224
225 Ok(())
226 }
227}
228
229#[allow(clippy::needless_lifetimes)]
232impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> {
233 fn from(value: &'s BatchStatement) -> Self {
234 match value {
235 BatchStatement::Query { text } => BatchStatement::Query { text: text.clone() },
236 BatchStatement::Prepared { id } => BatchStatement::Prepared { id: id.clone() },
237 }
238 }
239}
240
241impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedValues>> {
242 fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
243 let batch_type = buf.get_u8().try_into()?;
244
245 let statements_count: usize = types::read_short(buf)?.into();
246 let statements_with_values = (0..statements_count)
247 .map(|_| {
248 let batch_statement = BatchStatement::deserialize(buf)?;
249
250 let values = SerializedValues::new_from_frame(buf)?;
252
253 Ok((batch_statement, values))
254 })
255 .collect::<Result<Vec<_>, RequestDeserializationError>>()?;
256
257 let consistency = types::read_consistency(buf)?;
258
259 let flags = buf.get_u8();
260 let unknown_flags = flags & (!ALL_FLAGS);
261 if unknown_flags != 0 {
262 return Err(RequestDeserializationError::UnknownFlags {
263 flags: unknown_flags,
264 });
265 }
266 let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
267 let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
268
269 let serial_consistency = serial_consistency_flag
270 .then(|| types::read_consistency(buf))
271 .transpose()?
272 .map(
273 |consistency| match SerialConsistency::try_from(consistency) {
274 Ok(serial_consistency) => Ok(serial_consistency),
275 Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
276 consistency,
277 )),
278 },
279 )
280 .transpose()?;
281
282 let timestamp = default_timestamp_flag
283 .then(|| types::read_long(buf))
284 .transpose()?;
285
286 let (statements, values): (Vec<BatchStatement>, Vec<SerializedValues>) =
287 statements_with_values.into_iter().unzip();
288
289 Ok(Self {
290 batch_type,
291 consistency,
292 serial_consistency,
293 timestamp,
294 statements: Cow::Owned(statements),
295 values,
296 })
297 }
298}
299
300#[non_exhaustive]
302#[derive(Error, Debug, Clone)]
303pub enum BatchSerializationError {
304 #[error("Too many statements in the batch. Received {0} statements, when u16::MAX is maximum possible value.")]
306 TooManyStatements(usize),
307
308 #[error("Number of provided value lists must be equal to number of batch statements (got {n_value_lists} value lists, {n_statements} statements)")]
310 ValuesAndStatementsLengthMismatch {
311 n_value_lists: usize,
312 n_statements: usize,
313 },
314
315 #[error("Failed to serialize batch statement. statement idx: {statement_idx}, error: {error}")]
317 StatementSerialization {
318 statement_idx: usize,
319 error: BatchStatementSerializationError,
320 },
321
322 #[error("Invalid Batch constructed: not as many statements serialized as announced (announced: {n_announced_statements}, serialized: {n_serialized_statements})")]
324 BadBatchConstructed {
325 n_announced_statements: usize,
326 n_serialized_statements: usize,
327 },
328}
329
330#[non_exhaustive]
333#[derive(Error, Debug, Clone)]
334pub enum BatchStatementSerializationError {
335 #[error("Failed to serialize unprepared statement's content: {0}")]
337 StatementStringSerialization(TryFromIntError),
338
339 #[error("Malformed prepared statement's id: {0}")]
341 StatementIdSerialization(TryFromIntError),
342
343 #[error("Failed to serialize statement's values: {0}")]
345 ValuesSerialiation(SerializationError),
346
347 #[error("Too many values provided for the statement: {0}")]
349 TooManyValues(usize),
350}