bitvec/serdes/
utils.rs

1#![doc=include_str!("../../doc/serdes/utils.md")]
2
3use core::{
4	any,
5	fmt::{
6		self,
7		Formatter,
8	},
9	marker::PhantomData,
10	mem::MaybeUninit,
11};
12
13use serde::{
14	de::{
15		Deserialize,
16		Deserializer,
17		Error,
18		MapAccess,
19		SeqAccess,
20		Unexpected,
21		Visitor,
22	},
23	ser::{
24		Serialize,
25		SerializeSeq,
26		SerializeStruct,
27		SerializeTuple,
28		Serializer,
29	},
30};
31use wyz::comu::Const;
32
33use crate::{
34	domain::Domain,
35	index::BitIdx,
36	mem::{
37		bits_of,
38		BitRegister,
39	},
40	order::BitOrder,
41	store::BitStore,
42	view::BitViewSized,
43};
44
45/// A zero-sized type that deserializes from any string as long as it is equal
46/// to `any::type_name::<T>()`.
47pub(super) struct TypeName<T>(PhantomData<T>);
48
49impl<T> TypeName<T> {
50	/// Creates a type-name ghost for any type.
51	fn new() -> Self {
52		TypeName(PhantomData)
53	}
54}
55
56impl<'de, T> Deserialize<'de> for TypeName<T> {
57	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
58	where D: Deserializer<'de> {
59		deserializer.deserialize_str(Self::new())
60	}
61}
62
63impl<'de, T> Visitor<'de> for TypeName<T> {
64	type Value = Self;
65
66	fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
67		write!(fmt, "the string {:?}", any::type_name::<T>())
68	}
69
70	fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
71	where E: serde::de::Error {
72		if value == any::type_name::<T>() {
73			Ok(self)
74		}
75		else {
76			Err(serde::de::Error::invalid_value(
77				Unexpected::Str(value),
78				&self,
79			))
80		}
81	}
82}
83
84/// Fields used in the `BitIdx` transport format.
85static FIELDS: &[&str] = &["width", "index"];
86
87/// The components of a bit-idx in wire format.
88enum Field {
89	/// Denotes the maximum allowable value of the bit-idx.
90	Width,
91	/// Denotes the value of the bit-idx.
92	Index,
93}
94
95/// Visits field tokens of a bit-idx wire format.
96struct FieldVisitor;
97
98impl<'de> Deserialize<'de> for Field {
99	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100	where D: Deserializer<'de> {
101		deserializer.deserialize_identifier(FieldVisitor)
102	}
103}
104
105impl<'de> Visitor<'de> for FieldVisitor {
106	type Value = Field;
107
108	fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
109		fmt.write_str("field identifier")
110	}
111
112	fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
113	where E: serde::de::Error {
114		match value {
115			"width" => Ok(Field::Width),
116			"index" => Ok(Field::Index),
117			_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
118		}
119	}
120}
121
122impl<R> Serialize for BitIdx<R>
123where R: BitRegister
124{
125	#[inline]
126	fn serialize<S>(&self, serializer: S) -> super::Result<S>
127	where S: Serializer {
128		let mut state = serializer.serialize_struct("BitIdx", FIELDS.len())?;
129
130		//  Emit the bit-width of the `R` type.
131		state.serialize_field(FIELDS[0], &(bits_of::<R>() as u8))?;
132		//  Emit the actual head-bit index.
133		state.serialize_field(FIELDS[1], &self.into_inner())?;
134
135		state.end()
136	}
137}
138
139impl<'de, R> Deserialize<'de> for BitIdx<R>
140where R: BitRegister
141{
142	#[inline]
143	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
144	where D: Deserializer<'de> {
145		deserializer.deserialize_struct(
146			"BitIdx",
147			FIELDS,
148			BitIdxVisitor::<R>::THIS,
149		)
150	}
151}
152
153impl<T, O> Serialize for Domain<'_, Const, T, O>
154where
155	T: BitStore,
156	O: BitOrder,
157	T::Mem: Serialize,
158{
159	#[inline]
160	fn serialize<S>(&self, serializer: S) -> super::Result<S>
161	where S: Serializer {
162		//  Domain<T> is functionally equivalent to `[T::Mem]`.
163		let mut state = serializer.serialize_seq(Some(self.len()))?;
164		for elem in *self {
165			state.serialize_element(&elem)?;
166		}
167		state.end()
168	}
169}
170
171/** `serde` only provides implementations for `[T; 0 ..= 32]`. This wrapper
172provides the same de/ser logic, but allows it to be used on arrays of any size.
173**/
174#[repr(transparent)]
175#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
176pub(super) struct Array<T, const N: usize>
177where T: BitStore
178{
179	/// The data buffer being transported.
180	pub(super) inner: [T; N],
181}
182
183impl<T, const N: usize> Array<T, N>
184where T: BitStore
185{
186	/// Constructs a `&Array` reference from an `&[T; N]` reference.
187	///
188	/// ## Safety
189	///
190	/// `Array` is `#[repr(transparent)]`, so this address transformation is
191	/// always sound.
192	pub(super) fn from_ref(arr: &[T; N]) -> &Self {
193		unsafe { &*(arr as *const [T; N] as *const Self) }
194	}
195}
196
197impl<T, const N: usize> Serialize for Array<T, N>
198where
199	T: BitStore,
200	T::Mem: Serialize,
201{
202	#[inline]
203	fn serialize<S>(&self, serializer: S) -> super::Result<S>
204	where S: Serializer {
205		//  `serde` serializes arrays as a tuple, so that transport formats can
206		//  safely choose to keep or discard the length counter.
207		let mut state = serializer.serialize_tuple(N)?;
208		for elem in self.inner.as_raw_slice().iter().map(BitStore::load_value) {
209			state.serialize_element(&elem)?
210		}
211		state.end()
212	}
213}
214
215impl<'de, T, const N: usize> Deserialize<'de> for Array<T, N>
216where
217	T: BitStore,
218	T::Mem: Deserialize<'de>,
219{
220	#[inline]
221	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222	where D: Deserializer<'de> {
223		deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>::THIS)
224	}
225}
226
227/// Assists in deserialization of a static `[T; N]` for any `N`.
228struct ArrayVisitor<T, const N: usize>
229where T: BitStore
230{
231	/// This produces an array during its work.
232	inner: PhantomData<[T; N]>,
233}
234
235impl<T, const N: usize> ArrayVisitor<T, N>
236where T: BitStore
237{
238	/// A blank visitor in its ready state.
239	const THIS: Self = Self { inner: PhantomData };
240}
241
242impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
243where
244	T: BitStore,
245	T::Mem: Deserialize<'de>,
246{
247	type Value = Array<T, N>;
248
249	#[inline]
250	fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
251		write!(fmt, "a [{}; {}]", any::type_name::<T>(), N)
252	}
253
254	#[inline]
255	fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
256	where V: SeqAccess<'de> {
257		let mut uninit = [MaybeUninit::<T::Mem>::uninit(); N];
258		for (idx, slot) in uninit.iter_mut().enumerate() {
259			slot.write(
260				seq.next_element::<T::Mem>()?
261					.ok_or_else(|| <V::Error>::invalid_length(idx, &self))?,
262			);
263		}
264		Ok(Array {
265			inner: uninit
266				.map(|elem| unsafe { MaybeUninit::assume_init(elem) })
267				.map(BitStore::new),
268		})
269	}
270}
271
272/// Assists in deserialization of a `BitIdx` value.
273struct BitIdxVisitor<R>
274where R: BitRegister
275{
276	/// This requires carrying the register type information.
277	inner: PhantomData<R>,
278}
279
280impl<R> BitIdxVisitor<R>
281where R: BitRegister
282{
283	/// A blank visitor in its ready state.
284	const THIS: Self = Self { inner: PhantomData };
285
286	/// Attempts to assemble deserialized components into an output value.
287	#[inline]
288	fn assemble<E>(self, width: u8, index: u8) -> Result<BitIdx<R>, E>
289	where E: Error {
290		//  Fail if the transported type width does not match the destination.
291		if width != bits_of::<R>() as u8 {
292			return Err(E::invalid_type(
293				Unexpected::Unsigned(width as u64),
294				&self,
295			));
296		}
297
298		//  Capture an invalid index value and route it to the error handler.
299		BitIdx::<R>::new(index).map_err(|_| {
300			E::invalid_value(Unexpected::Unsigned(index as u64), &self)
301		})
302	}
303}
304
305impl<'de, R> Visitor<'de> for BitIdxVisitor<R>
306where R: BitRegister
307{
308	type Value = BitIdx<R>;
309
310	#[inline]
311	fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
312		write!(fmt, "a valid `BitIdx<u{}>`", bits_of::<R>())
313	}
314
315	#[inline]
316	fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
317	where V: SeqAccess<'de> {
318		let width = seq
319			.next_element::<u8>()?
320			.ok_or_else(|| <V::Error>::invalid_length(0, &self))?;
321		let index = seq
322			.next_element::<u8>()?
323			.ok_or_else(|| <V::Error>::invalid_length(1, &self))?;
324
325		self.assemble(width, index)
326	}
327
328	#[inline]
329	fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
330	where V: MapAccess<'de> {
331		let mut width = None;
332		let mut index = None;
333
334		while let Some(key) = map.next_key()? {
335			match key {
336				Field::Width => {
337					if width.replace(map.next_value::<u8>()?).is_some() {
338						return Err(<V::Error>::duplicate_field("width"));
339					}
340				},
341				Field::Index => {
342					if index.replace(map.next_value::<u8>()?).is_some() {
343						return Err(<V::Error>::duplicate_field("index"));
344					}
345				},
346			}
347		}
348
349		let width = width.ok_or_else(|| <V::Error>::missing_field("width"))?;
350		let index = index.ok_or_else(|| <V::Error>::missing_field("index"))?;
351
352		self.assemble(width, index)
353	}
354}
355
356#[cfg(test)]
357mod tests {
358	use serde_test::{
359		assert_de_tokens,
360		assert_de_tokens_error,
361		assert_ser_tokens,
362		Token,
363	};
364
365	use super::*;
366
367	#[test]
368	fn array_wrapper() {
369		let array = Array { inner: [0u8; 40] };
370		#[rustfmt::skip]
371		let tokens = &[
372			Token::Tuple { len: 40 },
373			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
374			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
375			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
376			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
377			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
378			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
379			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
380			Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0), Token::U8(0),
381			Token::TupleEnd,
382		];
383		assert_ser_tokens(&array, tokens);
384		assert_de_tokens(&array, tokens);
385
386		let tokens = &[Token::Tuple { len: 1 }, Token::U32(0), Token::TupleEnd];
387		assert_de_tokens_error::<Array<u32, 2>>(
388			tokens,
389			"invalid length 1, expected a [u32; 2]",
390		);
391	}
392
393	#[test]
394	fn bit_idx() {
395		let idx = BitIdx::<u32>::new(20).unwrap();
396		let tokens = &mut [
397			Token::Struct {
398				name: "BitIdx",
399				len:  2,
400			},
401			Token::Str("width"),
402			Token::U8(32),
403			Token::Str("index"),
404			Token::U8(20),
405			Token::StructEnd,
406		];
407		assert_ser_tokens(&idx, tokens);
408		tokens[1] = Token::BorrowedStr("width");
409		tokens[3] = Token::BorrowedStr("index");
410		assert_de_tokens(&idx, tokens);
411
412		let idx = BitIdx::<u16>::new(10).unwrap();
413		let tokens = &[
414			Token::Seq { len: Some(2) },
415			Token::U8(16),
416			Token::U8(10),
417			Token::SeqEnd,
418		];
419		assert_de_tokens(&idx, tokens);
420
421		assert_de_tokens_error::<BitIdx<u16>>(
422			&[
423				Token::Seq { len: Some(2) },
424				Token::U8(8),
425				Token::U8(0),
426				Token::SeqEnd,
427			],
428			"invalid type: integer `8`, expected a valid `BitIdx<u16>`",
429		);
430		assert_de_tokens_error::<BitIdx<u16>>(
431			&[
432				Token::Seq { len: Some(2) },
433				Token::U8(16),
434				Token::U8(16),
435				Token::SeqEnd,
436			],
437			"invalid value: integer `16`, expected a valid `BitIdx<u16>`",
438		);
439		assert_de_tokens_error::<BitIdx<u8>>(
440			&[
441				Token::Struct {
442					name: "BitIdx",
443					len:  1,
444				},
445				Token::BorrowedStr("unknown"),
446			],
447			"unknown field `unknown`, expected `width` or `index`",
448		);
449		assert_de_tokens_error::<BitIdx<u8>>(
450			&[
451				Token::Struct {
452					name: "BitIdx",
453					len:  2,
454				},
455				Token::BorrowedStr("width"),
456				Token::U8(8),
457				Token::BorrowedStr("width"),
458				Token::U8(8),
459				Token::StructEnd,
460			],
461			"duplicate field `width`",
462		);
463		assert_de_tokens_error::<BitIdx<u8>>(
464			&[
465				Token::Struct {
466					name: "BitIdx",
467					len:  2,
468				},
469				Token::BorrowedStr("index"),
470				Token::U8(7),
471				Token::BorrowedStr("index"),
472				Token::U8(7),
473				Token::StructEnd,
474			],
475			"duplicate field `index`",
476		);
477	}
478}