1use crate::errors::{ExecutionError, PagerExecutionError, PrepareError};
2use crate::response::query_result::QueryResult;
3use crate::response::{PagingState, PagingStateResponse};
4use crate::routing::partitioner::PartitionerName;
5use crate::statement::batch::{Batch, BatchStatement};
6use crate::statement::prepared::PreparedStatement;
7use crate::statement::unprepared::Statement;
8use bytes::Bytes;
9use dashmap::DashMap;
10use futures::future::try_join_all;
11use scylla_cql::frame::response::result::{PreparedMetadata, ResultMetadata};
12use scylla_cql::serialize::batch::BatchValues;
13use scylla_cql::serialize::row::SerializeRow;
14use std::collections::hash_map::RandomState;
15use std::fmt;
16use std::hash::BuildHasher;
17use std::sync::Arc;
18
19use crate::client::pager::QueryPager;
20use crate::client::session::Session;
21
22#[derive(Debug)]
27struct RawPreparedStatementData {
28 id: Bytes,
29 is_confirmed_lwt: bool,
30 metadata: PreparedMetadata,
31 result_metadata: Arc<ResultMetadata<'static>>,
32 partitioner_name: PartitionerName,
33}
34
35pub struct CachingSession<S = RandomState>
37where
38 S: Clone + BuildHasher,
39{
40 session: Session,
41 max_capacity: usize,
45 cache: DashMap<String, RawPreparedStatementData, S>,
46}
47
48impl<S> fmt::Debug for CachingSession<S>
49where
50 S: Clone + BuildHasher,
51{
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("GenericCachingSession")
54 .field("session", &self.session)
55 .field("max_capacity", &self.max_capacity)
56 .field("cache", &self.cache)
57 .finish()
58 }
59}
60
61impl<S> CachingSession<S>
62where
63 S: Default + BuildHasher + Clone,
64{
65 pub fn from(session: Session, cache_size: usize) -> Self {
66 Self {
67 session,
68 max_capacity: cache_size,
69 cache: Default::default(),
70 }
71 }
72}
73
74impl<S> CachingSession<S>
75where
76 S: BuildHasher + Clone,
77{
78 pub fn with_hasher(session: Session, cache_size: usize, hasher: S) -> Self {
81 Self {
82 session,
83 max_capacity: cache_size,
84 cache: DashMap::with_hasher(hasher),
85 }
86 }
87}
88
89impl<S> CachingSession<S>
90where
91 S: BuildHasher + Clone,
92{
93 pub async fn execute_unpaged(
96 &self,
97 query: impl Into<Statement>,
98 values: impl SerializeRow,
99 ) -> Result<QueryResult, ExecutionError> {
100 let query = query.into();
101 let prepared = self.add_prepared_statement_owned(query).await?;
102 self.session.execute_unpaged(&prepared, values).await
103 }
104
105 pub async fn execute_iter(
108 &self,
109 query: impl Into<Statement>,
110 values: impl SerializeRow,
111 ) -> Result<QueryPager, PagerExecutionError> {
112 let query = query.into();
113 let prepared = self.add_prepared_statement_owned(query).await?;
114 self.session.execute_iter(prepared, values).await
115 }
116
117 pub async fn execute_single_page(
120 &self,
121 query: impl Into<Statement>,
122 values: impl SerializeRow,
123 paging_state: PagingState,
124 ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> {
125 let query = query.into();
126 let prepared = self.add_prepared_statement_owned(query).await?;
127 self.session
128 .execute_single_page(&prepared, values, paging_state)
129 .await
130 }
131
132 pub async fn batch(
137 &self,
138 batch: &Batch,
139 values: impl BatchValues,
140 ) -> Result<QueryResult, ExecutionError> {
141 let all_prepared: bool = batch
142 .statements
143 .iter()
144 .all(|stmt| matches!(stmt, BatchStatement::PreparedStatement(_)));
145
146 if all_prepared {
147 self.session.batch(batch, &values).await
148 } else {
149 let prepared_batch: Batch = self.prepare_batch(batch).await?;
150
151 self.session.batch(&prepared_batch, &values).await
152 }
153 }
154}
155
156impl<S> CachingSession<S>
157where
158 S: BuildHasher + Clone,
159{
160 pub async fn prepare_batch(&self, batch: &Batch) -> Result<Batch, ExecutionError> {
164 let mut prepared_batch = batch.clone();
165
166 try_join_all(
167 prepared_batch
168 .statements
169 .iter_mut()
170 .map(|statement| async move {
171 if let BatchStatement::Query(query) = statement {
172 let prepared = self.add_prepared_statement(&*query).await?;
173 *statement = BatchStatement::PreparedStatement(prepared);
174 }
175 Ok::<(), ExecutionError>(())
176 }),
177 )
178 .await?;
179
180 Ok(prepared_batch)
181 }
182
183 pub async fn add_prepared_statement(
185 &self,
186 query: impl Into<&Statement>,
187 ) -> Result<PreparedStatement, PrepareError> {
188 self.add_prepared_statement_owned(query.into().clone())
189 .await
190 }
191
192 async fn add_prepared_statement_owned(
193 &self,
194 query: impl Into<Statement>,
195 ) -> Result<PreparedStatement, PrepareError> {
196 let query = query.into();
197
198 if let Some(raw) = self.cache.get(&query.contents) {
199 let page_size = query.get_validated_page_size();
200 let mut stmt = PreparedStatement::new(
201 raw.id.clone(),
202 raw.is_confirmed_lwt,
203 raw.metadata.clone(),
204 raw.result_metadata.clone(),
205 query.contents,
206 page_size,
207 query.config,
208 );
209 stmt.set_partitioner_name(raw.partitioner_name.clone());
210 Ok(stmt)
211 } else {
212 let query_contents = query.contents.clone();
213 let prepared = self.session.prepare(query).await?;
214
215 if self.max_capacity == self.cache.len() {
216 let query = self.cache.iter().next().map(|c| c.key().to_string());
221
222 if let Some(q) = query {
224 self.cache.remove(&q);
225 }
226 }
227
228 let raw = RawPreparedStatementData {
229 id: prepared.get_id().clone(),
230 is_confirmed_lwt: prepared.is_confirmed_lwt(),
231 metadata: prepared.get_prepared_metadata().clone(),
232 result_metadata: prepared.get_result_metadata().clone(),
233 partitioner_name: prepared.get_partitioner_name().clone(),
234 };
235 self.cache.insert(query_contents, raw);
236
237 Ok(prepared)
238 }
239 }
240
241 pub fn get_max_capacity(&self) -> usize {
242 self.max_capacity
243 }
244
245 pub fn get_session(&self) -> &Session {
246 &self.session
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use crate::client::session::Session;
253 use crate::response::PagingState;
254 use crate::routing::partitioner::PartitionerName;
255 use crate::statement::batch::{Batch, BatchStatement};
256 use crate::statement::prepared::PreparedStatement;
257 use crate::statement::unprepared::Statement;
258 use crate::test_utils::{
259 create_new_session_builder, scylla_supports_tablets, setup_tracing, PerformDDL,
260 };
261 use crate::utils::test_utils::unique_keyspace_name;
262 use crate::value::Row;
263 use futures::TryStreamExt;
264 use std::collections::BTreeSet;
265
266 use super::CachingSession;
267
268 async fn new_for_test(with_tablet_support: bool) -> Session {
269 let session = create_new_session_builder()
270 .build()
271 .await
272 .expect("Could not create session");
273 let ks = unique_keyspace_name();
274
275 let mut create_ks = format!(
276 "CREATE KEYSPACE IF NOT EXISTS {ks}
277 WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}"
278 );
279 if !with_tablet_support && scylla_supports_tablets(&session).await {
280 create_ks += " AND TABLETS = {'enabled': false}";
281 }
282
283 session
284 .ddl(create_ks)
285 .await
286 .expect("Could not create keyspace");
287
288 session
289 .ddl(format!(
290 "CREATE TABLE IF NOT EXISTS {}.test_table (a int primary key, b int)",
291 ks
292 ))
293 .await
294 .expect("Could not create table");
295
296 session
297 .use_keyspace(ks, false)
298 .await
299 .expect("Could not set keyspace");
300
301 session
302 }
303
304 async fn create_caching_session() -> CachingSession {
305 let session = CachingSession::from(new_for_test(true).await, 2);
306
307 session
309 .execute_unpaged("insert into test_table(a, b) values (1, 2)", &[])
310 .await
311 .unwrap();
312
313 assert_eq!(session.cache.len(), 1);
315
316 session.cache.clear();
317
318 session
319 }
320
321 #[tokio::test]
324 async fn test_full() {
325 setup_tracing();
326 let session = create_caching_session().await;
327
328 let first_query = "select * from test_table";
329 let middle_query = "insert into test_table(a, b) values (?, ?)";
330 let last_query = "update test_table set b = ? where a = 1";
331
332 session
333 .add_prepared_statement(&first_query.into())
334 .await
335 .unwrap();
336 session
337 .add_prepared_statement(&middle_query.into())
338 .await
339 .unwrap();
340 session
341 .add_prepared_statement(&last_query.into())
342 .await
343 .unwrap();
344
345 assert_eq!(2, session.cache.len());
346
347 assert!(session.cache.get(last_query).is_some());
349
350 let first_query_removed = session.cache.get(first_query).is_none();
352 let middle_query_removed = session.cache.get(middle_query).is_none();
353
354 assert!(first_query_removed || middle_query_removed);
355 }
356
357 #[tokio::test]
359 async fn test_execute_unpaged_cached() {
360 setup_tracing();
361 let session = create_caching_session().await;
362 let result = session
363 .execute_unpaged("select * from test_table", &[])
364 .await
365 .unwrap();
366 let result_rows = result.into_rows_result().unwrap();
367
368 assert_eq!(1, session.cache.len());
369 assert_eq!(1, result_rows.rows_num());
370
371 let result = session
372 .execute_unpaged("select * from test_table", &[])
373 .await
374 .unwrap();
375
376 let result_rows = result.into_rows_result().unwrap();
377
378 assert_eq!(1, session.cache.len());
379 assert_eq!(1, result_rows.rows_num());
380 }
381
382 #[tokio::test]
384 async fn test_execute_iter_cached() {
385 setup_tracing();
386 let session = create_caching_session().await;
387
388 assert!(session.cache.is_empty());
389
390 let iter = session
391 .execute_iter("select * from test_table", &[])
392 .await
393 .unwrap()
394 .rows_stream::<Row>()
395 .unwrap()
396 .into_stream();
397
398 let rows = iter
399 .into_stream()
400 .try_collect::<Vec<_>>()
401 .await
402 .unwrap()
403 .len();
404
405 assert_eq!(1, rows);
406 assert_eq!(1, session.cache.len());
407 }
408
409 #[tokio::test]
411 async fn test_execute_single_page_cached() {
412 setup_tracing();
413 let session = create_caching_session().await;
414
415 assert!(session.cache.is_empty());
416
417 let (result, _paging_state) = session
418 .execute_single_page("select * from test_table", &[], PagingState::start())
419 .await
420 .unwrap();
421
422 assert_eq!(1, session.cache.len());
423 assert_eq!(1, result.into_rows_result().unwrap().rows_num());
424 }
425
426 async fn assert_test_batch_table_rows_contain(
427 sess: &CachingSession,
428 expected_rows: &[(i32, i32)],
429 ) {
430 let selected_rows: BTreeSet<(i32, i32)> = sess
431 .execute_unpaged("SELECT a, b FROM test_batch_table", ())
432 .await
433 .unwrap()
434 .into_rows_result()
435 .unwrap()
436 .rows::<(i32, i32)>()
437 .unwrap()
438 .map(|r| r.unwrap())
439 .collect();
440 for expected_row in expected_rows.iter() {
441 if !selected_rows.contains(expected_row) {
442 panic!(
443 "Expected {:?} to contain row: {:?}, but they didn't",
444 selected_rows, expected_row
445 );
446 }
447 }
448 }
449
450 #[tokio::test]
452 async fn test_custom_hasher() {
453 setup_tracing();
454 #[derive(Default, Clone)]
455 struct CustomBuildHasher;
456 impl std::hash::BuildHasher for CustomBuildHasher {
457 type Hasher = CustomHasher;
458 fn build_hasher(&self) -> Self::Hasher {
459 CustomHasher(0)
460 }
461 }
462
463 struct CustomHasher(u8);
464 impl std::hash::Hasher for CustomHasher {
465 fn write(&mut self, bytes: &[u8]) {
466 for b in bytes {
467 self.0 ^= *b;
468 }
469 }
470 fn finish(&self) -> u64 {
471 self.0 as u64
472 }
473 }
474
475 let _session: CachingSession<std::collections::hash_map::RandomState> =
476 CachingSession::from(new_for_test(true).await, 2);
477 let _session: CachingSession<CustomBuildHasher> =
478 CachingSession::from(new_for_test(true).await, 2);
479 let _session: CachingSession<CustomBuildHasher> =
480 CachingSession::with_hasher(new_for_test(true).await, 2, Default::default());
481 }
482
483 #[tokio::test]
484 async fn test_batch() {
485 setup_tracing();
486 let session: CachingSession = create_caching_session().await;
487
488 session
489 .ddl("CREATE TABLE IF NOT EXISTS test_batch_table (a int, b int, primary key (a, b))")
490 .await
491 .unwrap();
492
493 let unprepared_insert_a_b: &str = "insert into test_batch_table (a, b) values (?, ?)";
494 let unprepared_insert_a_7: &str = "insert into test_batch_table (a, b) values (?, 7)";
495 let unprepared_insert_8_b: &str = "insert into test_batch_table (a, b) values (8, ?)";
496 let prepared_insert_a_b: PreparedStatement = session
497 .add_prepared_statement(&unprepared_insert_a_b.into())
498 .await
499 .unwrap();
500 let prepared_insert_a_7: PreparedStatement = session
501 .add_prepared_statement(&unprepared_insert_a_7.into())
502 .await
503 .unwrap();
504 let prepared_insert_8_b: PreparedStatement = session
505 .add_prepared_statement(&unprepared_insert_8_b.into())
506 .await
507 .unwrap();
508
509 let assert_batch_prepared = |b: &Batch| {
510 for stmt in &b.statements {
511 match stmt {
512 BatchStatement::PreparedStatement(_) => {}
513 _ => panic!("Unprepared statement in prepared batch!"),
514 }
515 }
516 };
517
518 {
519 let mut unprepared_batch: Batch = Default::default();
520 unprepared_batch.append_statement(unprepared_insert_a_b);
521 unprepared_batch.append_statement(unprepared_insert_a_7);
522 unprepared_batch.append_statement(unprepared_insert_8_b);
523
524 session
525 .batch(&unprepared_batch, ((10, 20), (10,), (20,)))
526 .await
527 .unwrap();
528 assert_test_batch_table_rows_contain(&session, &[(10, 20), (10, 7), (8, 20)]).await;
529
530 let prepared_batch: Batch = session.prepare_batch(&unprepared_batch).await.unwrap();
531 assert_batch_prepared(&prepared_batch);
532
533 session
534 .batch(&prepared_batch, ((15, 25), (15,), (25,)))
535 .await
536 .unwrap();
537 assert_test_batch_table_rows_contain(&session, &[(15, 25), (15, 7), (8, 25)]).await;
538 }
539
540 {
541 let mut partially_prepared_batch: Batch = Default::default();
542 partially_prepared_batch.append_statement(unprepared_insert_a_b);
543 partially_prepared_batch.append_statement(prepared_insert_a_7.clone());
544 partially_prepared_batch.append_statement(unprepared_insert_8_b);
545
546 session
547 .batch(&partially_prepared_batch, ((30, 40), (30,), (40,)))
548 .await
549 .unwrap();
550 assert_test_batch_table_rows_contain(&session, &[(30, 40), (30, 7), (8, 40)]).await;
551
552 let prepared_batch: Batch = session
553 .prepare_batch(&partially_prepared_batch)
554 .await
555 .unwrap();
556 assert_batch_prepared(&prepared_batch);
557
558 session
559 .batch(&prepared_batch, ((35, 45), (35,), (45,)))
560 .await
561 .unwrap();
562 assert_test_batch_table_rows_contain(&session, &[(35, 45), (35, 7), (8, 45)]).await;
563 }
564
565 {
566 let mut fully_prepared_batch: Batch = Default::default();
567 fully_prepared_batch.append_statement(prepared_insert_a_b);
568 fully_prepared_batch.append_statement(prepared_insert_a_7);
569 fully_prepared_batch.append_statement(prepared_insert_8_b);
570
571 session
572 .batch(&fully_prepared_batch, ((50, 60), (50,), (60,)))
573 .await
574 .unwrap();
575 assert_test_batch_table_rows_contain(&session, &[(50, 60), (50, 7), (8, 60)]).await;
576
577 let prepared_batch: Batch = session.prepare_batch(&fully_prepared_batch).await.unwrap();
578 assert_batch_prepared(&prepared_batch);
579
580 session
581 .batch(&prepared_batch, ((55, 65), (55,), (65,)))
582 .await
583 .unwrap();
584
585 assert_test_batch_table_rows_contain(&session, &[(55, 65), (55, 7), (8, 65)]).await;
586 }
587
588 {
589 let mut bad_batch: Batch = Default::default();
590 bad_batch.append_statement(unprepared_insert_a_b);
591 bad_batch.append_statement("This isnt even CQL");
592 bad_batch.append_statement(unprepared_insert_8_b);
593
594 assert!(session.batch(&bad_batch, ((1, 2), (), (2,))).await.is_err());
595 assert!(session.prepare_batch(&bad_batch).await.is_err());
596 }
597 }
598
599 #[tokio::test]
604 async fn test_parameters_caching() {
605 setup_tracing();
606 let session: CachingSession = CachingSession::from(new_for_test(true).await, 100);
607
608 session
609 .ddl("CREATE TABLE tbl (a int PRIMARY KEY, b int)")
610 .await
611 .unwrap();
612
613 let q = Statement::new("INSERT INTO tbl (a, b) VALUES (?, ?)");
614
615 let mut q1 = q.clone();
617 q1.set_timestamp(Some(1000));
618
619 session
620 .execute_unpaged(q1, (1, 1))
621 .await
622 .unwrap()
623 .result_not_rows()
624 .unwrap();
625
626 let mut q2 = q.clone();
628 q2.set_timestamp(Some(2000));
629
630 session
631 .execute_unpaged(q2, (2, 2))
632 .await
633 .unwrap()
634 .result_not_rows()
635 .unwrap();
636
637 let mut rows = session
639 .execute_unpaged("SELECT b, WRITETIME(b) FROM tbl", ())
640 .await
641 .unwrap()
642 .into_rows_result()
643 .unwrap()
644 .rows::<(i32, i64)>()
645 .unwrap()
646 .collect::<Result<Vec<_>, _>>()
647 .unwrap();
648
649 rows.sort_unstable();
650 assert_eq!(rows, vec![(1, 1000), (2, 2000)]);
651 }
652
653 #[tokio::test]
655 async fn test_partitioner_name_caching() {
656 setup_tracing();
657 if option_env!("CDC") == Some("disabled") {
658 return;
659 }
660
661 let session: CachingSession = CachingSession::from(new_for_test(false).await, 100);
663
664 session
665 .ddl("CREATE TABLE tbl (a int PRIMARY KEY) with cdc = {'enabled': true}")
666 .await
667 .unwrap();
668
669 session
670 .get_session()
671 .await_schema_agreement()
672 .await
673 .unwrap();
674
675 let verify_partitioner = || async {
680 let query =
681 Statement::new("SELECT * FROM tbl_scylla_cdc_log WHERE \"cdc$stream_id\" = ?");
682 let prepared = session.add_prepared_statement(&query).await.unwrap();
683 assert_eq!(prepared.get_partitioner_name(), &PartitionerName::CDC);
684 };
685
686 verify_partitioner().await;
689 verify_partitioner().await;
690 }
691
692 fn _caching_session_impls_debug() {
694 fn assert_debug<T: std::fmt::Debug>() {}
695 assert_debug::<CachingSession>();
696 }
697}