linera_views/views/
hashable_wrapper.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    marker::PhantomData,
6    ops::{Deref, DerefMut},
7    sync::Mutex,
8};
9
10use allocative::Allocative;
11use linera_base::visit_allocative_simple;
12use serde::{de::DeserializeOwned, Serialize};
13
14use crate::{
15    batch::Batch,
16    common::from_bytes_option,
17    context::Context,
18    views::{ClonableView, HashableView, Hasher, ReplaceContext, View, ViewError, MIN_VIEW_TAG},
19};
20
21/// Wrapping a view to memoize its hash.
22#[derive(Debug, Allocative)]
23#[allocative(bound = "C, O, W: Allocative")]
24pub struct WrappedHashableContainerView<C, W, O> {
25    /// Phantom data for the context type.
26    #[allocative(skip)]
27    _phantom: PhantomData<C>,
28    /// The hash persisted in storage.
29    #[allocative(visit = visit_allocative_simple)]
30    stored_hash: Option<O>,
31    /// Memoized hash, if any.
32    #[allocative(visit = visit_allocative_simple)]
33    hash: Mutex<Option<O>>,
34    /// The wrapped view.
35    inner: W,
36}
37
38/// Key tags to create the sub-keys of a `WrappedHashableContainerView` on top of the base key.
39#[repr(u8)]
40enum KeyTag {
41    /// Prefix for the indices of the view.
42    Inner = MIN_VIEW_TAG,
43    /// Prefix for the hash.
44    Hash,
45}
46
47impl<C, W, O, C2> ReplaceContext<C2> for WrappedHashableContainerView<C, W, O>
48where
49    W: HashableView<Hasher: Hasher<Output = O>, Context = C> + ReplaceContext<C2>,
50    <W as ReplaceContext<C2>>::Target: HashableView<Hasher: Hasher<Output = O>>,
51    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
52    C: Context,
53    C2: Context,
54{
55    type Target = WrappedHashableContainerView<C2, <W as ReplaceContext<C2>>::Target, O>;
56
57    async fn with_context(
58        &mut self,
59        ctx: impl FnOnce(&Self::Context) -> C2 + Clone,
60    ) -> Self::Target {
61        let hash = *self.hash.lock().unwrap();
62        WrappedHashableContainerView {
63            _phantom: PhantomData,
64            stored_hash: self.stored_hash,
65            hash: Mutex::new(hash),
66            inner: self.inner.with_context(ctx).await,
67        }
68    }
69}
70
71impl<W: HashableView, O> View for WrappedHashableContainerView<W::Context, W, O>
72where
73    W: HashableView<Hasher: Hasher<Output = O>>,
74    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
75{
76    const NUM_INIT_KEYS: usize = 1 + W::NUM_INIT_KEYS;
77
78    type Context = W::Context;
79
80    fn context(&self) -> Self::Context {
81        // The inner context has our base key + the KeyTag::Inner byte
82        self.inner.context().clone_with_trimmed_key(1)
83    }
84
85    fn pre_load(context: &Self::Context) -> Result<Vec<Vec<u8>>, ViewError> {
86        let mut v = vec![context.base_key().base_tag(KeyTag::Hash as u8)];
87        let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
88        let context = context.clone_with_base_key(base_key);
89        v.extend(W::pre_load(&context)?);
90        Ok(v)
91    }
92
93    fn post_load(context: Self::Context, values: &[Option<Vec<u8>>]) -> Result<Self, ViewError> {
94        let hash = from_bytes_option(values.first().ok_or(ViewError::PostLoadValuesError)?)?;
95        let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
96        let context = context.clone_with_base_key(base_key);
97        let inner = W::post_load(
98            context,
99            values.get(1..).ok_or(ViewError::PostLoadValuesError)?,
100        )?;
101        Ok(Self {
102            _phantom: PhantomData,
103            stored_hash: hash,
104            hash: Mutex::new(hash),
105            inner,
106        })
107    }
108
109    fn rollback(&mut self) {
110        self.inner.rollback();
111        *self.hash.get_mut().unwrap() = self.stored_hash;
112    }
113
114    async fn has_pending_changes(&self) -> bool {
115        if self.inner.has_pending_changes().await {
116            return true;
117        }
118        let hash = self.hash.lock().unwrap();
119        self.stored_hash != *hash
120    }
121
122    fn pre_save(&self, batch: &mut Batch) -> Result<bool, ViewError> {
123        let delete_view = self.inner.pre_save(batch)?;
124        let hash = *self.hash.lock().unwrap();
125        if delete_view {
126            let mut key_prefix = self.inner.context().base_key().bytes.clone();
127            key_prefix.pop();
128            batch.delete_key_prefix(key_prefix);
129        } else if self.stored_hash != hash {
130            let mut key = self.inner.context().base_key().bytes.clone();
131            let tag = key.last_mut().unwrap();
132            *tag = KeyTag::Hash as u8;
133            match hash {
134                None => batch.delete_key(key),
135                Some(hash) => batch.put_key_value(key, &hash)?,
136            }
137        }
138        Ok(delete_view)
139    }
140
141    fn post_save(&mut self) {
142        self.inner.post_save();
143        let hash = *self.hash.get_mut().unwrap();
144        self.stored_hash = hash;
145    }
146
147    fn clear(&mut self) {
148        self.inner.clear();
149        *self.hash.get_mut().unwrap() = None;
150    }
151}
152
153impl<W, O> ClonableView for WrappedHashableContainerView<W::Context, W, O>
154where
155    W: HashableView + ClonableView,
156    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
157    W::Hasher: Hasher<Output = O>,
158{
159    fn clone_unchecked(&mut self) -> Result<Self, ViewError> {
160        Ok(WrappedHashableContainerView {
161            _phantom: PhantomData,
162            stored_hash: self.stored_hash,
163            hash: Mutex::new(*self.hash.get_mut().unwrap()),
164            inner: self.inner.clone_unchecked()?,
165        })
166    }
167}
168
169impl<W, O> HashableView for WrappedHashableContainerView<W::Context, W, O>
170where
171    W: HashableView,
172    O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
173    W::Hasher: Hasher<Output = O>,
174{
175    type Hasher = W::Hasher;
176
177    async fn hash_mut(&mut self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
178        let hash = *self.hash.get_mut().unwrap();
179        match hash {
180            Some(hash) => Ok(hash),
181            None => {
182                let new_hash = self.inner.hash_mut().await?;
183                let hash = self.hash.get_mut().unwrap();
184                *hash = Some(new_hash);
185                Ok(new_hash)
186            }
187        }
188    }
189
190    async fn hash(&self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
191        let hash = *self.hash.lock().unwrap();
192        match hash {
193            Some(hash) => Ok(hash),
194            None => {
195                let new_hash = self.inner.hash().await?;
196                let mut hash = self.hash.lock().unwrap();
197                *hash = Some(new_hash);
198                Ok(new_hash)
199            }
200        }
201    }
202}
203
204impl<C, W, O> Deref for WrappedHashableContainerView<C, W, O> {
205    type Target = W;
206
207    fn deref(&self) -> &W {
208        &self.inner
209    }
210}
211
212impl<C, W, O> DerefMut for WrappedHashableContainerView<C, W, O> {
213    fn deref_mut(&mut self) -> &mut W {
214        *self.hash.get_mut().unwrap() = None;
215        &mut self.inner
216    }
217}
218
219#[cfg(with_graphql)]
220mod graphql {
221    use std::borrow::Cow;
222
223    use super::WrappedHashableContainerView;
224    use crate::context::Context;
225
226    impl<C, W, O> async_graphql::OutputType for WrappedHashableContainerView<C, W, O>
227    where
228        C: Context,
229        W: async_graphql::OutputType + Send + Sync,
230        O: Send + Sync,
231    {
232        fn type_name() -> Cow<'static, str> {
233            W::type_name()
234        }
235
236        fn qualified_type_name() -> String {
237            W::qualified_type_name()
238        }
239
240        fn create_type_info(registry: &mut async_graphql::registry::Registry) -> String {
241            W::create_type_info(registry)
242        }
243
244        async fn resolve(
245            &self,
246            ctx: &async_graphql::ContextSelectionSet<'_>,
247            field: &async_graphql::Positioned<async_graphql::parser::types::Field>,
248        ) -> async_graphql::ServerResult<async_graphql::Value> {
249            (**self).resolve(ctx, field).await
250        }
251    }
252}