scylla/statement/
batch.rs1use std::borrow::Cow;
2use std::sync::Arc;
3
4use crate::client::execution_profile::ExecutionProfileHandle;
5use crate::observability::history::HistoryListener;
6use crate::policies::retry::RetryPolicy;
7use crate::statement::prepared::PreparedStatement;
8use crate::statement::unprepared::Statement;
9
10use super::StatementConfig;
11use super::{Consistency, SerialConsistency};
12pub use crate::frame::request::batch::BatchType;
13
14#[derive(Clone)]
18pub struct Batch {
19 pub(crate) config: StatementConfig,
20
21 pub statements: Vec<BatchStatement>,
22 batch_type: BatchType,
23}
24
25impl Batch {
26 pub fn new(batch_type: BatchType) -> Self {
28 Self {
29 batch_type,
30 ..Default::default()
31 }
32 }
33
34 pub(crate) fn new_from(batch: &Batch) -> Batch {
36 let batch_type = batch.get_type();
37 let config = batch.config.clone();
38 Batch {
39 batch_type,
40 config,
41 ..Default::default()
42 }
43 }
44
45 pub fn new_with_statements(batch_type: BatchType, statements: Vec<BatchStatement>) -> Self {
47 Self {
48 batch_type,
49 statements,
50 ..Default::default()
51 }
52 }
53
54 pub fn append_statement(&mut self, statement: impl Into<BatchStatement>) {
56 self.statements.push(statement.into());
57 }
58
59 pub fn get_type(&self) -> BatchType {
61 self.batch_type
62 }
63
64 pub fn set_consistency(&mut self, c: Consistency) {
66 self.config.consistency = Some(c);
67 }
68
69 pub fn get_consistency(&self) -> Option<Consistency> {
72 self.config.consistency
73 }
74
75 pub fn set_serial_consistency(&mut self, sc: Option<SerialConsistency>) {
78 self.config.serial_consistency = Some(sc);
79 }
80
81 pub fn get_serial_consistency(&self) -> Option<SerialConsistency> {
84 self.config.serial_consistency.flatten()
85 }
86
87 pub fn set_is_idempotent(&mut self, is_idempotent: bool) {
93 self.config.is_idempotent = is_idempotent;
94 }
95
96 pub fn get_is_idempotent(&self) -> bool {
98 self.config.is_idempotent
99 }
100
101 pub fn set_tracing(&mut self, should_trace: bool) {
105 self.config.tracing = should_trace;
106 }
107
108 pub fn get_tracing(&self) -> bool {
110 self.config.tracing
111 }
112
113 pub fn set_timestamp(&mut self, timestamp: Option<i64>) {
117 self.config.timestamp = timestamp
118 }
119
120 pub fn get_timestamp(&self) -> Option<i64> {
122 self.config.timestamp
123 }
124
125 #[inline]
127 pub fn set_retry_policy(&mut self, retry_policy: Option<Arc<dyn RetryPolicy>>) {
128 self.config.retry_policy = retry_policy;
129 }
130
131 #[inline]
133 pub fn get_retry_policy(&self) -> Option<&Arc<dyn RetryPolicy>> {
134 self.config.retry_policy.as_ref()
135 }
136
137 pub fn set_history_listener(&mut self, history_listener: Arc<dyn HistoryListener>) {
139 self.config.history_listener = Some(history_listener);
140 }
141
142 pub fn remove_history_listener(&mut self) -> Option<Arc<dyn HistoryListener>> {
144 self.config.history_listener.take()
145 }
146
147 pub fn set_execution_profile_handle(&mut self, profile_handle: Option<ExecutionProfileHandle>) {
150 self.config.execution_profile_handle = profile_handle;
151 }
152
153 pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
155 self.config.execution_profile_handle.as_ref()
156 }
157}
158
159impl Default for Batch {
160 fn default() -> Self {
161 Self {
162 statements: Vec::new(),
163 batch_type: BatchType::Logged,
164 config: Default::default(),
165 }
166 }
167}
168
169#[derive(Clone)]
171#[non_exhaustive]
172pub enum BatchStatement {
173 Query(Statement),
174 PreparedStatement(PreparedStatement),
175}
176
177impl From<&str> for BatchStatement {
178 fn from(s: &str) -> Self {
179 BatchStatement::Query(Statement::from(s))
180 }
181}
182
183impl From<Statement> for BatchStatement {
184 fn from(q: Statement) -> Self {
185 BatchStatement::Query(q)
186 }
187}
188
189impl From<PreparedStatement> for BatchStatement {
190 fn from(p: PreparedStatement) -> Self {
191 BatchStatement::PreparedStatement(p)
192 }
193}
194
195impl<'a: 'b, 'b> From<&'a BatchStatement>
196 for scylla_cql::frame::request::batch::BatchStatement<'b>
197{
198 fn from(val: &'a BatchStatement) -> Self {
199 match val {
200 BatchStatement::Query(query) => {
201 scylla_cql::frame::request::batch::BatchStatement::Query {
202 text: Cow::Borrowed(&query.contents),
203 }
204 }
205 BatchStatement::PreparedStatement(prepared) => {
206 scylla_cql::frame::request::batch::BatchStatement::Prepared {
207 id: Cow::Borrowed(prepared.get_id()),
208 }
209 }
210 }
211 }
212}
213
214pub(crate) mod batch_values {
215 use scylla_cql::serialize::batch::BatchValues;
216 use scylla_cql::serialize::batch::BatchValuesIterator;
217 use scylla_cql::serialize::row::RowSerializationContext;
218 use scylla_cql::serialize::row::SerializedValues;
219 use scylla_cql::serialize::{RowWriter, SerializationError};
220
221 use crate::errors::ExecutionError;
222 use crate::routing::Token;
223 use crate::statement::prepared::PartitionKeyError;
224
225 use super::BatchStatement;
226
227 pub(crate) fn peek_first_token<'bv>(
239 values: impl BatchValues + 'bv,
240 statement: Option<&BatchStatement>,
241 ) -> Result<(Option<Token>, impl BatchValues + 'bv), ExecutionError> {
242 let mut values_iter = values.batch_values_iter();
243 let (token, first_values) = match statement {
244 Some(BatchStatement::PreparedStatement(ps)) => {
245 let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
246 let (first_values, did_write) = SerializedValues::from_closure(|writer| {
247 values_iter
248 .serialize_next(&ctx, writer)
249 .transpose()
250 .map(|o| o.is_some())
251 })?;
252 if did_write {
253 let token = ps
254 .calculate_token_untyped(&first_values)
255 .map_err(PartitionKeyError::into_execution_error)?;
256 (token, Some(first_values))
257 } else {
258 (None, None)
259 }
260 }
261 _ => (None, None),
262 };
263
264 std::mem::drop(values_iter);
267
268 let values = BatchValuesFirstSerialized::new(values, first_values);
270
271 Ok((token, values))
272 }
273
274 struct BatchValuesFirstSerialized<BV> {
275 first: Option<SerializedValues>,
278 rest: BV,
279 }
280
281 impl<BV> BatchValuesFirstSerialized<BV> {
282 fn new(rest: BV, first: Option<SerializedValues>) -> Self {
283 Self { first, rest }
284 }
285 }
286
287 impl<BV> BatchValues for BatchValuesFirstSerialized<BV>
288 where
289 BV: BatchValues,
290 {
291 type BatchValuesIter<'r>
292 = BatchValuesFirstSerializedIterator<'r, BV::BatchValuesIter<'r>>
293 where
294 Self: 'r;
295
296 fn batch_values_iter(&self) -> Self::BatchValuesIter<'_> {
297 BatchValuesFirstSerializedIterator {
298 first: self.first.as_ref(),
299 rest: self.rest.batch_values_iter(),
300 }
301 }
302 }
303
304 struct BatchValuesFirstSerializedIterator<'f, BVI> {
305 first: Option<&'f SerializedValues>,
306 rest: BVI,
307 }
308
309 impl<'f, BVI> BatchValuesIterator<'f> for BatchValuesFirstSerializedIterator<'f, BVI>
310 where
311 BVI: BatchValuesIterator<'f>,
312 {
313 #[inline]
314 fn serialize_next(
315 &mut self,
316 ctx: &RowSerializationContext<'_>,
317 writer: &mut RowWriter,
318 ) -> Option<Result<(), SerializationError>> {
319 match self.first.take() {
320 Some(sr) => {
321 writer.append_serialize_row(sr);
322 self.rest.skip_next();
323 Some(Ok(()))
324 }
325 None => self.rest.serialize_next(ctx, writer),
326 }
327 }
328
329 #[inline]
330 fn is_empty_next(&mut self) -> Option<bool> {
331 match self.first.take() {
332 Some(s) => {
333 self.rest.skip_next();
334 Some(s.is_empty())
335 }
336 None => self.rest.is_empty_next(),
337 }
338 }
339
340 #[inline]
341 fn skip_next(&mut self) -> Option<()> {
342 self.first = None;
343 self.rest.skip_next()
344 }
345
346 #[inline]
347 fn count(self) -> usize
348 where
349 Self: Sized,
350 {
351 self.rest.count()
352 }
353 }
354}