1use std::{
5 marker::PhantomData,
6 ops::{Deref, DerefMut},
7 sync::Mutex,
8};
9
10use allocative::Allocative;
11use linera_base::visit_allocative_simple;
12use serde::{de::DeserializeOwned, Serialize};
13
14use crate::{
15 batch::Batch,
16 common::from_bytes_option,
17 context::Context,
18 views::{ClonableView, HashableView, Hasher, ReplaceContext, View, ViewError, MIN_VIEW_TAG},
19};
20
21#[derive(Debug, Allocative)]
23#[allocative(bound = "C, O, W: Allocative")]
24pub struct WrappedHashableContainerView<C, W, O> {
25 #[allocative(skip)]
27 _phantom: PhantomData<C>,
28 #[allocative(visit = visit_allocative_simple)]
30 stored_hash: Option<O>,
31 #[allocative(visit = visit_allocative_simple)]
33 hash: Mutex<Option<O>>,
34 inner: W,
36}
37
38#[repr(u8)]
40enum KeyTag {
41 Inner = MIN_VIEW_TAG,
43 Hash,
45}
46
47impl<C, W, O, C2> ReplaceContext<C2> for WrappedHashableContainerView<C, W, O>
48where
49 W: HashableView<Hasher: Hasher<Output = O>, Context = C> + ReplaceContext<C2>,
50 <W as ReplaceContext<C2>>::Target: HashableView<Hasher: Hasher<Output = O>>,
51 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
52 C: Context,
53 C2: Context,
54{
55 type Target = WrappedHashableContainerView<C2, <W as ReplaceContext<C2>>::Target, O>;
56
57 async fn with_context(
58 &mut self,
59 ctx: impl FnOnce(&Self::Context) -> C2 + Clone,
60 ) -> Self::Target {
61 let hash = *self.hash.lock().unwrap();
62 WrappedHashableContainerView {
63 _phantom: PhantomData,
64 stored_hash: self.stored_hash,
65 hash: Mutex::new(hash),
66 inner: self.inner.with_context(ctx).await,
67 }
68 }
69}
70
71impl<W: HashableView, O> View for WrappedHashableContainerView<W::Context, W, O>
72where
73 W: HashableView<Hasher: Hasher<Output = O>>,
74 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
75{
76 const NUM_INIT_KEYS: usize = 1 + W::NUM_INIT_KEYS;
77
78 type Context = W::Context;
79
80 fn context(&self) -> Self::Context {
81 self.inner.context().clone_with_trimmed_key(1)
83 }
84
85 fn pre_load(context: &Self::Context) -> Result<Vec<Vec<u8>>, ViewError> {
86 let mut v = vec![context.base_key().base_tag(KeyTag::Hash as u8)];
87 let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
88 let context = context.clone_with_base_key(base_key);
89 v.extend(W::pre_load(&context)?);
90 Ok(v)
91 }
92
93 fn post_load(context: Self::Context, values: &[Option<Vec<u8>>]) -> Result<Self, ViewError> {
94 let hash = from_bytes_option(values.first().ok_or(ViewError::PostLoadValuesError)?)?;
95 let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
96 let context = context.clone_with_base_key(base_key);
97 let inner = W::post_load(
98 context,
99 values.get(1..).ok_or(ViewError::PostLoadValuesError)?,
100 )?;
101 Ok(Self {
102 _phantom: PhantomData,
103 stored_hash: hash,
104 hash: Mutex::new(hash),
105 inner,
106 })
107 }
108
109 fn rollback(&mut self) {
110 self.inner.rollback();
111 *self.hash.get_mut().unwrap() = self.stored_hash;
112 }
113
114 async fn has_pending_changes(&self) -> bool {
115 if self.inner.has_pending_changes().await {
116 return true;
117 }
118 let hash = self.hash.lock().unwrap();
119 self.stored_hash != *hash
120 }
121
122 fn pre_save(&self, batch: &mut Batch) -> Result<bool, ViewError> {
123 let delete_view = self.inner.pre_save(batch)?;
124 let hash = *self.hash.lock().unwrap();
125 if delete_view {
126 let mut key_prefix = self.inner.context().base_key().bytes.clone();
127 key_prefix.pop();
128 batch.delete_key_prefix(key_prefix);
129 } else if self.stored_hash != hash {
130 let mut key = self.inner.context().base_key().bytes.clone();
131 let tag = key.last_mut().unwrap();
132 *tag = KeyTag::Hash as u8;
133 match hash {
134 None => batch.delete_key(key),
135 Some(hash) => batch.put_key_value(key, &hash)?,
136 }
137 }
138 Ok(delete_view)
139 }
140
141 fn post_save(&mut self) {
142 self.inner.post_save();
143 let hash = *self.hash.get_mut().unwrap();
144 self.stored_hash = hash;
145 }
146
147 fn clear(&mut self) {
148 self.inner.clear();
149 *self.hash.get_mut().unwrap() = None;
150 }
151}
152
153impl<W, O> ClonableView for WrappedHashableContainerView<W::Context, W, O>
154where
155 W: HashableView + ClonableView,
156 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
157 W::Hasher: Hasher<Output = O>,
158{
159 fn clone_unchecked(&mut self) -> Result<Self, ViewError> {
160 Ok(WrappedHashableContainerView {
161 _phantom: PhantomData,
162 stored_hash: self.stored_hash,
163 hash: Mutex::new(*self.hash.get_mut().unwrap()),
164 inner: self.inner.clone_unchecked()?,
165 })
166 }
167}
168
169impl<W, O> HashableView for WrappedHashableContainerView<W::Context, W, O>
170where
171 W: HashableView,
172 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
173 W::Hasher: Hasher<Output = O>,
174{
175 type Hasher = W::Hasher;
176
177 async fn hash_mut(&mut self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
178 let hash = *self.hash.get_mut().unwrap();
179 match hash {
180 Some(hash) => Ok(hash),
181 None => {
182 let new_hash = self.inner.hash_mut().await?;
183 let hash = self.hash.get_mut().unwrap();
184 *hash = Some(new_hash);
185 Ok(new_hash)
186 }
187 }
188 }
189
190 async fn hash(&self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
191 let hash = *self.hash.lock().unwrap();
192 match hash {
193 Some(hash) => Ok(hash),
194 None => {
195 let new_hash = self.inner.hash().await?;
196 let mut hash = self.hash.lock().unwrap();
197 *hash = Some(new_hash);
198 Ok(new_hash)
199 }
200 }
201 }
202}
203
204impl<C, W, O> Deref for WrappedHashableContainerView<C, W, O> {
205 type Target = W;
206
207 fn deref(&self) -> &W {
208 &self.inner
209 }
210}
211
212impl<C, W, O> DerefMut for WrappedHashableContainerView<C, W, O> {
213 fn deref_mut(&mut self) -> &mut W {
214 *self.hash.get_mut().unwrap() = None;
215 &mut self.inner
216 }
217}
218
219#[cfg(with_graphql)]
220mod graphql {
221 use std::borrow::Cow;
222
223 use super::WrappedHashableContainerView;
224 use crate::context::Context;
225
226 impl<C, W, O> async_graphql::OutputType for WrappedHashableContainerView<C, W, O>
227 where
228 C: Context,
229 W: async_graphql::OutputType + Send + Sync,
230 O: Send + Sync,
231 {
232 fn type_name() -> Cow<'static, str> {
233 W::type_name()
234 }
235
236 fn qualified_type_name() -> String {
237 W::qualified_type_name()
238 }
239
240 fn create_type_info(registry: &mut async_graphql::registry::Registry) -> String {
241 W::create_type_info(registry)
242 }
243
244 async fn resolve(
245 &self,
246 ctx: &async_graphql::ContextSelectionSet<'_>,
247 field: &async_graphql::Positioned<async_graphql::parser::types::Field>,
248 ) -> async_graphql::ServerResult<async_graphql::Value> {
249 (**self).resolve(ctx, field).await
250 }
251 }
252}