linera_core/environment/wallet/
memory.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use futures::{Stream, StreamExt as _};
5use linera_base::identifiers::ChainId;
6use serde::{ser::SerializeMap, Serialize, Serializer};
7
8use super::{Chain, Wallet};
9
10/// A basic implementation of `Wallet` that doesn't persist anything and merely tracks the
11/// chains in memory.
12///
13/// This can be used as-is as an ephemeral wallet for testing or ephemeral clients, or as
14/// a building block for more complex wallets that layer persistence on top of it.
15#[derive(Default, Clone, serde::Deserialize)]
16pub struct Memory(papaya::HashMap<ChainId, Chain>);
17
18/// Custom Serialize implementation that ensures stable ordering by sorting entries by ChainId.
19impl Serialize for Memory {
20    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
21        let guard = self.0.pin();
22        let mut items: Vec<_> = guard.iter().collect();
23        items.sort_by_key(|(k, _)| *k);
24        let mut map = serializer.serialize_map(Some(items.len()))?;
25        for (k, v) in items {
26            map.serialize_entry(k, v)?;
27        }
28        map.end()
29    }
30}
31
32impl Memory {
33    pub fn get(&self, id: ChainId) -> Option<Chain> {
34        self.0.pin().get(&id).cloned()
35    }
36
37    pub fn insert(&self, id: ChainId, chain: Chain) -> Option<Chain> {
38        self.0.pin().insert(id, chain).cloned()
39    }
40
41    pub fn try_insert(&self, id: ChainId, chain: Chain) -> Option<Chain> {
42        match self.0.pin().try_insert(id, chain) {
43            Ok(_inserted) => None,
44            Err(error) => Some(error.not_inserted),
45        }
46    }
47
48    pub fn remove(&self, id: ChainId) -> Option<Chain> {
49        self.0.pin().remove(&id).cloned()
50    }
51
52    pub fn items(&self) -> Vec<(ChainId, Chain)> {
53        self.0
54            .pin()
55            .iter()
56            .map(|(id, chain)| (*id, chain.clone()))
57            .collect::<Vec<_>>()
58    }
59
60    pub fn chain_ids(&self) -> Vec<ChainId> {
61        self.0.pin().keys().copied().collect::<Vec<_>>()
62    }
63
64    pub fn owned_chain_ids(&self) -> Vec<ChainId> {
65        self.0
66            .pin()
67            .iter()
68            .filter_map(|(id, chain)| chain.owner.as_ref().map(|_| *id))
69            .collect::<Vec<_>>()
70    }
71
72    pub fn mutate<R>(
73        &self,
74        chain_id: ChainId,
75        mut mutate: impl FnMut(&mut Chain) -> R,
76    ) -> Option<R> {
77        use papaya::Operation::*;
78
79        let mut outcome = None;
80        self.0.pin().compute(chain_id, |chain| {
81            if let Some((_, chain)) = chain {
82                let mut chain = chain.clone();
83                outcome = Some(mutate(&mut chain));
84                Insert(chain)
85            } else {
86                Abort(())
87            }
88        });
89
90        outcome
91    }
92}
93
94impl Extend<(ChainId, Chain)> for Memory {
95    fn extend<It: IntoIterator<Item = (ChainId, Chain)>>(&mut self, chains: It) {
96        let map = self.0.pin();
97        for (id, chain) in chains {
98            map.insert(id, chain);
99        }
100    }
101}
102
103impl Wallet for Memory {
104    type Error = std::convert::Infallible;
105
106    async fn get(&self, id: ChainId) -> Result<Option<Chain>, Self::Error> {
107        Ok(self.get(id))
108    }
109
110    async fn insert(&self, id: ChainId, chain: Chain) -> Result<Option<Chain>, Self::Error> {
111        Ok(self.insert(id, chain))
112    }
113
114    async fn try_insert(&self, id: ChainId, chain: Chain) -> Result<Option<Chain>, Self::Error> {
115        Ok(self.try_insert(id, chain))
116    }
117
118    async fn remove(&self, id: ChainId) -> Result<Option<Chain>, Self::Error> {
119        Ok(self.remove(id))
120    }
121
122    fn items(&self) -> impl Stream<Item = Result<(ChainId, Chain), Self::Error>> {
123        futures::stream::iter(self.items()).map(Ok)
124    }
125
126    async fn modify(
127        &self,
128        id: ChainId,
129        f: impl FnMut(&mut Chain) + Send,
130    ) -> Result<Option<()>, Self::Error> {
131        Ok(self.mutate(id, f))
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use linera_base::{crypto::CryptoHash, data_types::Timestamp};
138
139    use super::*;
140
141    fn make_chain(height: u64) -> Chain {
142        Chain {
143            owner: None,
144            block_hash: None,
145            next_block_height: height.into(),
146            timestamp: Timestamp::from(0),
147            pending_proposal: None,
148            epoch: None,
149        }
150    }
151
152    #[test]
153    fn test_memory_serialization_roundtrip() {
154        let memory = Memory::default();
155
156        // Insert chains in non-sorted order using different hashes
157        let id1 = ChainId(CryptoHash::test_hash("chain1"));
158        let id2 = ChainId(CryptoHash::test_hash("chain2"));
159        let id3 = ChainId(CryptoHash::test_hash("chain3"));
160
161        memory.insert(id2, make_chain(2));
162        memory.insert(id1, make_chain(1));
163        memory.insert(id3, make_chain(3));
164
165        // Serialize to JSON
166        let json = serde_json::to_string_pretty(&memory).unwrap();
167
168        // Deserialize back
169        let restored: Memory = serde_json::from_str(&json).unwrap();
170
171        // Verify data matches
172        assert_eq!(restored.get(id1).unwrap().next_block_height, 1.into());
173        assert_eq!(restored.get(id2).unwrap().next_block_height, 2.into());
174        assert_eq!(restored.get(id3).unwrap().next_block_height, 3.into());
175    }
176
177    #[test]
178    fn test_memory_serialization_is_sorted() {
179        let memory = Memory::default();
180
181        let id1 = ChainId(CryptoHash::test_hash("a"));
182        let id2 = ChainId(CryptoHash::test_hash("b"));
183        let id3 = ChainId(CryptoHash::test_hash("c"));
184
185        // Insert in non-sorted order
186        memory.insert(id3, make_chain(3));
187        memory.insert(id1, make_chain(1));
188        memory.insert(id2, make_chain(2));
189
190        // Serialize and verify output keys are sorted
191        let json = serde_json::to_string(&memory).unwrap();
192        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
193        let keys: Vec<_> = value.as_object().unwrap().keys().collect();
194        let mut sorted_keys = keys.clone();
195        sorted_keys.sort();
196        assert_eq!(keys, sorted_keys);
197    }
198}