1use serde::{Deserialize, Serialize};
23use static_assertions as sa;
24use thiserror::Error;
25
26use crate::{
27 batch::{Batch, BatchValueWriter, DeletePrefixExpander, SimplifiedBatch},
28 store::{
29 DirectKeyValueStore, KeyValueDatabase, KeyValueStoreError, ReadableKeyValueStore,
30 WithError, WritableKeyValueStore,
31 },
32 views::MIN_VIEW_TAG,
33};
34
35#[derive(Clone)]
37pub struct JournalingKeyValueDatabase<D> {
38 database: D,
39}
40
41#[derive(Clone)]
43pub struct JournalingKeyValueStore<S> {
44 store: S,
46 has_exclusive_access: bool,
48}
49
50#[derive(Error, Debug)]
52pub enum JournalingError<E> {
53 #[error(transparent)]
55 Inner(#[from] E),
56
57 #[error(transparent)]
59 BcsError(bcs::Error),
60
61 #[error("Refusing to use the journal without exclusive database access to the root object.")]
63 JournalRequiresExclusiveAccess,
64
65 #[error("Journal resolution failed: {0}")]
68 JournalResolutionFailed(JournalingResolutionError<E>),
69}
70
71#[derive(Error, Debug)]
73pub enum JournalingResolutionError<E> {
74 #[error(transparent)]
76 Inner(#[from] E),
77
78 #[error(transparent)]
80 BcsError(bcs::Error),
81
82 #[error("The journal block could not be retrieved, it could be missing or corrupted.")]
84 FailureToRetrieveJournalBlock,
85}
86
87impl<E: KeyValueStoreError> From<bcs::Error> for JournalingError<E> {
88 fn from(error: bcs::Error) -> Self {
89 JournalingError::BcsError(error)
90 }
91}
92
93impl<E: KeyValueStoreError + 'static> KeyValueStoreError for JournalingError<E> {
94 const BACKEND: &'static str = "journaling";
95
96 fn must_reload_view(&self) -> bool {
97 matches!(self, JournalingError::JournalResolutionFailed(_))
98 }
99}
100
101impl<E: KeyValueStoreError> From<bcs::Error> for JournalingResolutionError<E> {
102 fn from(error: bcs::Error) -> Self {
103 JournalingResolutionError::BcsError(error)
104 }
105}
106
107const JOURNAL_TAG: u8 = 0;
109sa::const_assert!(JOURNAL_TAG < MIN_VIEW_TAG);
112
113#[repr(u8)]
114enum KeyTag {
115 Journal = 1,
117 Entry,
119}
120
121fn get_journaling_key(tag: u8, pos: u32) -> Result<Vec<u8>, bcs::Error> {
122 let mut key = vec![JOURNAL_TAG];
123 key.extend([tag]);
124 bcs::serialize_into(&mut key, &pos)?;
125 Ok(key)
126}
127
128#[derive(Serialize, Deserialize, Debug, Default)]
130struct JournalHeader {
131 block_count: u32,
132}
133
134impl<S> DeletePrefixExpander for &JournalingKeyValueStore<S>
135where
136 S: DirectKeyValueStore,
137{
138 type Error = S::Error;
139
140 async fn expand_delete_prefix(&self, key_prefix: &[u8]) -> Result<Vec<Vec<u8>>, Self::Error> {
141 self.store.find_keys_by_prefix(key_prefix).await
142 }
143}
144
145impl<D> WithError for JournalingKeyValueDatabase<D>
146where
147 D: WithError,
148 D::Error: 'static,
149{
150 type Error = JournalingError<D::Error>;
151}
152
153impl<S> WithError for JournalingKeyValueStore<S>
154where
155 S: WithError,
156 S::Error: 'static,
157{
158 type Error = JournalingError<S::Error>;
159}
160
161impl<S> ReadableKeyValueStore for JournalingKeyValueStore<S>
162where
163 S: ReadableKeyValueStore,
164 S::Error: 'static,
165{
166 const MAX_KEY_SIZE: usize = S::MAX_KEY_SIZE;
167
168 fn max_stream_queries(&self) -> usize {
169 self.store.max_stream_queries()
170 }
171
172 fn root_key(&self) -> Result<Vec<u8>, Self::Error> {
173 Ok(self.store.root_key()?)
174 }
175
176 async fn read_value_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
177 Ok(self.store.read_value_bytes(key).await?)
178 }
179
180 async fn contains_key(&self, key: &[u8]) -> Result<bool, Self::Error> {
181 Ok(self.store.contains_key(key).await?)
182 }
183
184 async fn contains_keys(&self, keys: &[Vec<u8>]) -> Result<Vec<bool>, Self::Error> {
185 Ok(self.store.contains_keys(keys).await?)
186 }
187
188 async fn read_multi_values_bytes(
189 &self,
190 keys: &[Vec<u8>],
191 ) -> Result<Vec<Option<Vec<u8>>>, Self::Error> {
192 Ok(self.store.read_multi_values_bytes(keys).await?)
193 }
194
195 async fn find_keys_by_prefix(&self, key_prefix: &[u8]) -> Result<Vec<Vec<u8>>, Self::Error> {
196 Ok(self.store.find_keys_by_prefix(key_prefix).await?)
197 }
198
199 async fn find_key_values_by_prefix(
200 &self,
201 key_prefix: &[u8],
202 ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, Self::Error> {
203 Ok(self.store.find_key_values_by_prefix(key_prefix).await?)
204 }
205}
206
207impl<D> KeyValueDatabase for JournalingKeyValueDatabase<D>
208where
209 D: KeyValueDatabase,
210 D::Error: 'static,
211{
212 type Config = D::Config;
213 type Store = JournalingKeyValueStore<D::Store>;
214
215 fn get_name() -> String {
216 format!("journaling {}", D::get_name())
217 }
218
219 async fn connect(config: &Self::Config, namespace: &str) -> Result<Self, Self::Error> {
220 let database = D::connect(config, namespace).await?;
221 Ok(Self { database })
222 }
223
224 fn open_shared(&self, root_key: &[u8]) -> Result<Self::Store, Self::Error> {
225 let store = self.database.open_shared(root_key)?;
226 Ok(JournalingKeyValueStore {
227 store,
228 has_exclusive_access: false,
229 })
230 }
231
232 fn open_exclusive(&self, root_key: &[u8]) -> Result<Self::Store, Self::Error> {
233 let store = self.database.open_exclusive(root_key)?;
234 Ok(JournalingKeyValueStore {
235 store,
236 has_exclusive_access: true,
237 })
238 }
239
240 async fn list_all(config: &Self::Config) -> Result<Vec<String>, Self::Error> {
241 Ok(D::list_all(config).await?)
242 }
243
244 async fn list_root_keys(&self) -> Result<Vec<Vec<u8>>, Self::Error> {
245 Ok(self.database.list_root_keys().await?)
246 }
247
248 async fn delete_all(config: &Self::Config) -> Result<(), Self::Error> {
249 Ok(D::delete_all(config).await?)
250 }
251
252 async fn exists(config: &Self::Config, namespace: &str) -> Result<bool, Self::Error> {
253 Ok(D::exists(config, namespace).await?)
254 }
255
256 async fn create(config: &Self::Config, namespace: &str) -> Result<(), Self::Error> {
257 Ok(D::create(config, namespace).await?)
258 }
259
260 async fn delete(config: &Self::Config, namespace: &str) -> Result<(), Self::Error> {
261 Ok(D::delete(config, namespace).await?)
262 }
263}
264
265impl<S> WritableKeyValueStore for JournalingKeyValueStore<S>
266where
267 S: DirectKeyValueStore,
268 S::Error: 'static,
269{
270 const MAX_VALUE_SIZE: usize = S::MAX_VALUE_SIZE;
271
272 async fn write_batch(&self, batch: Batch) -> Result<(), Self::Error> {
273 let batch = S::Batch::from_batch(self, batch).await?;
274 if Self::is_fastpath_feasible(&batch) {
275 Ok(self.store.write_batch(batch).await?)
276 } else {
277 if !self.has_exclusive_access {
278 return Err(JournalingError::JournalRequiresExclusiveAccess);
279 }
280 let header = self.write_journal(batch).await?;
281 match self.coherently_resolve_journal(header).await {
282 Ok(()) => Ok(()),
283 Err(e) => Err(JournalingError::JournalResolutionFailed(e)),
284 }
285 }
286 }
287
288 async fn clear_journal(&self) -> Result<(), Self::Error> {
289 let key = get_journaling_key(KeyTag::Journal as u8, 0)?;
290 let value = self.read_value::<JournalHeader>(&key).await?;
291 if let Some(header) = value {
292 match self.coherently_resolve_journal(header).await {
293 Ok(()) => Ok(()),
294 Err(e) => Err(JournalingError::JournalResolutionFailed(e)),
295 }
296 } else {
297 Ok(())
298 }
299 }
300}
301
302impl<S> JournalingKeyValueStore<S>
303where
304 S: DirectKeyValueStore,
305 S::Error: 'static,
306{
307 async fn coherently_resolve_journal(
328 &self,
329 mut header: JournalHeader,
330 ) -> Result<(), JournalingResolutionError<S::Error>> {
331 let header_key = get_journaling_key(KeyTag::Journal as u8, 0)?;
332 while header.block_count > 0 {
333 let block_key = get_journaling_key(KeyTag::Entry as u8, header.block_count - 1)?;
334 let mut batch = self
336 .store
337 .read_value::<S::Batch>(&block_key)
338 .await?
339 .ok_or(JournalingResolutionError::FailureToRetrieveJournalBlock)?;
340 batch.add_delete(block_key);
342 header.block_count -= 1;
343 if header.block_count > 0 {
344 let value = bcs::to_bytes(&header)?;
345 batch.add_insert(header_key.clone(), value);
346 } else {
347 batch.add_delete(header_key.clone());
348 }
349 self.store.write_batch(batch).await?;
350 }
351 Ok(())
352 }
353
354 async fn write_journal(
395 &self,
396 batch: S::Batch,
397 ) -> Result<JournalHeader, JournalingError<S::Error>> {
398 let header_key = get_journaling_key(KeyTag::Journal as u8, 0)?;
399 let key_len = header_key.len();
400 let header_value_len = bcs::serialized_size(&JournalHeader::default())?;
401 let journal_len_upper_bound = key_len + header_value_len;
402 let max_transaction_size = S::MAX_BATCH_TOTAL_SIZE;
404 let max_block_size = std::cmp::min(
405 S::MAX_VALUE_SIZE,
406 S::MAX_BATCH_TOTAL_SIZE - key_len - journal_len_upper_bound,
407 );
408
409 let mut iter = batch.into_iter();
410 let mut block_batch = S::Batch::default();
411 let mut block_size = 0;
412 let mut block_count = 0;
413 let mut transaction_batch = S::Batch::default();
414 let mut transaction_size = 0;
415 while iter.write_next_value(&mut block_batch, &mut block_size)? {
416 let (block_flush, transaction_flush) = {
417 if iter.is_empty() || transaction_batch.len() == S::MAX_BATCH_SIZE - 1 {
418 (true, true)
419 } else {
420 let next_block_size = iter
421 .next_batch_size(&block_batch, block_size)?
422 .expect("iter is not empty");
423 let next_transaction_size = transaction_size + next_block_size + key_len;
424 let transaction_flush = next_transaction_size > max_transaction_size;
425 let block_flush = transaction_flush
426 || block_batch.len() == S::MAX_BATCH_SIZE - 2
427 || next_block_size > max_block_size;
428 (block_flush, transaction_flush)
429 }
430 };
431 if block_flush {
432 block_size += block_batch.overhead_size();
433 let value = bcs::to_bytes(&block_batch)?;
434 block_batch = S::Batch::default();
435 assert_eq!(value.len(), block_size);
436 let key = get_journaling_key(KeyTag::Entry as u8, block_count)?;
437 transaction_batch.add_insert(key, value);
438 block_count += 1;
439 transaction_size += block_size + key_len;
440 block_size = 0;
441 }
442 if transaction_flush {
443 let batch = std::mem::take(&mut transaction_batch);
444 self.store.write_batch(batch).await?;
445 transaction_size = 0;
446 }
447 }
448 let header = JournalHeader { block_count };
449 if block_count > 0 {
450 let value = bcs::to_bytes(&header)?;
451 let mut batch = S::Batch::default();
452 batch.add_insert(header_key, value);
453 self.store.write_batch(batch).await?;
454 }
455 Ok(header)
456 }
457
458 fn is_fastpath_feasible(batch: &S::Batch) -> bool {
459 batch.len() <= S::MAX_BATCH_SIZE && batch.num_bytes() <= S::MAX_BATCH_TOTAL_SIZE
460 }
461}
462
463impl<S> JournalingKeyValueStore<S> {
464 pub fn new(store: S) -> Self {
466 Self {
467 store,
468 has_exclusive_access: false,
469 }
470 }
471}