1use std::{
5 marker::PhantomData,
6 ops::{Deref, DerefMut},
7 sync::Mutex,
8};
9
10use serde::{de::DeserializeOwned, Serialize};
11
12use crate::{
13 batch::Batch,
14 common::from_bytes_option,
15 context::Context,
16 views::{ClonableView, HashableView, Hasher, ReplaceContext, View, ViewError, MIN_VIEW_TAG},
17};
18
19#[derive(Debug)]
21pub struct WrappedHashableContainerView<C, W, O> {
22 _phantom: PhantomData<C>,
23 stored_hash: Option<O>,
24 hash: Mutex<Option<O>>,
25 inner: W,
26}
27
28#[repr(u8)]
30enum KeyTag {
31 Inner = MIN_VIEW_TAG,
33 Hash,
35}
36
37impl<C, W, O, C2> ReplaceContext<C2> for WrappedHashableContainerView<C, W, O>
38where
39 W: HashableView<Hasher: Hasher<Output = O>, Context = C> + ReplaceContext<C2>,
40 <W as ReplaceContext<C2>>::Target: HashableView<Hasher: Hasher<Output = O>>,
41 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
42 C: Context,
43 C2: Context,
44{
45 type Target = WrappedHashableContainerView<C2, <W as ReplaceContext<C2>>::Target, O>;
46
47 async fn with_context(
48 &mut self,
49 ctx: impl FnOnce(&Self::Context) -> C2 + Clone,
50 ) -> Self::Target {
51 let hash = *self.hash.lock().unwrap();
52 WrappedHashableContainerView {
53 _phantom: PhantomData,
54 stored_hash: self.stored_hash,
55 hash: Mutex::new(hash),
56 inner: self.inner.with_context(ctx).await,
57 }
58 }
59}
60
61impl<W: HashableView, O> View for WrappedHashableContainerView<W::Context, W, O>
62where
63 W: HashableView<Hasher: Hasher<Output = O>>,
64 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
65{
66 const NUM_INIT_KEYS: usize = 1 + W::NUM_INIT_KEYS;
67
68 type Context = W::Context;
69
70 fn context(&self) -> &Self::Context {
71 self.inner.context()
72 }
73
74 fn pre_load(context: &Self::Context) -> Result<Vec<Vec<u8>>, ViewError> {
75 let mut v = vec![context.base_key().base_tag(KeyTag::Hash as u8)];
76 let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
77 let context = context.clone_with_base_key(base_key);
78 v.extend(W::pre_load(&context)?);
79 Ok(v)
80 }
81
82 fn post_load(context: Self::Context, values: &[Option<Vec<u8>>]) -> Result<Self, ViewError> {
83 let hash = from_bytes_option(values.first().ok_or(ViewError::PostLoadValuesError)?)?;
84 let base_key = context.base_key().base_tag(KeyTag::Inner as u8);
85 let context = context.clone_with_base_key(base_key);
86 let inner = W::post_load(
87 context,
88 values.get(1..).ok_or(ViewError::PostLoadValuesError)?,
89 )?;
90 Ok(Self {
91 _phantom: PhantomData,
92 stored_hash: hash,
93 hash: Mutex::new(hash),
94 inner,
95 })
96 }
97
98 fn rollback(&mut self) {
99 self.inner.rollback();
100 *self.hash.get_mut().unwrap() = self.stored_hash;
101 }
102
103 async fn has_pending_changes(&self) -> bool {
104 if self.inner.has_pending_changes().await {
105 return true;
106 }
107 let hash = self.hash.lock().unwrap();
108 self.stored_hash != *hash
109 }
110
111 fn flush(&mut self, batch: &mut Batch) -> Result<bool, ViewError> {
112 let delete_view = self.inner.flush(batch)?;
113 let hash = self.hash.get_mut().unwrap();
114 if delete_view {
115 let mut key_prefix = self.inner.context().base_key().bytes.clone();
116 key_prefix.pop();
117 batch.delete_key_prefix(key_prefix);
118 self.stored_hash = None;
119 *hash = None;
120 } else if self.stored_hash != *hash {
121 let mut key = self.inner.context().base_key().bytes.clone();
122 let tag = key.last_mut().unwrap();
123 *tag = KeyTag::Hash as u8;
124 match hash {
125 None => batch.delete_key(key),
126 Some(hash) => batch.put_key_value(key, hash)?,
127 }
128 self.stored_hash = *hash;
129 }
130 Ok(delete_view)
131 }
132
133 fn clear(&mut self) {
134 self.inner.clear();
135 *self.hash.get_mut().unwrap() = None;
136 }
137}
138
139impl<W, O> ClonableView for WrappedHashableContainerView<W::Context, W, O>
140where
141 W: HashableView + ClonableView,
142 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
143 W::Hasher: Hasher<Output = O>,
144{
145 fn clone_unchecked(&mut self) -> Result<Self, ViewError> {
146 Ok(WrappedHashableContainerView {
147 _phantom: PhantomData,
148 stored_hash: self.stored_hash,
149 hash: Mutex::new(*self.hash.get_mut().unwrap()),
150 inner: self.inner.clone_unchecked()?,
151 })
152 }
153}
154
155impl<W, O> HashableView for WrappedHashableContainerView<W::Context, W, O>
156where
157 W: HashableView,
158 O: Serialize + DeserializeOwned + Send + Sync + Copy + PartialEq,
159 W::Hasher: Hasher<Output = O>,
160{
161 type Hasher = W::Hasher;
162
163 async fn hash_mut(&mut self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
164 let hash = *self.hash.get_mut().unwrap();
165 match hash {
166 Some(hash) => Ok(hash),
167 None => {
168 let new_hash = self.inner.hash_mut().await?;
169 let hash = self.hash.get_mut().unwrap();
170 *hash = Some(new_hash);
171 Ok(new_hash)
172 }
173 }
174 }
175
176 async fn hash(&self) -> Result<<Self::Hasher as Hasher>::Output, ViewError> {
177 let hash = *self.hash.lock().unwrap();
178 match hash {
179 Some(hash) => Ok(hash),
180 None => {
181 let new_hash = self.inner.hash().await?;
182 let mut hash = self.hash.lock().unwrap();
183 *hash = Some(new_hash);
184 Ok(new_hash)
185 }
186 }
187 }
188}
189
190impl<C, W, O> Deref for WrappedHashableContainerView<C, W, O> {
191 type Target = W;
192
193 fn deref(&self) -> &W {
194 &self.inner
195 }
196}
197
198impl<C, W, O> DerefMut for WrappedHashableContainerView<C, W, O> {
199 fn deref_mut(&mut self) -> &mut W {
200 *self.hash.get_mut().unwrap() = None;
201 &mut self.inner
202 }
203}
204
205#[cfg(with_graphql)]
206mod graphql {
207 use std::borrow::Cow;
208
209 use super::WrappedHashableContainerView;
210 use crate::context::Context;
211
212 impl<C, W, O> async_graphql::OutputType for WrappedHashableContainerView<C, W, O>
213 where
214 C: Context,
215 W: async_graphql::OutputType + Send + Sync,
216 O: Send + Sync,
217 {
218 fn type_name() -> Cow<'static, str> {
219 W::type_name()
220 }
221
222 fn qualified_type_name() -> String {
223 W::qualified_type_name()
224 }
225
226 fn create_type_info(registry: &mut async_graphql::registry::Registry) -> String {
227 W::create_type_info(registry)
228 }
229
230 async fn resolve(
231 &self,
232 ctx: &async_graphql::ContextSelectionSet<'_>,
233 field: &async_graphql::Positioned<async_graphql::parser::types::Field>,
234 ) -> async_graphql::ServerResult<async_graphql::Value> {
235 (**self).resolve(ctx, field).await
236 }
237 }
238}