papaya/
serde_impls.rs

1use serde::de::{MapAccess, SeqAccess, Visitor};
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4use std::fmt::{self, Formatter};
5use std::hash::{BuildHasher, Hash};
6use std::marker::PhantomData;
7
8use crate::{Guard, HashMap, HashMapRef, HashSet, HashSetRef};
9
10struct MapVisitor<K, V, S> {
11    _marker: PhantomData<HashMap<K, V, S>>,
12}
13
14impl<K, V, S, G> Serialize for HashMapRef<'_, K, V, S, G>
15where
16    K: Serialize + Hash + Eq,
17    V: Serialize,
18    G: Guard,
19    S: BuildHasher,
20{
21    fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
22    where
23        Sr: Serializer,
24    {
25        serializer.collect_map(self)
26    }
27}
28
29impl<K, V, S> Serialize for HashMap<K, V, S>
30where
31    K: Serialize + Hash + Eq,
32    V: Serialize,
33    S: BuildHasher,
34{
35    fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
36    where
37        Sr: Serializer,
38    {
39        self.pin().serialize(serializer)
40    }
41}
42
43impl<'de, K, V, S> Deserialize<'de> for HashMap<K, V, S>
44where
45    K: Deserialize<'de> + Hash + Eq,
46    V: Deserialize<'de>,
47    S: Default + BuildHasher,
48{
49    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50    where
51        D: Deserializer<'de>,
52    {
53        deserializer.deserialize_map(MapVisitor::new())
54    }
55}
56
57impl<K, V, S> MapVisitor<K, V, S> {
58    pub(crate) fn new() -> Self {
59        Self {
60            _marker: PhantomData,
61        }
62    }
63}
64
65impl<'de, K, V, S> Visitor<'de> for MapVisitor<K, V, S>
66where
67    K: Deserialize<'de> + Hash + Eq,
68    V: Deserialize<'de>,
69    S: Default + BuildHasher,
70{
71    type Value = HashMap<K, V, S>;
72
73    fn expecting(&self, f: &mut Formatter<'_>) -> fmt::Result {
74        write!(f, "a map")
75    }
76
77    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
78    where
79        M: MapAccess<'de>,
80    {
81        let values = match access.size_hint() {
82            Some(size) => HashMap::with_capacity_and_hasher(size, S::default()),
83            None => HashMap::default(),
84        };
85
86        {
87            let values = values.pin();
88            while let Some((key, value)) = access.next_entry()? {
89                values.insert(key, value);
90            }
91        }
92
93        Ok(values)
94    }
95}
96
97struct SetVisitor<K, S> {
98    _marker: PhantomData<HashSet<K, S>>,
99}
100
101impl<K, S, G> Serialize for HashSetRef<'_, K, S, G>
102where
103    K: Serialize + Hash + Eq,
104    G: Guard,
105    S: BuildHasher,
106{
107    fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
108    where
109        Sr: Serializer,
110    {
111        serializer.collect_seq(self)
112    }
113}
114
115impl<K, S> Serialize for HashSet<K, S>
116where
117    K: Serialize + Hash + Eq,
118    S: BuildHasher,
119{
120    fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
121    where
122        Sr: Serializer,
123    {
124        self.pin().serialize(serializer)
125    }
126}
127
128impl<'de, K, S> Deserialize<'de> for HashSet<K, S>
129where
130    K: Deserialize<'de> + Hash + Eq,
131    S: Default + BuildHasher,
132{
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: Deserializer<'de>,
136    {
137        deserializer.deserialize_seq(SetVisitor::new())
138    }
139}
140
141impl<K, S> SetVisitor<K, S> {
142    pub(crate) fn new() -> Self {
143        Self {
144            _marker: PhantomData,
145        }
146    }
147}
148
149impl<'de, K, S> Visitor<'de> for SetVisitor<K, S>
150where
151    K: Deserialize<'de> + Hash + Eq,
152    S: Default + BuildHasher,
153{
154    type Value = HashSet<K, S>;
155
156    fn expecting(&self, f: &mut Formatter<'_>) -> fmt::Result {
157        write!(f, "a set")
158    }
159
160    fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
161    where
162        M: SeqAccess<'de>,
163    {
164        let values = match access.size_hint() {
165            Some(size) => HashSet::with_capacity_and_hasher(size, S::default()),
166            None => HashSet::default(),
167        };
168
169        {
170            let values = values.pin();
171            while let Some(key) = access.next_element()? {
172                values.insert(key);
173            }
174        }
175
176        Ok(values)
177    }
178}
179
180#[cfg(test)]
181mod test {
182    use crate::HashMap;
183    use crate::HashSet;
184
185    #[test]
186    fn test_map() {
187        let map: HashMap<u8, u8> = HashMap::new();
188        let guard = map.guard();
189
190        map.insert(0, 4, &guard);
191        map.insert(1, 3, &guard);
192        map.insert(2, 2, &guard);
193        map.insert(3, 1, &guard);
194        map.insert(4, 0, &guard);
195
196        let serialized = serde_json::to_string(&map).unwrap();
197        let deserialized = serde_json::from_str(&serialized).unwrap();
198
199        assert_eq!(map, deserialized);
200    }
201
202    #[test]
203    fn test_set() {
204        let map: HashSet<u8> = HashSet::new();
205        let guard = map.guard();
206
207        map.insert(0, &guard);
208        map.insert(1, &guard);
209        map.insert(2, &guard);
210        map.insert(3, &guard);
211        map.insert(4, &guard);
212
213        let serialized = serde_json::to_string(&map).unwrap();
214        let deserialized = serde_json::from_str(&serialized).unwrap();
215
216        assert_eq!(map, deserialized);
217    }
218}