1use std::{
5 marker::PhantomData,
6 ops::{Deref, DerefMut},
7 sync::Mutex,
8};
9
10use serde::{de::DeserializeOwned, Serialize};
11
12use crate::{
13 batch::Batch,
14 common::from_bytes_option,
15 context::Context,
16 store::ReadableKeyValueStore as _,
17 views::{ClonableView, HashableView, Hasher, ReplaceContext, View, ViewError, MIN_VIEW_TAG},
18};
19
20#[derive(Debug)]
22pub struct WrappedHashableContainerView<C, W, O> {
23 _phantom: PhantomData<C>,
24 stored_hash: Option<O>,
25 hash: Mutex<Option<O>>,
26 inner: W,
27}
28
29#[repr(u8)]
31enum KeyTag {
32 Inner = MIN_VIEW_TAG,
34 Hash,
36}
37
38impl<C, W, O, C2> ReplaceContext<C2> for WrappedHashableContainerView<C, W, O>
39where
40 W: HashableView<Hasher: Hasher<Output = O>, Context = C> + ReplaceContext<C2>,
41 <W as ReplaceContext<C2>>::Target: HashableView<Hasher: Hasher<Output = O>>,
42 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
43 C: Context,
44 C2: Context,
45{
46 type Target = WrappedHashableContainerView<C2, <W as ReplaceContext<C2>>::Target, O>;
47
48 async fn with_context(
49 &mut self,
50 ctx: impl FnOnce(&Self::Context) -> C2 + Clone,
51 ) -> Self::Target {
52 let hash = *self.hash.lock().unwrap();
53 WrappedHashableContainerView {
54 _phantom: PhantomData,
55 stored_hash: self.stored_hash,
56 hash: Mutex::new(hash),
57 inner: self.inner.with_context(ctx).await,
58 }
59 }
60}
61
62impl<W: HashableView, O> View for WrappedHashableContainerView<W::Context, W, O>
63where
64 W: HashableView<Hasher: Hasher<Output = O>>,
65 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
66{
67 const NUM_INIT_KEYS: usize = 1 + W::NUM_INIT_KEYS;
68
69 type Context = W::Context;
70
71 fn context(&self) -> &Self::Context {
72 self.inner.context()
73 }
74
75 fn pre_load(context: &Self::Context) -> Result<Vec<Vec<u8>>, ViewError> {
76 let mut v = vec![context.base_key().base_tag(KeyTag::Hash as u8)];
77 let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
78 let context = context.clone_with_base_key(base_key);
79 v.extend(W::pre_load(&context)?);
80 Ok(v)
81 }
82
83 fn post_load(context: Self::Context, values: &[Option<Vec<u8>>]) -> Result<Self, ViewError> {
84 let hash = from_bytes_option(values.first().ok_or(ViewError::PostLoadValuesError)?)?;
85 let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
86 let context = context.clone_with_base_key(base_key);
87 let inner = W::post_load(
88 context,
89 values.get(1..).ok_or(ViewError::PostLoadValuesError)?,
90 )?;
91 Ok(Self {
92 _phantom: PhantomData,
93 stored_hash: hash,
94 hash: Mutex::new(hash),
95 inner,
96 })
97 }
98
99 async fn load(context: Self::Context) -> Result<Self, ViewError> {
100 let keys = Self::pre_load(&context)?;
101 let values = context.store().read_multi_values_bytes(keys).await?;
102 Self::post_load(context, &values)
103 }
104
105 fn rollback(&mut self) {
106 self.inner.rollback();
107 *self.hash.get_mut().unwrap() = self.stored_hash;
108 }
109
110 async fn has_pending_changes(&self) -> bool {
111 if self.inner.has_pending_changes().await {
112 return true;
113 }
114 let hash = self.hash.lock().unwrap();
115 self.stored_hash != *hash
116 }
117
118 fn flush(&mut self, batch: &mut Batch) -> Result<bool, ViewError> {
119 let delete_view = self.inner.flush(batch)?;
120 let hash = self.hash.get_mut().unwrap();
121 if delete_view {
122 let mut key_prefix = self.inner.context().base_key().bytes.clone();
123 key_prefix.pop();
124 batch.delete_key_prefix(key_prefix);
125 self.stored_hash = None;
126 *hash = None;
127 } else if self.stored_hash != *hash {
128 let mut key = self.inner.context().base_key().bytes.clone();
129 let tag = key.last_mut().unwrap();
130 *tag = KeyTag::Hash as u8;
131 match hash {
132 None => batch.delete_key(key),
133 Some(hash) => batch.put_key_value(key, hash)?,
134 }
135 self.stored_hash = *hash;
136 }
137 Ok(delete_view)
138 }
139
140 fn clear(&mut self) {
141 self.inner.clear();
142 *self.hash.get_mut().unwrap() = None;
143 }
144}
145
146impl<W, O> ClonableView for WrappedHashableContainerView<W::Context, W, O>
147where
148 W: HashableView + ClonableView,
149 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
150 W::Hasher: Hasher<Output = O>,
151{
152 fn clone_unchecked(&mut self) -> Self {
153 WrappedHashableContainerView {
154 _phantom: PhantomData,
155 stored_hash: self.stored_hash,
156 hash: Mutex::new(*self.hash.get_mut().unwrap()),
157 inner: self.inner.clone_unchecked(),
158 }
159 }
160}
161
162impl<W, O> HashableView for WrappedHashableContainerView<W::Context, W, O>
163where
164 W: HashableView,
165 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
166 W::Hasher: Hasher<Output = O>,
167{
168 type Hasher = W::Hasher;
169
170 async fn hash_mut(&mut self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
171 let hash = *self.hash.get_mut().unwrap();
172 match hash {
173 Some(hash) => Ok(hash),
174 None => {
175 let new_hash = self.inner.hash_mut().await?;
176 let hash = self.hash.get_mut().unwrap();
177 *hash = Some(new_hash);
178 Ok(new_hash)
179 }
180 }
181 }
182
183 async fn hash(&self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
184 let hash = *self.hash.lock().unwrap();
185 match hash {
186 Some(hash) => Ok(hash),
187 None => {
188 let new_hash = self.inner.hash().await?;
189 let mut hash = self.hash.lock().unwrap();
190 *hash = Some(new_hash);
191 Ok(new_hash)
192 }
193 }
194 }
195}
196
197impl<C, W, O> Deref for WrappedHashableContainerView<C, W, O> {
198 type Target = W;
199
200 fn deref(&self) -> &W {
201 &self.inner
202 }
203}
204
205impl<C, W, O> DerefMut for WrappedHashableContainerView<C, W, O> {
206 fn deref_mut(&mut self) -> &mut W {
207 *self.hash.get_mut().unwrap() = None;
208 &mut self.inner
209 }
210}
211
212#[cfg(with_graphql)]
213mod graphql {
214 use std::borrow::Cow;
215
216 use super::WrappedHashableContainerView;
217 use crate::context::Context;
218
219 impl<C, W, O> async_graphql::OutputType for WrappedHashableContainerView<C, W, O>
220 where
221 C: Context,
222 W: async_graphql::OutputType + Send + Sync,
223 O: Send + Sync,
224 {
225 fn type_name() -> Cow<'static, str> {
226 W::type_name()
227 }
228
229 fn qualified_type_name() -> String {
230 W::qualified_type_name()
231 }
232
233 fn create_type_info(registry: &mut async_graphql::registry::Registry) -> String {
234 W::create_type_info(registry)
235 }
236
237 async fn resolve(
238 &self,
239 ctx: &async_graphql::ContextSelectionSet<'_>,
240 field: &async_graphql::Positioned<async_graphql::parser::types::Field>,
241 ) -> async_graphql::ServerResult<async_graphql::Value> {
242 (**self).resolve(ctx, field).await
243 }
244 }
245}