1use bytes::{Bytes, BytesMut};
2use scylla_cql::frame::response::result::{
3 ColumnSpec, PartitionKeyIndex, ResultMetadata, TableSpec,
4};
5use scylla_cql::frame::types::RawValue;
6use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
7use scylla_cql::serialize::SerializationError;
8use smallvec::{smallvec, SmallVec};
9use std::convert::TryInto;
10use std::sync::Arc;
11use std::time::Duration;
12use thiserror::Error;
13use uuid::Uuid;
14
15use super::{PageSize, StatementConfig};
16use crate::client::execution_profile::ExecutionProfileHandle;
17use crate::errors::{BadQuery, ExecutionError};
18use crate::frame::response::result::PreparedMetadata;
19use crate::frame::types::{Consistency, SerialConsistency};
20use crate::observability::history::HistoryListener;
21use crate::policies::retry::RetryPolicy;
22use crate::response::query_result::ColumnSpecs;
23use crate::routing::partitioner::{Partitioner, PartitionerHasher, PartitionerName};
24use crate::routing::Token;
25
26#[derive(Debug)]
92pub struct PreparedStatement {
93 pub(crate) config: StatementConfig,
94 pub prepare_tracing_ids: Vec<Uuid>,
95
96 id: Bytes,
97 shared: Arc<PreparedStatementSharedData>,
98 page_size: PageSize,
99 partitioner_name: PartitionerName,
100 is_confirmed_lwt: bool,
101}
102
103#[derive(Debug)]
104struct PreparedStatementSharedData {
105 metadata: PreparedMetadata,
106 result_metadata: Arc<ResultMetadata<'static>>,
107 statement: String,
108}
109
110impl Clone for PreparedStatement {
111 fn clone(&self) -> Self {
112 Self {
113 config: self.config.clone(),
114 prepare_tracing_ids: Vec::new(),
115 id: self.id.clone(),
116 shared: self.shared.clone(),
117 page_size: self.page_size,
118 partitioner_name: self.partitioner_name.clone(),
119 is_confirmed_lwt: self.is_confirmed_lwt,
120 }
121 }
122}
123
124impl PreparedStatement {
125 pub(crate) fn new(
126 id: Bytes,
127 is_lwt: bool,
128 metadata: PreparedMetadata,
129 result_metadata: Arc<ResultMetadata<'static>>,
130 statement: String,
131 page_size: PageSize,
132 config: StatementConfig,
133 ) -> Self {
134 Self {
135 id,
136 shared: Arc::new(PreparedStatementSharedData {
137 metadata,
138 result_metadata,
139 statement,
140 }),
141 prepare_tracing_ids: Vec::new(),
142 page_size,
143 config,
144 partitioner_name: Default::default(),
145 is_confirmed_lwt: is_lwt,
146 }
147 }
148
149 pub fn get_id(&self) -> &Bytes {
150 &self.id
151 }
152
153 pub fn get_statement(&self) -> &str {
154 &self.shared.statement
155 }
156
157 pub fn set_page_size(&mut self, page_size: i32) {
161 self.page_size = page_size
162 .try_into()
163 .unwrap_or_else(|err| panic!("PreparedStatement::set_page_size: {err}"));
164 }
165
166 pub(crate) fn get_validated_page_size(&self) -> PageSize {
168 self.page_size
169 }
170
171 pub fn get_page_size(&self) -> i32 {
173 self.page_size.inner()
174 }
175
176 pub fn get_prepare_tracing_ids(&self) -> &[Uuid] {
178 &self.prepare_tracing_ids
179 }
180
181 pub fn is_token_aware(&self) -> bool {
185 !self.get_prepared_metadata().pk_indexes.is_empty()
186 }
187
188 pub fn is_confirmed_lwt(&self) -> bool {
196 self.is_confirmed_lwt
197 }
198
199 pub fn compute_partition_key(
209 &self,
210 bound_values: &impl SerializeRow,
211 ) -> Result<Bytes, PartitionKeyError> {
212 let serialized = self.serialize_values(bound_values)?;
213 let partition_key = self.extract_partition_key(&serialized)?;
214 let mut buf = BytesMut::new();
215 let mut writer = |chunk: &[u8]| buf.extend_from_slice(chunk);
216
217 partition_key.write_encoded_partition_key(&mut writer)?;
218
219 Ok(buf.freeze())
220 }
221
222 pub(crate) fn extract_partition_key<'ps>(
226 &'ps self,
227 bound_values: &'ps SerializedValues,
228 ) -> Result<PartitionKey<'ps>, PartitionKeyExtractionError> {
229 PartitionKey::new(self.get_prepared_metadata(), bound_values)
230 }
231
232 pub(crate) fn extract_partition_key_and_calculate_token<'ps>(
233 &'ps self,
234 partitioner_name: &'ps PartitionerName,
235 serialized_values: &'ps SerializedValues,
236 ) -> Result<Option<(PartitionKey<'ps>, Token)>, PartitionKeyError> {
237 if !self.is_token_aware() {
238 return Ok(None);
239 }
240
241 let partition_key = self.extract_partition_key(serialized_values)?;
242 let token = partition_key.calculate_token(partitioner_name)?;
243
244 Ok(Some((partition_key, token)))
245 }
246
247 pub fn calculate_token(
255 &self,
256 values: &impl SerializeRow,
257 ) -> Result<Option<Token>, PartitionKeyError> {
258 self.calculate_token_untyped(&self.serialize_values(values)?)
259 }
260
261 pub(crate) fn calculate_token_untyped(
264 &self,
265 values: &SerializedValues,
266 ) -> Result<Option<Token>, PartitionKeyError> {
267 self.extract_partition_key_and_calculate_token(&self.partitioner_name, values)
268 .map(|opt| opt.map(|(_pk, token)| token))
269 }
270
271 pub fn get_table_spec(&self) -> Option<&TableSpec> {
273 self.get_prepared_metadata()
274 .col_specs
275 .first()
276 .map(|spec| spec.table_spec())
277 }
278
279 pub fn get_keyspace_name(&self) -> Option<&str> {
281 self.get_prepared_metadata()
282 .col_specs
283 .first()
284 .map(|col_spec| col_spec.table_spec().ks_name())
285 }
286
287 pub fn get_table_name(&self) -> Option<&str> {
289 self.get_prepared_metadata()
290 .col_specs
291 .first()
292 .map(|col_spec| col_spec.table_spec().table_name())
293 }
294
295 pub fn set_consistency(&mut self, c: Consistency) {
297 self.config.consistency = Some(c);
298 }
299
300 pub fn get_consistency(&self) -> Option<Consistency> {
303 self.config.consistency
304 }
305
306 pub fn set_serial_consistency(&mut self, sc: Option<SerialConsistency>) {
309 self.config.serial_consistency = Some(sc);
310 }
311
312 pub fn get_serial_consistency(&self) -> Option<SerialConsistency> {
315 self.config.serial_consistency.flatten()
316 }
317
318 pub fn set_is_idempotent(&mut self, is_idempotent: bool) {
324 self.config.is_idempotent = is_idempotent;
325 }
326
327 pub fn get_is_idempotent(&self) -> bool {
329 self.config.is_idempotent
330 }
331
332 pub fn set_tracing(&mut self, should_trace: bool) {
336 self.config.tracing = should_trace;
337 }
338
339 pub fn get_tracing(&self) -> bool {
341 self.config.tracing
342 }
343
344 pub fn set_use_cached_result_metadata(&mut self, use_cached_metadata: bool) {
356 self.config.skip_result_metadata = use_cached_metadata;
357 }
358
359 pub fn get_use_cached_result_metadata(&self) -> bool {
362 self.config.skip_result_metadata
363 }
364
365 pub fn set_timestamp(&mut self, timestamp: Option<i64>) {
370 self.config.timestamp = timestamp
371 }
372
373 pub fn get_timestamp(&self) -> Option<i64> {
375 self.config.timestamp
376 }
377
378 pub fn set_request_timeout(&mut self, timeout: Option<Duration>) {
383 self.config.request_timeout = timeout
384 }
385
386 pub fn get_request_timeout(&self) -> Option<Duration> {
388 self.config.request_timeout
389 }
390
391 pub(crate) fn set_partitioner_name(&mut self, partitioner_name: PartitionerName) {
393 self.partitioner_name = partitioner_name;
394 }
395
396 pub(crate) fn get_prepared_metadata(&self) -> &PreparedMetadata {
398 &self.shared.metadata
399 }
400
401 pub fn get_variable_col_specs(&self) -> ColumnSpecs<'_, 'static> {
403 ColumnSpecs::new(&self.shared.metadata.col_specs)
404 }
405
406 pub fn get_variable_pk_indexes(&self) -> &[PartitionKeyIndex] {
408 &self.shared.metadata.pk_indexes
409 }
410
411 pub(crate) fn get_result_metadata(&self) -> &Arc<ResultMetadata<'static>> {
413 &self.shared.result_metadata
414 }
415
416 pub fn get_result_set_col_specs(&self) -> ColumnSpecs<'_, 'static> {
418 ColumnSpecs::new(self.shared.result_metadata.col_specs())
419 }
420
421 pub fn get_partitioner_name(&self) -> &PartitionerName {
423 &self.partitioner_name
424 }
425
426 #[inline]
428 pub fn set_retry_policy(&mut self, retry_policy: Option<Arc<dyn RetryPolicy>>) {
429 self.config.retry_policy = retry_policy;
430 }
431
432 #[inline]
434 pub fn get_retry_policy(&self) -> Option<&Arc<dyn RetryPolicy>> {
435 self.config.retry_policy.as_ref()
436 }
437
438 pub fn set_history_listener(&mut self, history_listener: Arc<dyn HistoryListener>) {
440 self.config.history_listener = Some(history_listener);
441 }
442
443 pub fn remove_history_listener(&mut self) -> Option<Arc<dyn HistoryListener>> {
445 self.config.history_listener.take()
446 }
447
448 pub fn set_execution_profile_handle(&mut self, profile_handle: Option<ExecutionProfileHandle>) {
451 self.config.execution_profile_handle = profile_handle;
452 }
453
454 pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
456 self.config.execution_profile_handle.as_ref()
457 }
458
459 pub(crate) fn serialize_values(
460 &self,
461 values: &impl SerializeRow,
462 ) -> Result<SerializedValues, SerializationError> {
463 let ctx = RowSerializationContext::from_prepared(self.get_prepared_metadata());
464 SerializedValues::from_serializable(&ctx, values)
465 }
466}
467
468#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
469#[non_exhaustive]
470pub enum PartitionKeyExtractionError {
471 #[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
472 NoPkIndexValue(u16, u16),
473}
474
475#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
476#[non_exhaustive]
477pub enum TokenCalculationError {
478 #[error("Value bytes too long to create partition key, max 65 535 allowed! value.len(): {0}")]
479 ValueTooLong(usize),
480}
481
482#[derive(Clone, Debug, Error)]
484#[non_exhaustive]
485pub enum PartitionKeyError {
486 #[error(transparent)]
488 PartitionKeyExtraction(#[from] PartitionKeyExtractionError),
489
490 #[error(transparent)]
492 TokenCalculation(#[from] TokenCalculationError),
493
494 #[error(transparent)]
496 Serialization(#[from] SerializationError),
497}
498
499impl PartitionKeyError {
500 pub fn into_execution_error(self) -> ExecutionError {
502 match self {
503 PartitionKeyError::PartitionKeyExtraction(_) => {
504 ExecutionError::BadQuery(BadQuery::PartitionKeyExtraction)
505 }
506 PartitionKeyError::TokenCalculation(TokenCalculationError::ValueTooLong(
507 values_len,
508 )) => {
509 ExecutionError::BadQuery(BadQuery::ValuesTooLongForKey(values_len, u16::MAX.into()))
510 }
511 PartitionKeyError::Serialization(err) => {
512 ExecutionError::BadQuery(BadQuery::SerializationError(err))
513 }
514 }
515 }
516}
517
518pub(crate) type PartitionKeyValue<'ps> = (&'ps [u8], &'ps ColumnSpec<'ps>);
519
520pub(crate) struct PartitionKey<'ps> {
521 pk_values: SmallVec<[Option<PartitionKeyValue<'ps>>; PartitionKey::SMALLVEC_ON_STACK_SIZE]>,
522}
523
524impl<'ps> PartitionKey<'ps> {
525 const SMALLVEC_ON_STACK_SIZE: usize = 8;
526
527 fn new(
528 prepared_metadata: &'ps PreparedMetadata,
529 bound_values: &'ps SerializedValues,
530 ) -> Result<Self, PartitionKeyExtractionError> {
531 let mut pk_values: SmallVec<[_; PartitionKey::SMALLVEC_ON_STACK_SIZE]> =
534 smallvec![None; prepared_metadata.pk_indexes.len()];
535 let mut values_iter = bound_values.iter();
536 let mut values_iter_offset = 0;
541 for pk_index in prepared_metadata.pk_indexes.iter().copied() {
542 let next_val = values_iter
544 .nth((pk_index.index - values_iter_offset) as usize)
545 .ok_or_else(|| {
546 PartitionKeyExtractionError::NoPkIndexValue(
547 pk_index.index,
548 bound_values.element_count(),
549 )
550 })?;
551 if let RawValue::Value(v) = next_val {
553 let spec = &prepared_metadata.col_specs[pk_index.index as usize];
554 pk_values[pk_index.sequence as usize] = Some((v, spec));
555 }
556 values_iter_offset = pk_index.index + 1;
557 }
558 Ok(Self { pk_values })
559 }
560
561 pub(crate) fn iter(&self) -> impl Iterator<Item = PartitionKeyValue<'ps>> + Clone + '_ {
562 self.pk_values.iter().flatten().copied()
563 }
564
565 fn write_encoded_partition_key(
566 &self,
567 writer: &mut impl FnMut(&[u8]),
568 ) -> Result<(), TokenCalculationError> {
569 let mut pk_val_iter = self.iter().map(|(val, _spec)| val);
570 if let Some(first_value) = pk_val_iter.next() {
571 if let Some(second_value) = pk_val_iter.next() {
572 for value in std::iter::once(first_value)
574 .chain(std::iter::once(second_value))
575 .chain(pk_val_iter)
576 {
577 let v_len_u16: u16 = value
578 .len()
579 .try_into()
580 .map_err(|_| TokenCalculationError::ValueTooLong(value.len()))?;
581 writer(&v_len_u16.to_be_bytes());
582 writer(value);
583 writer(&[0u8]);
584 }
585 } else {
586 writer(first_value);
588 }
589 }
590 Ok(())
591 }
592
593 pub(crate) fn calculate_token(
594 &self,
595 partitioner_name: &PartitionerName,
596 ) -> Result<Token, TokenCalculationError> {
597 let mut partitioner_hasher = partitioner_name.build_hasher();
598 let mut writer = |chunk: &[u8]| partitioner_hasher.write(chunk);
599
600 self.write_encoded_partition_key(&mut writer)?;
601
602 Ok(partitioner_hasher.finish())
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use scylla_cql::frame::response::result::{
609 ColumnSpec, ColumnType, NativeType, PartitionKeyIndex, PreparedMetadata, TableSpec,
610 };
611 use scylla_cql::serialize::row::SerializedValues;
612
613 use crate::statement::prepared::PartitionKey;
614 use crate::test_utils::setup_tracing;
615
616 fn make_meta(
617 cols: impl IntoIterator<Item = ColumnType<'static>>,
618 idx: impl IntoIterator<Item = usize>,
619 ) -> PreparedMetadata {
620 let table_spec = TableSpec::owned("ks".to_owned(), "t".to_owned());
621 let col_specs: Vec<_> = cols
622 .into_iter()
623 .enumerate()
624 .map(|(i, typ)| ColumnSpec::owned(format!("col_{}", i), typ, table_spec.clone()))
625 .collect();
626 let mut pk_indexes = idx
627 .into_iter()
628 .enumerate()
629 .map(|(sequence, index)| PartitionKeyIndex {
630 index: index as u16,
631 sequence: sequence as u16,
632 })
633 .collect::<Vec<_>>();
634 pk_indexes.sort_unstable_by_key(|pki| pki.index);
635 PreparedMetadata {
636 flags: 0,
637 col_count: col_specs.len(),
638 col_specs,
639 pk_indexes,
640 }
641 }
642
643 #[test]
644 fn test_partition_key_multiple_columns_shuffled() {
645 setup_tracing();
646 let meta = make_meta(
647 [
648 ColumnType::Native(NativeType::TinyInt),
649 ColumnType::Native(NativeType::SmallInt),
650 ColumnType::Native(NativeType::Int),
651 ColumnType::Native(NativeType::BigInt),
652 ColumnType::Native(NativeType::Blob),
653 ],
654 [4, 0, 3],
655 );
656 let mut values = SerializedValues::new();
657 values
658 .add_value(&67i8, &ColumnType::Native(NativeType::TinyInt))
659 .unwrap();
660 values
661 .add_value(&42i16, &ColumnType::Native(NativeType::SmallInt))
662 .unwrap();
663 values
664 .add_value(&23i32, &ColumnType::Native(NativeType::Int))
665 .unwrap();
666 values
667 .add_value(&89i64, &ColumnType::Native(NativeType::BigInt))
668 .unwrap();
669 values
670 .add_value(&[1u8, 2, 3, 4, 5], &ColumnType::Native(NativeType::Blob))
671 .unwrap();
672
673 let pk = PartitionKey::new(&meta, &values).unwrap();
674 let pk_cols = Vec::from_iter(pk.iter());
675 assert_eq!(
676 pk_cols,
677 vec![
678 ([1u8, 2, 3, 4, 5].as_slice(), &meta.col_specs[4]),
679 (67i8.to_be_bytes().as_ref(), &meta.col_specs[0]),
680 (89i64.to_be_bytes().as_ref(), &meta.col_specs[3]),
681 ]
682 );
683 }
684}