1use serde::{Deserialize, Serialize};
23use static_assertions as sa;
24use thiserror::Error;
25
26use crate::{
27 batch::{Batch, BatchValueWriter, DeletePrefixExpander, SimplifiedBatch},
28 store::{
29 DirectKeyValueStore, KeyValueDatabase, ReadableKeyValueStore, WithError,
30 WritableKeyValueStore,
31 },
32 views::MIN_VIEW_TAG,
33};
34
35#[derive(Clone)]
37pub struct JournalingKeyValueDatabase<D> {
38 database: D,
39}
40
41#[derive(Clone)]
43pub struct JournalingKeyValueStore<S> {
44 store: S,
46 has_exclusive_access: bool,
48}
49
50#[derive(Error, Debug)]
52#[allow(missing_docs)]
53pub enum JournalConsistencyError {
54 #[error("The journal block could not be retrieved, it could be missing or corrupted.")]
55 FailureToRetrieveJournalBlock,
56
57 #[error("Refusing to use the journal without exclusive database access to the root object.")]
58 JournalRequiresExclusiveAccess,
59}
60
61const JOURNAL_TAG: u8 = 0;
63sa::const_assert!(JOURNAL_TAG < MIN_VIEW_TAG);
66
67#[repr(u8)]
68enum KeyTag {
69 Journal = 1,
71 Entry,
73}
74
75fn get_journaling_key(tag: u8, pos: u32) -> Result<Vec<u8>, bcs::Error> {
76 let mut key = vec![JOURNAL_TAG];
77 key.extend([tag]);
78 bcs::serialize_into(&mut key, &pos)?;
79 Ok(key)
80}
81
82#[derive(Serialize, Deserialize, Debug, Default)]
84struct JournalHeader {
85 block_count: u32,
86}
87
88impl<S> DeletePrefixExpander for &JournalingKeyValueStore<S>
89where
90 S: DirectKeyValueStore,
91{
92 type Error = S::Error;
93
94 async fn expand_delete_prefix(&self, key_prefix: &[u8]) -> Result<Vec<Vec<u8>>, Self::Error> {
95 self.store.find_keys_by_prefix(key_prefix).await
96 }
97}
98
99impl<D> WithError for JournalingKeyValueDatabase<D>
100where
101 D: WithError,
102{
103 type Error = D::Error;
104}
105
106impl<S> WithError for JournalingKeyValueStore<S>
107where
108 S: WithError,
109{
110 type Error = S::Error;
111}
112
113impl<S> ReadableKeyValueStore for JournalingKeyValueStore<S>
114where
115 S: ReadableKeyValueStore,
116 S::Error: From<JournalConsistencyError>,
117{
118 const MAX_KEY_SIZE: usize = S::MAX_KEY_SIZE;
120
121 fn max_stream_queries(&self) -> usize {
123 self.store.max_stream_queries()
124 }
125
126 fn root_key(&self) -> Result<Vec<u8>, Self::Error> {
127 self.store.root_key()
128 }
129
130 async fn read_value_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
131 self.store.read_value_bytes(key).await
132 }
133
134 async fn contains_key(&self, key: &[u8]) -> Result<bool, Self::Error> {
135 self.store.contains_key(key).await
136 }
137
138 async fn contains_keys(&self, keys: &[Vec<u8>]) -> Result<Vec<bool>, Self::Error> {
139 self.store.contains_keys(keys).await
140 }
141
142 async fn read_multi_values_bytes(
143 &self,
144 keys: &[Vec<u8>],
145 ) -> Result<Vec<Option<Vec<u8>>>, Self::Error> {
146 self.store.read_multi_values_bytes(keys).await
147 }
148
149 async fn find_keys_by_prefix(&self, key_prefix: &[u8]) -> Result<Vec<Vec<u8>>, Self::Error> {
150 self.store.find_keys_by_prefix(key_prefix).await
151 }
152
153 async fn find_key_values_by_prefix(
154 &self,
155 key_prefix: &[u8],
156 ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, Self::Error> {
157 self.store.find_key_values_by_prefix(key_prefix).await
158 }
159}
160
161impl<D> KeyValueDatabase for JournalingKeyValueDatabase<D>
162where
163 D: KeyValueDatabase,
164{
165 type Config = D::Config;
166 type Store = JournalingKeyValueStore<D::Store>;
167
168 fn get_name() -> String {
169 format!("journaling {}", D::get_name())
170 }
171
172 async fn connect(config: &Self::Config, namespace: &str) -> Result<Self, Self::Error> {
173 let database = D::connect(config, namespace).await?;
174 Ok(Self { database })
175 }
176
177 fn open_shared(&self, root_key: &[u8]) -> Result<Self::Store, Self::Error> {
178 let store = self.database.open_shared(root_key)?;
179 Ok(JournalingKeyValueStore {
180 store,
181 has_exclusive_access: false,
182 })
183 }
184
185 fn open_exclusive(&self, root_key: &[u8]) -> Result<Self::Store, Self::Error> {
186 let store = self.database.open_exclusive(root_key)?;
187 Ok(JournalingKeyValueStore {
188 store,
189 has_exclusive_access: true,
190 })
191 }
192
193 async fn list_all(config: &Self::Config) -> Result<Vec<String>, Self::Error> {
194 D::list_all(config).await
195 }
196
197 async fn list_root_keys(&self) -> Result<Vec<Vec<u8>>, Self::Error> {
198 self.database.list_root_keys().await
199 }
200
201 async fn delete_all(config: &Self::Config) -> Result<(), Self::Error> {
202 D::delete_all(config).await
203 }
204
205 async fn exists(config: &Self::Config, namespace: &str) -> Result<bool, Self::Error> {
206 D::exists(config, namespace).await
207 }
208
209 async fn create(config: &Self::Config, namespace: &str) -> Result<(), Self::Error> {
210 D::create(config, namespace).await
211 }
212
213 async fn delete(config: &Self::Config, namespace: &str) -> Result<(), Self::Error> {
214 D::delete(config, namespace).await
215 }
216}
217
218impl<S> WritableKeyValueStore for JournalingKeyValueStore<S>
219where
220 S: DirectKeyValueStore,
221 S::Error: From<JournalConsistencyError>,
222{
223 const MAX_VALUE_SIZE: usize = S::MAX_VALUE_SIZE;
225
226 async fn write_batch(&self, batch: Batch) -> Result<(), Self::Error> {
227 let batch = S::Batch::from_batch(self, batch).await?;
228 if Self::is_fastpath_feasible(&batch) {
229 self.store.write_batch(batch).await
230 } else {
231 if !self.has_exclusive_access {
232 return Err(JournalConsistencyError::JournalRequiresExclusiveAccess.into());
233 }
234 let header = self.write_journal(batch).await?;
235 self.coherently_resolve_journal(header).await
236 }
237 }
238
239 async fn clear_journal(&self) -> Result<(), Self::Error> {
240 let key = get_journaling_key(KeyTag::Journal as u8, 0)?;
241 let value = self.read_value::<JournalHeader>(&key).await?;
242 if let Some(header) = value {
243 self.coherently_resolve_journal(header).await?;
244 }
245 Ok(())
246 }
247}
248
249impl<S> JournalingKeyValueStore<S>
250where
251 S: DirectKeyValueStore,
252 S::Error: From<JournalConsistencyError>,
253{
254 async fn coherently_resolve_journal(&self, mut header: JournalHeader) -> Result<(), S::Error> {
275 let header_key = get_journaling_key(KeyTag::Journal as u8, 0)?;
276 while header.block_count > 0 {
277 let block_key = get_journaling_key(KeyTag::Entry as u8, header.block_count - 1)?;
278 let mut batch = self
280 .store
281 .read_value::<S::Batch>(&block_key)
282 .await?
283 .ok_or(JournalConsistencyError::FailureToRetrieveJournalBlock)?;
284 batch.add_delete(block_key);
286 header.block_count -= 1;
287 if header.block_count > 0 {
288 let value = bcs::to_bytes(&header)?;
289 batch.add_insert(header_key.clone(), value);
290 } else {
291 batch.add_delete(header_key.clone());
292 }
293 self.store.write_batch(batch).await?;
294 }
295 Ok(())
296 }
297
298 async fn write_journal(&self, batch: S::Batch) -> Result<JournalHeader, S::Error> {
339 let header_key = get_journaling_key(KeyTag::Journal as u8, 0)?;
340 let key_len = header_key.len();
341 let header_value_len = bcs::serialized_size(&JournalHeader::default())?;
342 let journal_len_upper_bound = key_len + header_value_len;
343 let max_transaction_size = S::MAX_BATCH_TOTAL_SIZE;
345 let max_block_size = std::cmp::min(
346 S::MAX_VALUE_SIZE,
347 S::MAX_BATCH_TOTAL_SIZE - key_len - journal_len_upper_bound,
348 );
349
350 let mut iter = batch.into_iter();
351 let mut block_batch = S::Batch::default();
352 let mut block_size = 0;
353 let mut block_count = 0;
354 let mut transaction_batch = S::Batch::default();
355 let mut transaction_size = 0;
356 while iter.write_next_value(&mut block_batch, &mut block_size)? {
357 let (block_flush, transaction_flush) = {
358 if iter.is_empty() || transaction_batch.len() == S::MAX_BATCH_SIZE - 1 {
359 (true, true)
360 } else {
361 let next_block_size = iter
362 .next_batch_size(&block_batch, block_size)?
363 .expect("iter is not empty");
364 let next_transaction_size = transaction_size + next_block_size + key_len;
365 let transaction_flush = next_transaction_size > max_transaction_size;
366 let block_flush = transaction_flush
367 || block_batch.len() == S::MAX_BATCH_SIZE - 2
368 || next_block_size > max_block_size;
369 (block_flush, transaction_flush)
370 }
371 };
372 if block_flush {
373 block_size += block_batch.overhead_size();
374 let value = bcs::to_bytes(&block_batch)?;
375 block_batch = S::Batch::default();
376 assert_eq!(value.len(), block_size);
377 let key = get_journaling_key(KeyTag::Entry as u8, block_count)?;
378 transaction_batch.add_insert(key, value);
379 block_count += 1;
380 transaction_size += block_size + key_len;
381 block_size = 0;
382 }
383 if transaction_flush {
384 let batch = std::mem::take(&mut transaction_batch);
385 self.store.write_batch(batch).await?;
386 transaction_size = 0;
387 }
388 }
389 let header = JournalHeader { block_count };
390 if block_count > 0 {
391 let value = bcs::to_bytes(&header)?;
392 let mut batch = S::Batch::default();
393 batch.add_insert(header_key, value);
394 self.store.write_batch(batch).await?;
395 }
396 Ok(header)
397 }
398
399 fn is_fastpath_feasible(batch: &S::Batch) -> bool {
400 batch.len() <= S::MAX_BATCH_SIZE && batch.num_bytes() <= S::MAX_BATCH_TOTAL_SIZE
401 }
402}
403
404impl<S> JournalingKeyValueStore<S> {
405 pub fn new(store: S) -> Self {
407 Self {
408 store,
409 has_exclusive_access: false,
410 }
411 }
412}