scylla/client/
caching_session.rs

1use crate::errors::{ExecutionError, PagerExecutionError, PrepareError};
2use crate::response::query_result::QueryResult;
3use crate::response::{PagingState, PagingStateResponse};
4use crate::routing::partitioner::PartitionerName;
5use crate::statement::batch::{Batch, BatchStatement};
6use crate::statement::prepared::PreparedStatement;
7use crate::statement::unprepared::Statement;
8use bytes::Bytes;
9use dashmap::DashMap;
10use futures::future::try_join_all;
11use scylla_cql::frame::response::result::{PreparedMetadata, ResultMetadata};
12use scylla_cql::serialize::batch::BatchValues;
13use scylla_cql::serialize::row::SerializeRow;
14use std::collections::hash_map::RandomState;
15use std::fmt;
16use std::hash::BuildHasher;
17use std::sync::Arc;
18
19use crate::client::pager::QueryPager;
20use crate::client::session::Session;
21
22/// Contains just the parts of a prepared statement that were returned
23/// from the database. All remaining parts (query string, page size,
24/// consistency, etc.) are taken from the Query passed
25/// to the `CachingSession::execute` family of methods.
26#[derive(Debug)]
27struct RawPreparedStatementData {
28    id: Bytes,
29    is_confirmed_lwt: bool,
30    metadata: PreparedMetadata,
31    result_metadata: Arc<ResultMetadata<'static>>,
32    partitioner_name: PartitionerName,
33}
34
35/// Provides auto caching while executing queries
36pub struct CachingSession<S = RandomState>
37where
38    S: Clone + BuildHasher,
39{
40    session: Session,
41    /// The prepared statement cache size
42    /// If a prepared statement is added while the limit is reached, the oldest prepared statement
43    /// is removed from the cache
44    max_capacity: usize,
45    cache: DashMap<String, RawPreparedStatementData, S>,
46}
47
48impl<S> fmt::Debug for CachingSession<S>
49where
50    S: Clone + BuildHasher,
51{
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("GenericCachingSession")
54            .field("session", &self.session)
55            .field("max_capacity", &self.max_capacity)
56            .field("cache", &self.cache)
57            .finish()
58    }
59}
60
61impl<S> CachingSession<S>
62where
63    S: Default + BuildHasher + Clone,
64{
65    pub fn from(session: Session, cache_size: usize) -> Self {
66        Self {
67            session,
68            max_capacity: cache_size,
69            cache: Default::default(),
70        }
71    }
72}
73
74impl<S> CachingSession<S>
75where
76    S: BuildHasher + Clone,
77{
78    /// Builds a [`CachingSession`] from a [`Session`], a cache size,
79    /// and a [`BuildHasher`], using a customer hasher.
80    pub fn with_hasher(session: Session, cache_size: usize, hasher: S) -> Self {
81        Self {
82            session,
83            max_capacity: cache_size,
84            cache: DashMap::with_hasher(hasher),
85        }
86    }
87}
88
89impl<S> CachingSession<S>
90where
91    S: BuildHasher + Clone,
92{
93    /// Does the same thing as [`Session::execute_unpaged`]
94    /// but uses the prepared statement cache.
95    pub async fn execute_unpaged(
96        &self,
97        query: impl Into<Statement>,
98        values: impl SerializeRow,
99    ) -> Result<QueryResult, ExecutionError> {
100        let query = query.into();
101        let prepared = self.add_prepared_statement_owned(query).await?;
102        self.session.execute_unpaged(&prepared, values).await
103    }
104
105    /// Does the same thing as [`Session::execute_iter`]
106    /// but uses the prepared statement cache.
107    pub async fn execute_iter(
108        &self,
109        query: impl Into<Statement>,
110        values: impl SerializeRow,
111    ) -> Result<QueryPager, PagerExecutionError> {
112        let query = query.into();
113        let prepared = self.add_prepared_statement_owned(query).await?;
114        self.session.execute_iter(prepared, values).await
115    }
116
117    /// Does the same thing as [`Session::execute_single_page`]
118    /// but uses the prepared statement cache.
119    pub async fn execute_single_page(
120        &self,
121        query: impl Into<Statement>,
122        values: impl SerializeRow,
123        paging_state: PagingState,
124    ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> {
125        let query = query.into();
126        let prepared = self.add_prepared_statement_owned(query).await?;
127        self.session
128            .execute_single_page(&prepared, values, paging_state)
129            .await
130    }
131
132    /// Does the same thing as [`Session::batch`] but uses the
133    /// prepared statement cache.\
134    /// Prepares batch using [`CachingSession::prepare_batch`]
135    /// if needed and then executes it.
136    pub async fn batch(
137        &self,
138        batch: &Batch,
139        values: impl BatchValues,
140    ) -> Result<QueryResult, ExecutionError> {
141        let all_prepared: bool = batch
142            .statements
143            .iter()
144            .all(|stmt| matches!(stmt, BatchStatement::PreparedStatement(_)));
145
146        if all_prepared {
147            self.session.batch(batch, &values).await
148        } else {
149            let prepared_batch: Batch = self.prepare_batch(batch).await?;
150
151            self.session.batch(&prepared_batch, &values).await
152        }
153    }
154}
155
156impl<S> CachingSession<S>
157where
158    S: BuildHasher + Clone,
159{
160    /// Prepares all statements within the batch and returns a new batch where every
161    /// statement is prepared.
162    /// Uses the prepared statements cache.
163    pub async fn prepare_batch(&self, batch: &Batch) -> Result<Batch, ExecutionError> {
164        let mut prepared_batch = batch.clone();
165
166        try_join_all(
167            prepared_batch
168                .statements
169                .iter_mut()
170                .map(|statement| async move {
171                    if let BatchStatement::Query(query) = statement {
172                        let prepared = self.add_prepared_statement(&*query).await?;
173                        *statement = BatchStatement::PreparedStatement(prepared);
174                    }
175                    Ok::<(), ExecutionError>(())
176                }),
177        )
178        .await?;
179
180        Ok(prepared_batch)
181    }
182
183    /// Adds a prepared statement to the cache
184    pub async fn add_prepared_statement(
185        &self,
186        query: impl Into<&Statement>,
187    ) -> Result<PreparedStatement, PrepareError> {
188        self.add_prepared_statement_owned(query.into().clone())
189            .await
190    }
191
192    async fn add_prepared_statement_owned(
193        &self,
194        query: impl Into<Statement>,
195    ) -> Result<PreparedStatement, PrepareError> {
196        let query = query.into();
197
198        if let Some(raw) = self.cache.get(&query.contents) {
199            let page_size = query.get_validated_page_size();
200            let mut stmt = PreparedStatement::new(
201                raw.id.clone(),
202                raw.is_confirmed_lwt,
203                raw.metadata.clone(),
204                raw.result_metadata.clone(),
205                query.contents,
206                page_size,
207                query.config,
208            );
209            stmt.set_partitioner_name(raw.partitioner_name.clone());
210            Ok(stmt)
211        } else {
212            let query_contents = query.contents.clone();
213            let prepared = self.session.prepare(query).await?;
214
215            if self.max_capacity == self.cache.len() {
216                // Cache is full, remove the first entry
217                // Don't hold a reference into the map (that's why the to_string() is called)
218                // This is because the documentation of the remove fn tells us that it may deadlock
219                // when holding some sort of reference into the map
220                let query = self.cache.iter().next().map(|c| c.key().to_string());
221
222                // Don't inline this: https://stackoverflow.com/questions/69873846/an-owned-value-is-still-references-somehow
223                if let Some(q) = query {
224                    self.cache.remove(&q);
225                }
226            }
227
228            let raw = RawPreparedStatementData {
229                id: prepared.get_id().clone(),
230                is_confirmed_lwt: prepared.is_confirmed_lwt(),
231                metadata: prepared.get_prepared_metadata().clone(),
232                result_metadata: prepared.get_result_metadata().clone(),
233                partitioner_name: prepared.get_partitioner_name().clone(),
234            };
235            self.cache.insert(query_contents, raw);
236
237            Ok(prepared)
238        }
239    }
240
241    pub fn get_max_capacity(&self) -> usize {
242        self.max_capacity
243    }
244
245    pub fn get_session(&self) -> &Session {
246        &self.session
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use crate::client::session::Session;
253    use crate::response::PagingState;
254    use crate::routing::partitioner::PartitionerName;
255    use crate::statement::batch::{Batch, BatchStatement};
256    use crate::statement::prepared::PreparedStatement;
257    use crate::statement::unprepared::Statement;
258    use crate::test_utils::{
259        create_new_session_builder, scylla_supports_tablets, setup_tracing, PerformDDL,
260    };
261    use crate::utils::test_utils::unique_keyspace_name;
262    use crate::value::Row;
263    use futures::TryStreamExt;
264    use std::collections::BTreeSet;
265
266    use super::CachingSession;
267
268    async fn new_for_test(with_tablet_support: bool) -> Session {
269        let session = create_new_session_builder()
270            .build()
271            .await
272            .expect("Could not create session");
273        let ks = unique_keyspace_name();
274
275        let mut create_ks = format!(
276            "CREATE KEYSPACE IF NOT EXISTS {ks}
277        WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}"
278        );
279        if !with_tablet_support && scylla_supports_tablets(&session).await {
280            create_ks += " AND TABLETS = {'enabled': false}";
281        }
282
283        session
284            .ddl(create_ks)
285            .await
286            .expect("Could not create keyspace");
287
288        session
289            .ddl(format!(
290                "CREATE TABLE IF NOT EXISTS {}.test_table (a int primary key, b int)",
291                ks
292            ))
293            .await
294            .expect("Could not create table");
295
296        session
297            .use_keyspace(ks, false)
298            .await
299            .expect("Could not set keyspace");
300
301        session
302    }
303
304    async fn create_caching_session() -> CachingSession {
305        let session = CachingSession::from(new_for_test(true).await, 2);
306
307        // Add a row, this makes it easier to check if the caching works combined with the regular execute fn on Session
308        session
309            .execute_unpaged("insert into test_table(a, b) values (1, 2)", &[])
310            .await
311            .unwrap();
312
313        // Clear the cache because it now contains an insert
314        assert_eq!(session.cache.len(), 1);
315
316        session.cache.clear();
317
318        session
319    }
320
321    /// Test that when the cache is full and a different query comes in, that query will be added
322    /// to the cache and a random query is removed
323    #[tokio::test]
324    async fn test_full() {
325        setup_tracing();
326        let session = create_caching_session().await;
327
328        let first_query = "select * from test_table";
329        let middle_query = "insert into test_table(a, b) values (?, ?)";
330        let last_query = "update test_table set b = ? where a = 1";
331
332        session
333            .add_prepared_statement(&first_query.into())
334            .await
335            .unwrap();
336        session
337            .add_prepared_statement(&middle_query.into())
338            .await
339            .unwrap();
340        session
341            .add_prepared_statement(&last_query.into())
342            .await
343            .unwrap();
344
345        assert_eq!(2, session.cache.len());
346
347        // This query should be in the cache
348        assert!(session.cache.get(last_query).is_some());
349
350        // Either the first or middle query should be removed
351        let first_query_removed = session.cache.get(first_query).is_none();
352        let middle_query_removed = session.cache.get(middle_query).is_none();
353
354        assert!(first_query_removed || middle_query_removed);
355    }
356
357    /// Checks that the same prepared statement is reused when executing the same query twice
358    #[tokio::test]
359    async fn test_execute_unpaged_cached() {
360        setup_tracing();
361        let session = create_caching_session().await;
362        let result = session
363            .execute_unpaged("select * from test_table", &[])
364            .await
365            .unwrap();
366        let result_rows = result.into_rows_result().unwrap();
367
368        assert_eq!(1, session.cache.len());
369        assert_eq!(1, result_rows.rows_num());
370
371        let result = session
372            .execute_unpaged("select * from test_table", &[])
373            .await
374            .unwrap();
375
376        let result_rows = result.into_rows_result().unwrap();
377
378        assert_eq!(1, session.cache.len());
379        assert_eq!(1, result_rows.rows_num());
380    }
381
382    /// Checks that caching works with execute_iter
383    #[tokio::test]
384    async fn test_execute_iter_cached() {
385        setup_tracing();
386        let session = create_caching_session().await;
387
388        assert!(session.cache.is_empty());
389
390        let iter = session
391            .execute_iter("select * from test_table", &[])
392            .await
393            .unwrap()
394            .rows_stream::<Row>()
395            .unwrap()
396            .into_stream();
397
398        let rows = iter
399            .into_stream()
400            .try_collect::<Vec<_>>()
401            .await
402            .unwrap()
403            .len();
404
405        assert_eq!(1, rows);
406        assert_eq!(1, session.cache.len());
407    }
408
409    /// Checks that caching works with execute_single_page
410    #[tokio::test]
411    async fn test_execute_single_page_cached() {
412        setup_tracing();
413        let session = create_caching_session().await;
414
415        assert!(session.cache.is_empty());
416
417        let (result, _paging_state) = session
418            .execute_single_page("select * from test_table", &[], PagingState::start())
419            .await
420            .unwrap();
421
422        assert_eq!(1, session.cache.len());
423        assert_eq!(1, result.into_rows_result().unwrap().rows_num());
424    }
425
426    async fn assert_test_batch_table_rows_contain(
427        sess: &CachingSession,
428        expected_rows: &[(i32, i32)],
429    ) {
430        let selected_rows: BTreeSet<(i32, i32)> = sess
431            .execute_unpaged("SELECT a, b FROM test_batch_table", ())
432            .await
433            .unwrap()
434            .into_rows_result()
435            .unwrap()
436            .rows::<(i32, i32)>()
437            .unwrap()
438            .map(|r| r.unwrap())
439            .collect();
440        for expected_row in expected_rows.iter() {
441            if !selected_rows.contains(expected_row) {
442                panic!(
443                    "Expected {:?} to contain row: {:?}, but they didn't",
444                    selected_rows, expected_row
445                );
446            }
447        }
448    }
449
450    /// This test checks that we can construct a CachingSession with custom HashBuilder implementations
451    #[tokio::test]
452    async fn test_custom_hasher() {
453        setup_tracing();
454        #[derive(Default, Clone)]
455        struct CustomBuildHasher;
456        impl std::hash::BuildHasher for CustomBuildHasher {
457            type Hasher = CustomHasher;
458            fn build_hasher(&self) -> Self::Hasher {
459                CustomHasher(0)
460            }
461        }
462
463        struct CustomHasher(u8);
464        impl std::hash::Hasher for CustomHasher {
465            fn write(&mut self, bytes: &[u8]) {
466                for b in bytes {
467                    self.0 ^= *b;
468                }
469            }
470            fn finish(&self) -> u64 {
471                self.0 as u64
472            }
473        }
474
475        let _session: CachingSession<std::collections::hash_map::RandomState> =
476            CachingSession::from(new_for_test(true).await, 2);
477        let _session: CachingSession<CustomBuildHasher> =
478            CachingSession::from(new_for_test(true).await, 2);
479        let _session: CachingSession<CustomBuildHasher> =
480            CachingSession::with_hasher(new_for_test(true).await, 2, Default::default());
481    }
482
483    #[tokio::test]
484    async fn test_batch() {
485        setup_tracing();
486        let session: CachingSession = create_caching_session().await;
487
488        session
489            .ddl("CREATE TABLE IF NOT EXISTS test_batch_table (a int, b int, primary key (a, b))")
490            .await
491            .unwrap();
492
493        let unprepared_insert_a_b: &str = "insert into test_batch_table (a, b) values (?, ?)";
494        let unprepared_insert_a_7: &str = "insert into test_batch_table (a, b) values (?, 7)";
495        let unprepared_insert_8_b: &str = "insert into test_batch_table (a, b) values (8, ?)";
496        let prepared_insert_a_b: PreparedStatement = session
497            .add_prepared_statement(&unprepared_insert_a_b.into())
498            .await
499            .unwrap();
500        let prepared_insert_a_7: PreparedStatement = session
501            .add_prepared_statement(&unprepared_insert_a_7.into())
502            .await
503            .unwrap();
504        let prepared_insert_8_b: PreparedStatement = session
505            .add_prepared_statement(&unprepared_insert_8_b.into())
506            .await
507            .unwrap();
508
509        let assert_batch_prepared = |b: &Batch| {
510            for stmt in &b.statements {
511                match stmt {
512                    BatchStatement::PreparedStatement(_) => {}
513                    _ => panic!("Unprepared statement in prepared batch!"),
514                }
515            }
516        };
517
518        {
519            let mut unprepared_batch: Batch = Default::default();
520            unprepared_batch.append_statement(unprepared_insert_a_b);
521            unprepared_batch.append_statement(unprepared_insert_a_7);
522            unprepared_batch.append_statement(unprepared_insert_8_b);
523
524            session
525                .batch(&unprepared_batch, ((10, 20), (10,), (20,)))
526                .await
527                .unwrap();
528            assert_test_batch_table_rows_contain(&session, &[(10, 20), (10, 7), (8, 20)]).await;
529
530            let prepared_batch: Batch = session.prepare_batch(&unprepared_batch).await.unwrap();
531            assert_batch_prepared(&prepared_batch);
532
533            session
534                .batch(&prepared_batch, ((15, 25), (15,), (25,)))
535                .await
536                .unwrap();
537            assert_test_batch_table_rows_contain(&session, &[(15, 25), (15, 7), (8, 25)]).await;
538        }
539
540        {
541            let mut partially_prepared_batch: Batch = Default::default();
542            partially_prepared_batch.append_statement(unprepared_insert_a_b);
543            partially_prepared_batch.append_statement(prepared_insert_a_7.clone());
544            partially_prepared_batch.append_statement(unprepared_insert_8_b);
545
546            session
547                .batch(&partially_prepared_batch, ((30, 40), (30,), (40,)))
548                .await
549                .unwrap();
550            assert_test_batch_table_rows_contain(&session, &[(30, 40), (30, 7), (8, 40)]).await;
551
552            let prepared_batch: Batch = session
553                .prepare_batch(&partially_prepared_batch)
554                .await
555                .unwrap();
556            assert_batch_prepared(&prepared_batch);
557
558            session
559                .batch(&prepared_batch, ((35, 45), (35,), (45,)))
560                .await
561                .unwrap();
562            assert_test_batch_table_rows_contain(&session, &[(35, 45), (35, 7), (8, 45)]).await;
563        }
564
565        {
566            let mut fully_prepared_batch: Batch = Default::default();
567            fully_prepared_batch.append_statement(prepared_insert_a_b);
568            fully_prepared_batch.append_statement(prepared_insert_a_7);
569            fully_prepared_batch.append_statement(prepared_insert_8_b);
570
571            session
572                .batch(&fully_prepared_batch, ((50, 60), (50,), (60,)))
573                .await
574                .unwrap();
575            assert_test_batch_table_rows_contain(&session, &[(50, 60), (50, 7), (8, 60)]).await;
576
577            let prepared_batch: Batch = session.prepare_batch(&fully_prepared_batch).await.unwrap();
578            assert_batch_prepared(&prepared_batch);
579
580            session
581                .batch(&prepared_batch, ((55, 65), (55,), (65,)))
582                .await
583                .unwrap();
584
585            assert_test_batch_table_rows_contain(&session, &[(55, 65), (55, 7), (8, 65)]).await;
586        }
587
588        {
589            let mut bad_batch: Batch = Default::default();
590            bad_batch.append_statement(unprepared_insert_a_b);
591            bad_batch.append_statement("This isnt even CQL");
592            bad_batch.append_statement(unprepared_insert_8_b);
593
594            assert!(session.batch(&bad_batch, ((1, 2), (), (2,))).await.is_err());
595            assert!(session.prepare_batch(&bad_batch).await.is_err());
596        }
597    }
598
599    // The CachingSession::execute and friends should have the same StatementConfig
600    // and the page size as the Query provided as a parameter. It must not cache
601    // those parameters internally.
602    // Reproduces #597
603    #[tokio::test]
604    async fn test_parameters_caching() {
605        setup_tracing();
606        let session: CachingSession = CachingSession::from(new_for_test(true).await, 100);
607
608        session
609            .ddl("CREATE TABLE tbl (a int PRIMARY KEY, b int)")
610            .await
611            .unwrap();
612
613        let q = Statement::new("INSERT INTO tbl (a, b) VALUES (?, ?)");
614
615        // Insert one row with timestamp 1000
616        let mut q1 = q.clone();
617        q1.set_timestamp(Some(1000));
618
619        session
620            .execute_unpaged(q1, (1, 1))
621            .await
622            .unwrap()
623            .result_not_rows()
624            .unwrap();
625
626        // Insert another row with timestamp 2000
627        let mut q2 = q.clone();
628        q2.set_timestamp(Some(2000));
629
630        session
631            .execute_unpaged(q2, (2, 2))
632            .await
633            .unwrap()
634            .result_not_rows()
635            .unwrap();
636
637        // Fetch both rows with their timestamps
638        let mut rows = session
639            .execute_unpaged("SELECT b, WRITETIME(b) FROM tbl", ())
640            .await
641            .unwrap()
642            .into_rows_result()
643            .unwrap()
644            .rows::<(i32, i64)>()
645            .unwrap()
646            .collect::<Result<Vec<_>, _>>()
647            .unwrap();
648
649        rows.sort_unstable();
650        assert_eq!(rows, vec![(1, 1000), (2, 2000)]);
651    }
652
653    // Checks whether the PartitionerName is cached properly.
654    #[tokio::test]
655    async fn test_partitioner_name_caching() {
656        setup_tracing();
657        if option_env!("CDC") == Some("disabled") {
658            return;
659        }
660
661        // This test uses CDC which is not yet compatible with Scylla's tablets.
662        let session: CachingSession = CachingSession::from(new_for_test(false).await, 100);
663
664        session
665            .ddl("CREATE TABLE tbl (a int PRIMARY KEY) with cdc = {'enabled': true}")
666            .await
667            .unwrap();
668
669        session
670            .get_session()
671            .await_schema_agreement()
672            .await
673            .unwrap();
674
675        // This creates a query with default partitioner name (murmur hash),
676        // but after adding the statement it should be changed to the cdc
677        // partitioner. It should happen when the query is prepared
678        // and after it is fetched from the cache.
679        let verify_partitioner = || async {
680            let query =
681                Statement::new("SELECT * FROM tbl_scylla_cdc_log WHERE \"cdc$stream_id\" = ?");
682            let prepared = session.add_prepared_statement(&query).await.unwrap();
683            assert_eq!(prepared.get_partitioner_name(), &PartitionerName::CDC);
684        };
685
686        // Using a closure here instead of a loop so that, when the test fails,
687        // one can see which case failed by looking at the full backtrace
688        verify_partitioner().await;
689        verify_partitioner().await;
690    }
691
692    // NOTE: intentionally no `#[test]`: this is a compile-time test
693    fn _caching_session_impls_debug() {
694        fn assert_debug<T: std::fmt::Debug>() {}
695        assert_debug::<CachingSession>();
696    }
697}