1#[cfg(with_testing)]
7use std::sync::Arc;
8
9use linera_base::ensure;
10use linera_views::{
11    batch::Batch,
12    store::{ReadableKeyValueStore, WithError, WritableKeyValueStore},
13};
14use thiserror::Error;
15
16#[cfg(with_testing)]
17use super::mock_key_value_store::MockKeyValueStore;
18use crate::{
19    contract::wit::{
20        base_runtime_api::{self as contract_wit},
21        contract_runtime_api::{self, WriteOperation},
22    },
23    service::wit::base_runtime_api as service_wit,
24    util::yield_once,
25};
26
27const MAX_KEY_SIZE: usize = 900;
34
35#[derive(Clone)]
37pub struct KeyValueStore {
38    wit_api: WitInterface,
39}
40
41#[cfg_attr(with_testing, allow(dead_code))]
42impl KeyValueStore {
43    pub(crate) fn for_contracts() -> Self {
45        KeyValueStore {
46            wit_api: WitInterface::Contract,
47        }
48    }
49
50    pub(crate) fn for_services() -> Self {
52        KeyValueStore {
53            wit_api: WitInterface::Service,
54        }
55    }
56
57    #[cfg(with_testing)]
59    pub fn mock() -> Self {
60        KeyValueStore {
61            wit_api: WitInterface::Mock {
62                store: Arc::new(MockKeyValueStore::default()),
63                read_only: true,
64            },
65        }
66    }
67
68    #[cfg(with_testing)]
71    pub fn to_mut(&self) -> Self {
72        let WitInterface::Mock { store, .. } = &self.wit_api else {
73            panic!("Real `KeyValueStore` should not be used in unit tests");
74        };
75
76        KeyValueStore {
77            wit_api: WitInterface::Mock {
78                store: store.clone(),
79                read_only: false,
80            },
81        }
82    }
83}
84
85impl WithError for KeyValueStore {
86    type Error = KeyValueStoreError;
87}
88
89#[derive(Error, Debug)]
91pub enum KeyValueStoreError {
92    #[error("Key too long")]
94    KeyTooLong,
95
96    #[error(transparent)]
98    BcsError(#[from] bcs::Error),
99}
100
101impl linera_views::store::KeyValueStoreError for KeyValueStoreError {
102    const BACKEND: &'static str = "key_value_store";
103}
104
105impl ReadableKeyValueStore for KeyValueStore {
106    const MAX_KEY_SIZE: usize = MAX_KEY_SIZE;
109
110    fn max_stream_queries(&self) -> usize {
111        1
112    }
113
114    async fn contains_key(&self, key: &[u8]) -> Result<bool, KeyValueStoreError> {
115        ensure!(
116            key.len() <= Self::MAX_KEY_SIZE,
117            KeyValueStoreError::KeyTooLong
118        );
119        let promise = self.wit_api.contains_key_new(key);
120        yield_once().await;
121        Ok(self.wit_api.contains_key_wait(promise))
122    }
123
124    async fn contains_keys(&self, keys: Vec<Vec<u8>>) -> Result<Vec<bool>, KeyValueStoreError> {
125        for key in &keys {
126            ensure!(
127                key.len() <= Self::MAX_KEY_SIZE,
128                KeyValueStoreError::KeyTooLong
129            );
130        }
131        let promise = self.wit_api.contains_keys_new(&keys);
132        yield_once().await;
133        Ok(self.wit_api.contains_keys_wait(promise))
134    }
135
136    async fn read_multi_values_bytes(
137        &self,
138        keys: Vec<Vec<u8>>,
139    ) -> Result<Vec<Option<Vec<u8>>>, KeyValueStoreError> {
140        for key in &keys {
141            ensure!(
142                key.len() <= Self::MAX_KEY_SIZE,
143                KeyValueStoreError::KeyTooLong
144            );
145        }
146        let promise = self.wit_api.read_multi_values_bytes_new(&keys);
147        yield_once().await;
148        Ok(self.wit_api.read_multi_values_bytes_wait(promise))
149    }
150
151    async fn read_value_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, KeyValueStoreError> {
152        ensure!(
153            key.len() <= Self::MAX_KEY_SIZE,
154            KeyValueStoreError::KeyTooLong
155        );
156        let promise = self.wit_api.read_value_bytes_new(key);
157        yield_once().await;
158        Ok(self.wit_api.read_value_bytes_wait(promise))
159    }
160
161    async fn find_keys_by_prefix(
162        &self,
163        key_prefix: &[u8],
164    ) -> Result<Vec<Vec<u8>>, KeyValueStoreError> {
165        ensure!(
166            key_prefix.len() <= Self::MAX_KEY_SIZE,
167            KeyValueStoreError::KeyTooLong
168        );
169        let promise = self.wit_api.find_keys_new(key_prefix);
170        yield_once().await;
171        Ok(self.wit_api.find_keys_wait(promise))
172    }
173
174    async fn find_key_values_by_prefix(
175        &self,
176        key_prefix: &[u8],
177    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, KeyValueStoreError> {
178        ensure!(
179            key_prefix.len() <= Self::MAX_KEY_SIZE,
180            KeyValueStoreError::KeyTooLong
181        );
182        let promise = self.wit_api.find_key_values_new(key_prefix);
183        yield_once().await;
184        Ok(self.wit_api.find_key_values_wait(promise))
185    }
186}
187
188impl WritableKeyValueStore for KeyValueStore {
189    const MAX_VALUE_SIZE: usize = usize::MAX;
190
191    async fn write_batch(&self, batch: Batch) -> Result<(), KeyValueStoreError> {
192        self.wit_api.write_batch(batch);
193        Ok(())
194    }
195
196    async fn clear_journal(&self) -> Result<(), KeyValueStoreError> {
197        Ok(())
198    }
199}
200
201#[derive(Clone)]
203#[cfg_attr(with_testing, allow(dead_code))]
204enum WitInterface {
205    Contract,
207    Service,
209    #[cfg(with_testing)]
210    Mock {
212        store: Arc<MockKeyValueStore>,
213        read_only: bool,
214    },
215}
216
217impl WitInterface {
218    fn contains_key_new(&self, key: &[u8]) -> u32 {
220        match self {
221            WitInterface::Contract => contract_wit::contains_key_new(key),
222            WitInterface::Service => service_wit::contains_key_new(key),
223            #[cfg(with_testing)]
224            WitInterface::Mock { store, .. } => store.contains_key_new(key),
225        }
226    }
227
228    fn contains_key_wait(&self, promise: u32) -> bool {
230        match self {
231            WitInterface::Contract => contract_wit::contains_key_wait(promise),
232            WitInterface::Service => service_wit::contains_key_wait(promise),
233            #[cfg(with_testing)]
234            WitInterface::Mock { store, .. } => store.contains_key_wait(promise),
235        }
236    }
237
238    fn contains_keys_new(&self, keys: &[Vec<u8>]) -> u32 {
240        match self {
241            WitInterface::Contract => contract_wit::contains_keys_new(keys),
242            WitInterface::Service => service_wit::contains_keys_new(keys),
243            #[cfg(with_testing)]
244            WitInterface::Mock { store, .. } => store.contains_keys_new(keys),
245        }
246    }
247
248    fn contains_keys_wait(&self, promise: u32) -> Vec<bool> {
250        match self {
251            WitInterface::Contract => contract_wit::contains_keys_wait(promise),
252            WitInterface::Service => service_wit::contains_keys_wait(promise),
253            #[cfg(with_testing)]
254            WitInterface::Mock { store, .. } => store.contains_keys_wait(promise),
255        }
256    }
257
258    fn read_multi_values_bytes_new(&self, keys: &[Vec<u8>]) -> u32 {
260        match self {
261            WitInterface::Contract => contract_wit::read_multi_values_bytes_new(keys),
262            WitInterface::Service => service_wit::read_multi_values_bytes_new(keys),
263            #[cfg(with_testing)]
264            WitInterface::Mock { store, .. } => store.read_multi_values_bytes_new(keys),
265        }
266    }
267
268    fn read_multi_values_bytes_wait(&self, promise: u32) -> Vec<Option<Vec<u8>>> {
270        match self {
271            WitInterface::Contract => contract_wit::read_multi_values_bytes_wait(promise),
272            WitInterface::Service => service_wit::read_multi_values_bytes_wait(promise),
273            #[cfg(with_testing)]
274            WitInterface::Mock { store, .. } => store.read_multi_values_bytes_wait(promise),
275        }
276    }
277
278    fn read_value_bytes_new(&self, key: &[u8]) -> u32 {
280        match self {
281            WitInterface::Contract => contract_wit::read_value_bytes_new(key),
282            WitInterface::Service => service_wit::read_value_bytes_new(key),
283            #[cfg(with_testing)]
284            WitInterface::Mock { store, .. } => store.read_value_bytes_new(key),
285        }
286    }
287
288    fn read_value_bytes_wait(&self, promise: u32) -> Option<Vec<u8>> {
290        match self {
291            WitInterface::Contract => contract_wit::read_value_bytes_wait(promise),
292            WitInterface::Service => service_wit::read_value_bytes_wait(promise),
293            #[cfg(with_testing)]
294            WitInterface::Mock { store, .. } => store.read_value_bytes_wait(promise),
295        }
296    }
297
298    fn find_keys_new(&self, key_prefix: &[u8]) -> u32 {
300        match self {
301            WitInterface::Contract => contract_wit::find_keys_new(key_prefix),
302            WitInterface::Service => service_wit::find_keys_new(key_prefix),
303            #[cfg(with_testing)]
304            WitInterface::Mock { store, .. } => store.find_keys_new(key_prefix),
305        }
306    }
307
308    fn find_keys_wait(&self, promise: u32) -> Vec<Vec<u8>> {
310        match self {
311            WitInterface::Contract => contract_wit::find_keys_wait(promise),
312            WitInterface::Service => service_wit::find_keys_wait(promise),
313            #[cfg(with_testing)]
314            WitInterface::Mock { store, .. } => store.find_keys_wait(promise),
315        }
316    }
317
318    fn find_key_values_new(&self, key_prefix: &[u8]) -> u32 {
320        match self {
321            WitInterface::Contract => contract_wit::find_key_values_new(key_prefix),
322            WitInterface::Service => service_wit::find_key_values_new(key_prefix),
323            #[cfg(with_testing)]
324            WitInterface::Mock { store, .. } => store.find_key_values_new(key_prefix),
325        }
326    }
327
328    fn find_key_values_wait(&self, promise: u32) -> Vec<(Vec<u8>, Vec<u8>)> {
330        match self {
331            WitInterface::Contract => contract_wit::find_key_values_wait(promise),
332            WitInterface::Service => service_wit::find_key_values_wait(promise),
333            #[cfg(with_testing)]
334            WitInterface::Mock { store, .. } => store.find_key_values_wait(promise),
335        }
336    }
337
338    fn write_batch(&self, batch: Batch) {
340        match self {
341            WitInterface::Contract => {
342                let batch_operations = batch
343                    .operations
344                    .into_iter()
345                    .map(WriteOperation::from)
346                    .collect::<Vec<_>>();
347
348                contract_runtime_api::write_batch(&batch_operations);
349            }
350            WitInterface::Service => panic!("Attempt to modify storage from a service"),
351            #[cfg(with_testing)]
352            WitInterface::Mock {
353                store,
354                read_only: false,
355            } => {
356                store.write_batch(batch);
357            }
358            #[cfg(with_testing)]
359            WitInterface::Mock {
360                read_only: true, ..
361            } => {
362                panic!("Attempt to modify storage from a service")
363            }
364        }
365    }
366}
367
368pub type ViewStorageContext = linera_views::context::ViewContext<(), KeyValueStore>;
371
372#[cfg(all(test, not(target_arch = "wasm32")))]
373mod tests {
374    use super::*;
375
376    #[tokio::test]
377    async fn test_key_value_store_mock() -> anyhow::Result<()> {
378        let store = KeyValueStore::mock();
380        let mock_store = store.to_mut();
381
382        let is_key_existing = mock_store.contains_key(b"foo").await?;
384        assert!(!is_key_existing);
385
386        let is_keys_existing = mock_store
388            .contains_keys(vec![b"foo".to_vec(), b"bar".to_vec()])
389            .await?;
390        assert!(!is_keys_existing[0]);
391        assert!(!is_keys_existing[1]);
392
393        let mut batch = Batch::new();
395        batch.put_key_value(b"foo".to_vec(), &32_u128)?;
396        batch.put_key_value(b"bar".to_vec(), &42_u128)?;
397        mock_store.write_batch(batch).await?;
398
399        let is_key_existing = mock_store.contains_key(b"foo").await?;
400        assert!(is_key_existing);
401
402        let value = mock_store.read_value(b"foo").await?;
403        assert_eq!(value, Some(32_u128));
404
405        let value = mock_store.read_value(b"bar").await?;
406        assert_eq!(value, Some(42_u128));
407
408        Ok(())
409    }
410}