1#![doc=include_str!("../../doc/serdes/utils.md")]
2
3use core::{
4 any,
5 fmt::{
6 self,
7 Formatter,
8 },
9 marker::PhantomData,
10 mem::MaybeUninit,
11};
12
13use serde::{
14 de::{
15 Deserialize,
16 Deserializer,
17 Error,
18 MapAccess,
19 SeqAccess,
20 Unexpected,
21 Visitor,
22 },
23 ser::{
24 Serialize,
25 SerializeSeq,
26 SerializeStruct,
27 SerializeTuple,
28 Serializer,
29 },
30};
31use wyz::comu::Const;
32
33use crate::{
34 domain::Domain,
35 index::BitIdx,
36 mem::{
37 bits_of,
38 BitRegister,
39 },
40 order::BitOrder,
41 store::BitStore,
42 view::BitViewSized,
43};
44
45pub(super) struct TypeName<T>(PhantomData<T>);
48
49impl<T> TypeName<T> {
50 fn new() -> Self {
52 TypeName(PhantomData)
53 }
54}
55
56impl<'de, T> Deserialize<'de> for TypeName<T> {
57 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
58 where D: Deserializer<'de> {
59 deserializer.deserialize_str(Self::new())
60 }
61}
62
63impl<'de, T> Visitor<'de> for TypeName<T> {
64 type Value = Self;
65
66 fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
67 write!(fmt, "the string {:?}", any::type_name::<T>())
68 }
69
70 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
71 where E: serde::de::Error {
72 if value == any::type_name::<T>() {
73 Ok(self)
74 }
75 else {
76 Err(serde::de::Error::invalid_value(
77 Unexpected::Str(value),
78 &self,
79 ))
80 }
81 }
82}
83
84static FIELDS: &[&str] = &["width", "index"];
86
87enum Field {
89 Width,
91 Index,
93}
94
95struct FieldVisitor;
97
98impl<'de> Deserialize<'de> for Field {
99 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100 where D: Deserializer<'de> {
101 deserializer.deserialize_identifier(FieldVisitor)
102 }
103}
104
105impl<'de> Visitor<'de> for FieldVisitor {
106 type Value = Field;
107
108 fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
109 fmt.write_str("field identifier")
110 }
111
112 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
113 where E: serde::de::Error {
114 match value {
115 "width" => Ok(Field::Width),
116 "index" => Ok(Field::Index),
117 _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
118 }
119 }
120}
121
122impl<R> Serialize for BitIdx<R>
123where R: BitRegister
124{
125 #[inline]
126 fn serialize<S>(&self, serializer: S) -> super::Result<S>
127 where S: Serializer {
128 let mut state = serializer.serialize_struct("BitIdx", FIELDS.len())?;
129
130 state.serialize_field(FIELDS[0], &(bits_of::<R>() as u8))?;
132 state.serialize_field(FIELDS[1], &self.into_inner())?;
134
135 state.end()
136 }
137}
138
139impl<'de, R> Deserialize<'de> for BitIdx<R>
140where R: BitRegister
141{
142 #[inline]
143 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
144 where D: Deserializer<'de> {
145 deserializer.deserialize_struct(
146 "BitIdx",
147 FIELDS,
148 BitIdxVisitor::<R>::THIS,
149 )
150 }
151}
152
153impl<T, O> Serialize for Domain<'_, Const, T, O>
154where
155 T: BitStore,
156 O: BitOrder,
157 T::Mem: Serialize,
158{
159 #[inline]
160 fn serialize<S>(&self, serializer: S) -> super::Result<S>
161 where S: Serializer {
162 let mut state = serializer.serialize_seq(Some(self.len()))?;
164 for elem in *self {
165 state.serialize_element(&elem)?;
166 }
167 state.end()
168 }
169}
170
171#[repr(transparent)]
175#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
176pub(super) struct Array<T, const N: usize>
177where T: BitStore
178{
179 pub(super) inner: [T; N],
181}
182
183impl<T, const N: usize> Array<T, N>
184where T: BitStore
185{
186 pub(super) fn from_ref(arr: &[T; N]) -> &Self {
193 unsafe { &*(arr as *const [T; N] as *const Self) }
194 }
195}
196
197impl<T, const N: usize> Serialize for Array<T, N>
198where
199 T: BitStore,
200 T::Mem: Serialize,
201{
202 #[inline]
203 fn serialize<S>(&self, serializer: S) -> super::Result<S>
204 where S: Serializer {
205 let mut state = serializer.serialize_tuple(N)?;
208 for elem in self.inner.as_raw_slice().iter().map(BitStore::load_value) {
209 state.serialize_element(&elem)?
210 }
211 state.end()
212 }
213}
214
215impl<'de, T, const N: usize> Deserialize<'de> for Array<T, N>
216where
217 T: BitStore,
218 T::Mem: Deserialize<'de>,
219{
220 #[inline]
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where D: Deserializer<'de> {
223 deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>::THIS)
224 }
225}
226
227struct ArrayVisitor<T, const N: usize>
229where T: BitStore
230{
231 inner: PhantomData<[T; N]>,
233}
234
235impl<T, const N: usize> ArrayVisitor<T, N>
236where T: BitStore
237{
238 const THIS: Self = Self { inner: PhantomData };
240}
241
242impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
243where
244 T: BitStore,
245 T::Mem: Deserialize<'de>,
246{
247 type Value = Array<T, N>;
248
249 #[inline]
250 fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
251 write!(fmt, "a [{}; {}]", any::type_name::<T>(), N)
252 }
253
254 #[inline]
255 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
256 where V: SeqAccess<'de> {
257 let mut uninit = [MaybeUninit::<T::Mem>::uninit(); N];
258 for (idx, slot) in uninit.iter_mut().enumerate() {
259 slot.write(
260 seq.next_element::<T::Mem>()?
261 .ok_or_else(|| <V::Error>::invalid_length(idx, &self))?,
262 );
263 }
264 Ok(Array {
265 inner: uninit
266 .map(|elem| unsafe { MaybeUninit::assume_init(elem) })
267 .map(BitStore::new),
268 })
269 }
270}
271
272struct BitIdxVisitor<R>
274where R: BitRegister
275{
276 inner: PhantomData<R>,
278}
279
280impl<R> BitIdxVisitor<R>
281where R: BitRegister
282{
283 const THIS: Self = Self { inner: PhantomData };
285
286 #[inline]
288 fn assemble<E>(self, width: u8, index: u8) -> Result<BitIdx<R>, E>
289 where E: Error {
290 if width != bits_of::<R>() as u8 {
292 return Err(E::invalid_type(
293 Unexpected::Unsigned(width as u64),
294 &self,
295 ));
296 }
297
298 BitIdx::<R>::new(index).map_err(|_| {
300 E::invalid_value(Unexpected::Unsigned(index as u64), &self)
301 })
302 }
303}
304
305impl<'de, R> Visitor<'de> for BitIdxVisitor<R>
306where R: BitRegister
307{
308 type Value = BitIdx<R>;
309
310 #[inline]
311 fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
312 write!(fmt, "a valid `BitIdx<u{}>`", bits_of::<R>())
313 }
314
315 #[inline]
316 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
317 where V: SeqAccess<'de> {
318 let width = seq
319 .next_element::<u8>()?
320 .ok_or_else(|| <V::Error>::invalid_length(0, &self))?;
321 let index = seq
322 .next_element::<u8>()?
323 .ok_or_else(|| <V::Error>::invalid_length(1, &self))?;
324
325 self.assemble(width, index)
326 }
327
328 #[inline]
329 fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
330 where V: MapAccess<'de> {
331 let mut width = None;
332 let mut index = None;
333
334 while let Some(key) = map.next_key()? {
335 match key {
336 Field::Width => {
337 if width.replace(map.next_value::<u8>()?).is_some() {
338 return Err(<V::Error>::duplicate_field("width"));
339 }
340 },
341 Field::Index => {
342 if index.replace(map.next_value::<u8>()?).is_some() {
343 return Err(<V::Error>::duplicate_field("index"));
344 }
345 },
346 }
347 }
348
349 let width = width.ok_or_else(|| <V::Error>::missing_field("width"))?;
350 let index = index.ok_or_else(|| <V::Error>::missing_field("index"))?;
351
352 self.assemble(width, index)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use serde_test::{
359 assert_de_tokens,
360 assert_de_tokens_error,
361 assert_ser_tokens,
362 Token,
363 };
364
365 use super::*;
366
367 #[test]
368 fn array_wrapper() {
369 let array = Array { inner: [0u8; 40] };
370 #[rustfmt::skip]
371 let tokens = &[
372 Token::Tuple { len: 40 },
373 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
374 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
375 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
376 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
377 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
378 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
379 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
380 Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
381 Token::TupleEnd,
382 ];
383 assert_ser_tokens(&array, tokens);
384 assert_de_tokens(&array, tokens);
385
386 let tokens = &[Token::Tuple { len: 1 }, Token::U32(0), Token::TupleEnd];
387 assert_de_tokens_error::<Array<u32, 2>>(
388 tokens,
389 "invalid length 1, expected a [u32; 2]",
390 );
391 }
392
393 #[test]
394 fn bit_idx() {
395 let idx = BitIdx::<u32>::new(20).unwrap();
396 let tokens = &mut [
397 Token::Struct {
398 name: "BitIdx",
399 len: 2,
400 },
401 Token::Str("width"),
402 Token::U8(32),
403 Token::Str("index"),
404 Token::U8(20),
405 Token::StructEnd,
406 ];
407 assert_ser_tokens(&idx, tokens);
408 tokens[1] = Token::BorrowedStr("width");
409 tokens[3] = Token::BorrowedStr("index");
410 assert_de_tokens(&idx, tokens);
411
412 let idx = BitIdx::<u16>::new(10).unwrap();
413 let tokens = &[
414 Token::Seq { len: Some(2) },
415 Token::U8(16),
416 Token::U8(10),
417 Token::SeqEnd,
418 ];
419 assert_de_tokens(&idx, tokens);
420
421 assert_de_tokens_error::<BitIdx<u16>>(
422 &[
423 Token::Seq { len: Some(2) },
424 Token::U8(8),
425 Token::U8(0),
426 Token::SeqEnd,
427 ],
428 "invalid type: integer `8`, expected a valid `BitIdx<u16>`",
429 );
430 assert_de_tokens_error::<BitIdx<u16>>(
431 &[
432 Token::Seq { len: Some(2) },
433 Token::U8(16),
434 Token::U8(16),
435 Token::SeqEnd,
436 ],
437 "invalid value: integer `16`, expected a valid `BitIdx<u16>`",
438 );
439 assert_de_tokens_error::<BitIdx<u8>>(
440 &[
441 Token::Struct {
442 name: "BitIdx",
443 len: 1,
444 },
445 Token::BorrowedStr("unknown"),
446 ],
447 "unknown field `unknown`, expected `width` or `index`",
448 );
449 assert_de_tokens_error::<BitIdx<u8>>(
450 &[
451 Token::Struct {
452 name: "BitIdx",
453 len: 2,
454 },
455 Token::BorrowedStr("width"),
456 Token::U8(8),
457 Token::BorrowedStr("width"),
458 Token::U8(8),
459 Token::StructEnd,
460 ],
461 "duplicate field `width`",
462 );
463 assert_de_tokens_error::<BitIdx<u8>>(
464 &[
465 Token::Struct {
466 name: "BitIdx",
467 len: 2,
468 },
469 Token::BorrowedStr("index"),
470 Token::U8(7),
471 Token::BorrowedStr("index"),
472 Token::U8(7),
473 Token::StructEnd,
474 ],
475 "duplicate field `index`",
476 );
477 }
478}