linera_core/environment/wallet/
memory.rs1use futures::{Stream, StreamExt as _};
5use linera_base::identifiers::ChainId;
6use serde::{ser::SerializeMap, Serialize, Serializer};
7
8use super::{Chain, Wallet};
9
10#[derive(Default, Clone, serde::Deserialize)]
16pub struct Memory(papaya::HashMap<ChainId, Chain>);
17
18impl Serialize for Memory {
20 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
21 let guard = self.0.pin();
22 let mut items: Vec<_> = guard.iter().collect();
23 items.sort_by_key(|(k, _)| *k);
24 let mut map = serializer.serialize_map(Some(items.len()))?;
25 for (k, v) in items {
26 map.serialize_entry(k, v)?;
27 }
28 map.end()
29 }
30}
31
32impl Memory {
33 pub fn get(&self, id: ChainId) -> Option<Chain> {
34 self.0.pin().get(&id).cloned()
35 }
36
37 pub fn insert(&self, id: ChainId, chain: Chain) -> Option<Chain> {
38 self.0.pin().insert(id, chain).cloned()
39 }
40
41 pub fn try_insert(&self, id: ChainId, chain: Chain) -> Option<Chain> {
42 match self.0.pin().try_insert(id, chain) {
43 Ok(_inserted) => None,
44 Err(error) => Some(error.not_inserted),
45 }
46 }
47
48 pub fn remove(&self, id: ChainId) -> Option<Chain> {
49 self.0.pin().remove(&id).cloned()
50 }
51
52 pub fn items(&self) -> Vec<(ChainId, Chain)> {
53 self.0
54 .pin()
55 .iter()
56 .map(|(id, chain)| (*id, chain.clone()))
57 .collect::<Vec<_>>()
58 }
59
60 pub fn chain_ids(&self) -> Vec<ChainId> {
61 self.0.pin().keys().copied().collect::<Vec<_>>()
62 }
63
64 pub fn owned_chain_ids(&self) -> Vec<ChainId> {
65 self.0
66 .pin()
67 .iter()
68 .filter_map(|(id, chain)| chain.owner.as_ref().map(|_| *id))
69 .collect::<Vec<_>>()
70 }
71
72 pub fn mutate<R>(
73 &self,
74 chain_id: ChainId,
75 mut mutate: impl FnMut(&mut Chain) -> R,
76 ) -> Option<R> {
77 use papaya::Operation::*;
78
79 let mut outcome = None;
80 self.0.pin().compute(chain_id, |chain| {
81 if let Some((_, chain)) = chain {
82 let mut chain = chain.clone();
83 outcome = Some(mutate(&mut chain));
84 Insert(chain)
85 } else {
86 Abort(())
87 }
88 });
89
90 outcome
91 }
92}
93
94impl Extend<(ChainId, Chain)> for Memory {
95 fn extend<It: IntoIterator<Item = (ChainId, Chain)>>(&mut self, chains: It) {
96 let map = self.0.pin();
97 for (id, chain) in chains {
98 map.insert(id, chain);
99 }
100 }
101}
102
103impl Wallet for Memory {
104 type Error = std::convert::Infallible;
105
106 async fn get(&self, id: ChainId) -> Result<Option<Chain>, Self::Error> {
107 Ok(self.get(id))
108 }
109
110 async fn insert(&self, id: ChainId, chain: Chain) -> Result<Option<Chain>, Self::Error> {
111 Ok(self.insert(id, chain))
112 }
113
114 async fn try_insert(&self, id: ChainId, chain: Chain) -> Result<Option<Chain>, Self::Error> {
115 Ok(self.try_insert(id, chain))
116 }
117
118 async fn remove(&self, id: ChainId) -> Result<Option<Chain>, Self::Error> {
119 Ok(self.remove(id))
120 }
121
122 fn items(&self) -> impl Stream<Item = Result<(ChainId, Chain), Self::Error>> {
123 futures::stream::iter(self.items()).map(Ok)
124 }
125
126 async fn modify(
127 &self,
128 id: ChainId,
129 f: impl FnMut(&mut Chain) + Send,
130 ) -> Result<Option<()>, Self::Error> {
131 Ok(self.mutate(id, f))
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use linera_base::{crypto::CryptoHash, data_types::Timestamp};
138
139 use super::*;
140
141 fn make_chain(height: u64) -> Chain {
142 Chain {
143 owner: None,
144 block_hash: None,
145 next_block_height: height.into(),
146 timestamp: Timestamp::from(0),
147 pending_proposal: None,
148 epoch: None,
149 }
150 }
151
152 #[test]
153 fn test_memory_serialization_roundtrip() {
154 let memory = Memory::default();
155
156 let id1 = ChainId(CryptoHash::test_hash("chain1"));
158 let id2 = ChainId(CryptoHash::test_hash("chain2"));
159 let id3 = ChainId(CryptoHash::test_hash("chain3"));
160
161 memory.insert(id2, make_chain(2));
162 memory.insert(id1, make_chain(1));
163 memory.insert(id3, make_chain(3));
164
165 let json = serde_json::to_string_pretty(&memory).unwrap();
167
168 let restored: Memory = serde_json::from_str(&json).unwrap();
170
171 assert_eq!(restored.get(id1).unwrap().next_block_height, 1.into());
173 assert_eq!(restored.get(id2).unwrap().next_block_height, 2.into());
174 assert_eq!(restored.get(id3).unwrap().next_block_height, 3.into());
175 }
176
177 #[test]
178 fn test_memory_serialization_is_sorted() {
179 let memory = Memory::default();
180
181 let id1 = ChainId(CryptoHash::test_hash("a"));
182 let id2 = ChainId(CryptoHash::test_hash("b"));
183 let id3 = ChainId(CryptoHash::test_hash("c"));
184
185 memory.insert(id3, make_chain(3));
187 memory.insert(id1, make_chain(1));
188 memory.insert(id2, make_chain(2));
189
190 let json = serde_json::to_string(&memory).unwrap();
192 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
193 let keys: Vec<_> = value.as_object().unwrap().keys().collect();
194 let mut sorted_keys = keys.clone();
195 sorted_keys.sort();
196 assert_eq!(keys, sorted_keys);
197 }
198}