1use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use static_assertions as sa;
24use thiserror::Error;
25
26use crate::{
27 batch::{Batch, BatchValueWriter, DeletePrefixExpander, SimplifiedBatch},
28 store::{
29 AdminKeyValueStore, KeyIterable, ReadableKeyValueStore, WithError, WritableKeyValueStore,
30 },
31 views::MIN_VIEW_TAG,
32};
33
34const JOURNAL_TAG: u8 = 0;
36sa::const_assert!(JOURNAL_TAG < MIN_VIEW_TAG);
39
40#[derive(Error, Debug)]
42#[allow(missing_docs)]
43pub enum JournalConsistencyError {
44 #[error("The journal block could not be retrieved, it could be missing or corrupted.")]
45 FailureToRetrieveJournalBlock,
46
47 #[error("Refusing to use the journal without exclusive database access to the root object.")]
48 JournalRequiresExclusiveAccess,
49}
50
51#[repr(u8)]
52enum KeyTag {
53 Journal = 1,
55 Entry,
57}
58
59fn get_journaling_key(tag: u8, pos: u32) -> Result<Vec<u8>, bcs::Error> {
60 let mut key = vec![JOURNAL_TAG];
61 key.extend([tag]);
62 bcs::serialize_into(&mut key, &pos)?;
63 Ok(key)
64}
65
66#[cfg_attr(not(web), trait_variant::make(Send + Sync))]
68pub trait DirectWritableKeyValueStore: WithError {
69 const MAX_BATCH_SIZE: usize;
71
72 const MAX_BATCH_TOTAL_SIZE: usize;
74
75 const MAX_VALUE_SIZE: usize;
77
78 type Batch: SimplifiedBatch + Serialize + DeserializeOwned + Default;
80
81 async fn write_batch(&self, batch: Self::Batch) -> Result<(), Self::Error>;
83}
84
85pub trait DirectKeyValueStore:
87 ReadableKeyValueStore + DirectWritableKeyValueStore + AdminKeyValueStore
88{
89}
90
91impl<T> DirectKeyValueStore for T where
92 T: ReadableKeyValueStore + DirectWritableKeyValueStore + AdminKeyValueStore
93{
94}
95
96#[derive(Serialize, Deserialize, Debug, Default)]
98struct JournalHeader {
99 block_count: u32,
100}
101
102#[derive(Clone)]
104pub struct JournalingKeyValueStore<K> {
105 store: K,
107 has_exclusive_access: bool,
109}
110
111impl<K> DeletePrefixExpander for &JournalingKeyValueStore<K>
112where
113 K: DirectKeyValueStore,
114{
115 type Error = K::Error;
116 async fn expand_delete_prefix(&self, key_prefix: &[u8]) -> Result<Vec<Vec<u8>>, Self::Error> {
117 let mut vector_list = Vec::new();
118 for key in self.store.find_keys_by_prefix(key_prefix).await?.iterator() {
119 vector_list.push(key?.to_vec());
120 }
121 Ok(vector_list)
122 }
123}
124
125impl<K> WithError for JournalingKeyValueStore<K>
126where
127 K: WithError,
128{
129 type Error = K::Error;
130}
131
132impl<K> ReadableKeyValueStore for JournalingKeyValueStore<K>
133where
134 K: ReadableKeyValueStore,
135 K::Error: From<JournalConsistencyError>,
136{
137 const MAX_KEY_SIZE: usize = K::MAX_KEY_SIZE;
139 type Keys = K::Keys;
141 type KeyValues = K::KeyValues;
142
143 fn max_stream_queries(&self) -> usize {
145 self.store.max_stream_queries()
146 }
147
148 async fn read_value_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
149 self.store.read_value_bytes(key).await
150 }
151
152 async fn contains_key(&self, key: &[u8]) -> Result<bool, Self::Error> {
153 self.store.contains_key(key).await
154 }
155
156 async fn contains_keys(&self, keys: Vec<Vec<u8>>) -> Result<Vec<bool>, Self::Error> {
157 self.store.contains_keys(keys).await
158 }
159
160 async fn read_multi_values_bytes(
161 &self,
162 keys: Vec<Vec<u8>>,
163 ) -> Result<Vec<Option<Vec<u8>>>, Self::Error> {
164 self.store.read_multi_values_bytes(keys).await
165 }
166
167 async fn find_keys_by_prefix(&self, key_prefix: &[u8]) -> Result<Self::Keys, Self::Error> {
168 self.store.find_keys_by_prefix(key_prefix).await
169 }
170
171 async fn find_key_values_by_prefix(
172 &self,
173 key_prefix: &[u8],
174 ) -> Result<Self::KeyValues, Self::Error> {
175 self.store.find_key_values_by_prefix(key_prefix).await
176 }
177}
178
179impl<K> AdminKeyValueStore for JournalingKeyValueStore<K>
180where
181 K: AdminKeyValueStore,
182{
183 type Config = K::Config;
184
185 fn get_name() -> String {
186 format!("journaling {}", K::get_name())
187 }
188
189 async fn connect(config: &Self::Config, namespace: &str) -> Result<Self, Self::Error> {
190 let store = K::connect(config, namespace).await?;
191 Ok(Self {
192 store,
193 has_exclusive_access: false,
194 })
195 }
196
197 fn open_exclusive(&self, root_key: &[u8]) -> Result<Self, Self::Error> {
198 let store = self.store.open_exclusive(root_key)?;
199 Ok(Self {
200 store,
201 has_exclusive_access: true,
202 })
203 }
204
205 async fn list_all(config: &Self::Config) -> Result<Vec<String>, Self::Error> {
206 K::list_all(config).await
207 }
208
209 async fn list_root_keys(
210 config: &Self::Config,
211 namespace: &str,
212 ) -> Result<Vec<Vec<u8>>, Self::Error> {
213 K::list_root_keys(config, namespace).await
214 }
215
216 async fn delete_all(config: &Self::Config) -> Result<(), Self::Error> {
217 K::delete_all(config).await
218 }
219
220 async fn exists(config: &Self::Config, namespace: &str) -> Result<bool, Self::Error> {
221 K::exists(config, namespace).await
222 }
223
224 async fn create(config: &Self::Config, namespace: &str) -> Result<(), Self::Error> {
225 K::create(config, namespace).await
226 }
227
228 async fn delete(config: &Self::Config, namespace: &str) -> Result<(), Self::Error> {
229 K::delete(config, namespace).await
230 }
231}
232
233impl<K> WritableKeyValueStore for JournalingKeyValueStore<K>
234where
235 K: DirectKeyValueStore,
236 K::Error: From<JournalConsistencyError>,
237{
238 const MAX_VALUE_SIZE: usize = K::MAX_VALUE_SIZE;
240
241 async fn write_batch(&self, batch: Batch) -> Result<(), Self::Error> {
242 let batch = K::Batch::from_batch(self, batch).await?;
243 if Self::is_fastpath_feasible(&batch) {
244 self.store.write_batch(batch).await
245 } else {
246 if !self.has_exclusive_access {
247 return Err(JournalConsistencyError::JournalRequiresExclusiveAccess.into());
248 }
249 let header = self.write_journal(batch).await?;
250 self.coherently_resolve_journal(header).await
251 }
252 }
253
254 async fn clear_journal(&self) -> Result<(), Self::Error> {
255 let key = get_journaling_key(KeyTag::Journal as u8, 0)?;
256 let value = self.read_value::<JournalHeader>(&key).await?;
257 if let Some(header) = value {
258 self.coherently_resolve_journal(header).await?;
259 }
260 Ok(())
261 }
262}
263
264impl<K> JournalingKeyValueStore<K>
265where
266 K: DirectKeyValueStore,
267 K::Error: From<JournalConsistencyError>,
268{
269 async fn coherently_resolve_journal(&self, mut header: JournalHeader) -> Result<(), K::Error> {
290 let header_key = get_journaling_key(KeyTag::Journal as u8, 0)?;
291 while header.block_count > 0 {
292 let block_key = get_journaling_key(KeyTag::Entry as u8, header.block_count - 1)?;
293 let mut batch = self
295 .store
296 .read_value::<K::Batch>(&block_key)
297 .await?
298 .ok_or(JournalConsistencyError::FailureToRetrieveJournalBlock)?;
299 batch.add_delete(block_key);
301 header.block_count -= 1;
302 if header.block_count > 0 {
303 let value = bcs::to_bytes(&header)?;
304 batch.add_insert(header_key.clone(), value);
305 } else {
306 batch.add_delete(header_key.clone());
307 }
308 self.store.write_batch(batch).await?;
309 }
310 Ok(())
311 }
312
313 async fn write_journal(&self, batch: K::Batch) -> Result<JournalHeader, K::Error> {
354 let header_key = get_journaling_key(KeyTag::Journal as u8, 0)?;
355 let key_len = header_key.len();
356 let header_value_len = bcs::serialized_size(&JournalHeader::default())?;
357 let journal_len_upper_bound = key_len + header_value_len;
358 let max_transaction_size = K::MAX_BATCH_TOTAL_SIZE;
360 let max_block_size = std::cmp::min(
361 K::MAX_VALUE_SIZE,
362 K::MAX_BATCH_TOTAL_SIZE - key_len - journal_len_upper_bound,
363 );
364
365 let mut iter = batch.into_iter();
366 let mut block_batch = K::Batch::default();
367 let mut block_size = 0;
368 let mut block_count = 0;
369 let mut transaction_batch = K::Batch::default();
370 let mut transaction_size = 0;
371 while iter.write_next_value(&mut block_batch, &mut block_size)? {
372 let (block_flush, transaction_flush) = {
373 if iter.is_empty() || transaction_batch.len() == K::MAX_BATCH_SIZE - 1 {
374 (true, true)
375 } else {
376 let next_block_size = iter
377 .next_batch_size(&block_batch, block_size)?
378 .expect("iter is not empty");
379 let next_transaction_size = transaction_size + next_block_size + key_len;
380 let transaction_flush = next_transaction_size > max_transaction_size;
381 let block_flush = transaction_flush
382 || block_batch.len() == K::MAX_BATCH_SIZE - 2
383 || next_block_size > max_block_size;
384 (block_flush, transaction_flush)
385 }
386 };
387 if block_flush {
388 block_size += block_batch.overhead_size();
389 let value = bcs::to_bytes(&block_batch)?;
390 block_batch = K::Batch::default();
391 assert_eq!(value.len(), block_size);
392 let key = get_journaling_key(KeyTag::Entry as u8, block_count)?;
393 transaction_batch.add_insert(key, value);
394 block_count += 1;
395 transaction_size += block_size + key_len;
396 block_size = 0;
397 }
398 if transaction_flush {
399 let batch = std::mem::take(&mut transaction_batch);
400 self.store.write_batch(batch).await?;
401 transaction_size = 0;
402 }
403 }
404 let header = JournalHeader { block_count };
405 if block_count > 0 {
406 let value = bcs::to_bytes(&header)?;
407 let mut batch = K::Batch::default();
408 batch.add_insert(header_key, value);
409 self.store.write_batch(batch).await?;
410 }
411 Ok(header)
412 }
413
414 fn is_fastpath_feasible(batch: &K::Batch) -> bool {
415 batch.len() <= K::MAX_BATCH_SIZE && batch.num_bytes() <= K::MAX_BATCH_TOTAL_SIZE
416 }
417}
418
419impl<K> JournalingKeyValueStore<K> {
420 pub fn new(store: K) -> Self {
422 Self {
423 store,
424 has_exclusive_access: false,
425 }
426 }
427}