linera_views/views/
hashable_wrapper.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    marker::PhantomData,
6    ops::{Deref, DerefMut},
7    sync::Mutex,
8};
9
10use serde::{de::DeserializeOwned, Serialize};
11
12use crate::{
13    batch::Batch,
14    common::from_bytes_option,
15    context::Context,
16    views::{ClonableView, HashableView, Hasher, ReplaceContext, View, ViewError, MIN_VIEW_TAG},
17};
18
19/// Wrapping a view to memoize its hash.
20#[derive(Debug)]
21pub struct WrappedHashableContainerView<C, W, O> {
22    _phantom: PhantomData<C>,
23    stored_hash: Option<O>,
24    hash: Mutex<Option<O>>,
25    inner: W,
26}
27
28/// Key tags to create the sub-keys of a `WrappedHashableContainerView` on top of the base key.
29#[repr(u8)]
30enum KeyTag {
31    /// Prefix for the indices of the view.
32    Inner = MIN_VIEW_TAG,
33    /// Prefix for the hash.
34    Hash,
35}
36
37impl<C, W, O, C2> ReplaceContext<C2> for WrappedHashableContainerView<C, W, O>
38where
39    W: HashableView<Hasher: Hasher<Output = O>, Context = C> + ReplaceContext<C2>,
40    <W as ReplaceContext<C2>>::Target: HashableView<Hasher: Hasher<Output = O>>,
41    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
42    C: Context,
43    C2: Context,
44{
45    type Target = WrappedHashableContainerView<C2, <W as ReplaceContext<C2>>::Target, O>;
46
47    async fn with_context(
48        &mut self,
49        ctx: impl FnOnce(&Self::Context) -> C2 + Clone,
50    ) -> Self::Target {
51        let hash = *self.hash.lock().unwrap();
52        WrappedHashableContainerView {
53            _phantom: PhantomData,
54            stored_hash: self.stored_hash,
55            hash: Mutex::new(hash),
56            inner: self.inner.with_context(ctx).await,
57        }
58    }
59}
60
61impl<W: HashableView, O> View for WrappedHashableContainerView<W::Context, W, O>
62where
63    W: HashableView<Hasher: Hasher<Output = O>>,
64    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
65{
66    const NUM_INIT_KEYS: usize = 1 + W::NUM_INIT_KEYS;
67
68    type Context = W::Context;
69
70    fn context(&self) -> &Self::Context {
71        self.inner.context()
72    }
73
74    fn pre_load(context: &Self::Context) -> Result<Vec<Vec<u8>>, ViewError> {
75        let mut v = vec![context.base_key().base_tag(KeyTag::Hash as u8)];
76        let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
77        let context = context.clone_with_base_key(base_key);
78        v.extend(W::pre_load(&context)?);
79        Ok(v)
80    }
81
82    fn post_load(context: Self::Context, values: &[Option<Vec<u8>>]) -> Result<Self, ViewError> {
83        let hash = from_bytes_option(values.first().ok_or(ViewError::PostLoadValuesError)?)?;
84        let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
85        let context = context.clone_with_base_key(base_key);
86        let inner = W::post_load(
87            context,
88            values.get(1..).ok_or(ViewError::PostLoadValuesError)?,
89        )?;
90        Ok(Self {
91            _phantom: PhantomData,
92            stored_hash: hash,
93            hash: Mutex::new(hash),
94            inner,
95        })
96    }
97
98    fn rollback(&mut self) {
99        self.inner.rollback();
100        *self.hash.get_mut().unwrap() = self.stored_hash;
101    }
102
103    async fn has_pending_changes(&self) -> bool {
104        if self.inner.has_pending_changes().await {
105            return true;
106        }
107        let hash = self.hash.lock().unwrap();
108        self.stored_hash != *hash
109    }
110
111    fn flush(&mut self, batch: &mut Batch) -> Result<bool, ViewError> {
112        let delete_view = self.inner.flush(batch)?;
113        let hash = self.hash.get_mut().unwrap();
114        if delete_view {
115            let mut key_prefix = self.inner.context().base_key().bytes.clone();
116            key_prefix.pop();
117            batch.delete_key_prefix(key_prefix);
118            self.stored_hash = None;
119            *hash = None;
120        } else if self.stored_hash != *hash {
121            let mut key = self.inner.context().base_key().bytes.clone();
122            let tag = key.last_mut().unwrap();
123            *tag = KeyTag::Hash as u8;
124            match hash {
125                None => batch.delete_key(key),
126                Some(hash) => batch.put_key_value(key, hash)?,
127            }
128            self.stored_hash = *hash;
129        }
130        Ok(delete_view)
131    }
132
133    fn clear(&mut self) {
134        self.inner.clear();
135        *self.hash.get_mut().unwrap() = None;
136    }
137}
138
139impl<W, O> ClonableView for WrappedHashableContainerView<W::Context, W, O>
140where
141    W: HashableView + ClonableView,
142    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
143    W::Hasher: Hasher<Output = O>,
144{
145    fn clone_unchecked(&mut self) -> Result<Self, ViewError> {
146        Ok(WrappedHashableContainerView {
147            _phantom: PhantomData,
148            stored_hash: self.stored_hash,
149            hash: Mutex::new(*self.hash.get_mut().unwrap()),
150            inner: self.inner.clone_unchecked()?,
151        })
152    }
153}
154
155impl<W, O> HashableView for WrappedHashableContainerView<W::Context, W, O>
156where
157    W: HashableView,
158    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
159    W::Hasher: Hasher<Output = O>,
160{
161    type Hasher = W::Hasher;
162
163    async fn hash_mut(&mut self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
164        let hash = *self.hash.get_mut().unwrap();
165        match hash {
166            Some(hash) => Ok(hash),
167            None => {
168                let new_hash = self.inner.hash_mut().await?;
169                let hash = self.hash.get_mut().unwrap();
170                *hash = Some(new_hash);
171                Ok(new_hash)
172            }
173        }
174    }
175
176    async fn hash(&self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
177        let hash = *self.hash.lock().unwrap();
178        match hash {
179            Some(hash) => Ok(hash),
180            None => {
181                let new_hash = self.inner.hash().await?;
182                let mut hash = self.hash.lock().unwrap();
183                *hash = Some(new_hash);
184                Ok(new_hash)
185            }
186        }
187    }
188}
189
190impl<C, W, O> Deref for WrappedHashableContainerView<C, W, O> {
191    type Target = W;
192
193    fn deref(&self) -> &W {
194        &self.inner
195    }
196}
197
198impl<C, W, O> DerefMut for WrappedHashableContainerView<C, W, O> {
199    fn deref_mut(&mut self) -> &mut W {
200        *self.hash.get_mut().unwrap() = None;
201        &mut self.inner
202    }
203}
204
205#[cfg(with_graphql)]
206mod graphql {
207    use std::borrow::Cow;
208
209    use super::WrappedHashableContainerView;
210    use crate::context::Context;
211
212    impl<C, W, O> async_graphql::OutputType for WrappedHashableContainerView<C, W, O>
213    where
214        C: Context,
215        W: async_graphql::OutputType + Send + Sync,
216        O: Send + Sync,
217    {
218        fn type_name() -> Cow<'static, str> {
219            W::type_name()
220        }
221
222        fn qualified_type_name() -> String {
223            W::qualified_type_name()
224        }
225
226        fn create_type_info(registry: &mut async_graphql::registry::Registry) -> String {
227            W::create_type_info(registry)
228        }
229
230        async fn resolve(
231            &self,
232            ctx: &async_graphql::ContextSelectionSet<'_>,
233            field: &async_graphql::Positioned<async_graphql::parser::types::Field>,
234        ) -> async_graphql::ServerResult<async_graphql::Value> {
235            (**self).resolve(ctx, field).await
236        }
237    }
238}