Skip to main content

serde_reflection/
trace.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    de::Deserializer,
6    error::{Error, Result},
7    format::*,
8    ser::Serializer,
9    value::Value,
10};
11use erased_discriminant::Discriminant;
12use once_cell::sync::Lazy;
13use serde::{de::DeserializeSeed, Deserialize, Serialize};
14use std::any::TypeId;
15use std::collections::{BTreeMap, BTreeSet};
16
17/// A map of container formats.
18pub type Registry = BTreeMap<String, ContainerFormat>;
19
20/// Structure to drive the tracing of Serde serialization and deserialization.
21/// This typically aims at computing a `Registry`.
22#[derive(Debug)]
23pub struct Tracer {
24    /// Hold configuration options.
25    pub(crate) config: TracerConfig,
26
27    /// Formats of the named containers discovered so far, while tracing
28    /// serialization and/or deserialization.
29    pub(crate) registry: Registry,
30
31    /// Enums that have detected to be yet incomplete (i.e. missing variants)
32    /// while tracing deserialization.
33    pub(crate) incomplete_enums: BTreeMap<String, IncompleteEnumReason>,
34
35    /// Discriminant associated with each variant of each enum.
36    pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>,
37
38    /// Enum variants whose serialized `VariantFormat` is already complete.
39    /// Keyed by (enum_name, variant_name). This allows `deserialize_enum` to skip
40    /// re-exploring variants that do not need further tracing.
41    pub(crate) serialized_variants: BTreeSet<(String, String)>,
42}
43
44/// Type of untraced enum variants
45#[derive(Copy, Clone, Debug)]
46pub enum IncompleteEnumReason {
47    /// There are variant names that have not yet been traced.
48    NamedVariantsRemaining,
49    /// There are variant numbers that have not yet been traced.
50    IndexedVariantsRemaining,
51}
52
53#[derive(Eq, PartialEq, Ord, PartialOrd, Debug)]
54pub(crate) enum VariantId<'a> {
55    Index(u32),
56    Name(&'a str),
57}
58
59/// User inputs, aka "samples", recorded during serialization.
60/// This will help passing user-defined checks during deserialization.
61#[derive(Debug, Default)]
62pub struct Samples {
63    pub(crate) values: BTreeMap<&'static str, Value>,
64}
65
66impl Samples {
67    /// Create a new structure to hold value samples.
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Obtain a (serialized) sample.
73    pub fn value(&self, name: &'static str) -> Option<&Value> {
74        self.values.get(name)
75    }
76}
77
78/// Configuration object to create a tracer.
79#[derive(Debug)]
80pub struct TracerConfig {
81    pub(crate) is_human_readable: bool,
82    pub(crate) record_samples_for_newtype_structs: bool,
83    pub(crate) record_samples_for_tuple_structs: bool,
84    pub(crate) record_samples_for_structs: bool,
85    pub(crate) default_bool_value: bool,
86    pub(crate) default_u8_value: u8,
87    pub(crate) default_u16_value: u16,
88    pub(crate) default_u32_value: u32,
89    pub(crate) default_u64_value: u64,
90    pub(crate) default_u128_value: u128,
91    pub(crate) default_i8_value: i8,
92    pub(crate) default_i16_value: i16,
93    pub(crate) default_i32_value: i32,
94    pub(crate) default_i64_value: i64,
95    pub(crate) default_i128_value: i128,
96    pub(crate) default_f32_value: f32,
97    pub(crate) default_f64_value: f64,
98    pub(crate) default_char_value: char,
99    pub(crate) default_borrowed_str_value: &'static str,
100    pub(crate) default_string_value: String,
101    pub(crate) default_borrowed_bytes_value: &'static [u8],
102    pub(crate) default_byte_buf_value: Vec<u8>,
103}
104
105impl Default for TracerConfig {
106    /// Create a new structure to hold value samples.
107    fn default() -> Self {
108        Self {
109            is_human_readable: false,
110            record_samples_for_newtype_structs: true,
111            record_samples_for_tuple_structs: false,
112            record_samples_for_structs: false,
113            default_bool_value: false,
114            default_u8_value: 0,
115            default_u16_value: 0,
116            default_u32_value: 0,
117            default_u64_value: 0,
118            default_u128_value: 0,
119            default_i8_value: 0,
120            default_i16_value: 0,
121            default_i32_value: 0,
122            default_i64_value: 0,
123            default_i128_value: 0,
124            default_f32_value: 0.0,
125            default_f64_value: 0.0,
126            default_char_value: 'A',
127            default_borrowed_str_value: "",
128            default_string_value: String::new(),
129            default_borrowed_bytes_value: b"",
130            default_byte_buf_value: Vec::new(),
131        }
132    }
133}
134
135macro_rules! define_default_value_setter {
136    ($method:ident, $ty:ty) => {
137        /// The default serialized value for this primitive type.
138        pub fn $method(mut self, value: $ty) -> Self {
139            self.$method = value;
140            self
141        }
142    };
143}
144
145impl TracerConfig {
146    /// Whether to trace the human readable encoding of (de)serialization.
147    #[allow(clippy::wrong_self_convention)]
148    pub fn is_human_readable(mut self, value: bool) -> Self {
149        self.is_human_readable = value;
150        self
151    }
152
153    /// Record samples of newtype structs during serialization and inject them during deserialization.
154    pub fn record_samples_for_newtype_structs(mut self, value: bool) -> Self {
155        self.record_samples_for_newtype_structs = value;
156        self
157    }
158
159    /// Record samples of tuple structs during serialization and inject them during deserialization.
160    pub fn record_samples_for_tuple_structs(mut self, value: bool) -> Self {
161        self.record_samples_for_tuple_structs = value;
162        self
163    }
164
165    /// Record samples of (regular) structs during serialization and inject them during deserialization.
166    pub fn record_samples_for_structs(mut self, value: bool) -> Self {
167        self.record_samples_for_structs = value;
168        self
169    }
170
171    define_default_value_setter!(default_bool_value, bool);
172    define_default_value_setter!(default_u8_value, u8);
173    define_default_value_setter!(default_u16_value, u16);
174    define_default_value_setter!(default_u32_value, u32);
175    define_default_value_setter!(default_u64_value, u64);
176    define_default_value_setter!(default_u128_value, u128);
177    define_default_value_setter!(default_i8_value, i8);
178    define_default_value_setter!(default_i16_value, i16);
179    define_default_value_setter!(default_i32_value, i32);
180    define_default_value_setter!(default_i64_value, i64);
181    define_default_value_setter!(default_i128_value, i128);
182    define_default_value_setter!(default_f32_value, f32);
183    define_default_value_setter!(default_f64_value, f64);
184    define_default_value_setter!(default_char_value, char);
185    define_default_value_setter!(default_borrowed_str_value, &'static str);
186    define_default_value_setter!(default_string_value, String);
187    define_default_value_setter!(default_borrowed_bytes_value, &'static [u8]);
188    define_default_value_setter!(default_byte_buf_value, Vec<u8>);
189}
190
191impl Tracer {
192    /// Start tracing deserialization.
193    pub fn new(config: TracerConfig) -> Self {
194        Self {
195            config,
196            registry: BTreeMap::new(),
197            incomplete_enums: BTreeMap::new(),
198            discriminants: BTreeMap::new(),
199            serialized_variants: BTreeSet::new(),
200        }
201    }
202
203    /// Trace the serialization of a particular value.
204    /// * Nested containers will be added to the tracing registry, indexed by
205    ///   their (non-qualified) name.
206    /// * Sampled Rust values will be inserted into `samples` to benefit future calls
207    ///   to the `trace_type_*` methods.
208    pub fn trace_value<T>(&mut self, samples: &mut Samples, value: &T) -> Result<(Format, Value)>
209    where
210        T: ?Sized + Serialize,
211    {
212        let serializer = Serializer::new(self, samples);
213        let (mut format, sample) = value.serialize(serializer)?;
214        format.reduce();
215        Ok((format, sample))
216    }
217
218    /// Trace a single deserialization of a particular type.
219    /// * Nested containers will be added to the tracing registry, indexed by
220    ///   their (non-qualified) name.
221    /// * As a byproduct of deserialization, we also return a value of type `T`.
222    /// * Tracing deserialization of a type may fail if this type or some dependencies
223    ///   have implemented a custom deserializer that validates data. The solution is
224    ///   to make sure that `samples` holds enough sampled Rust values to cover all the
225    ///   custom types.
226    pub fn trace_type_once<'de, T>(&mut self, samples: &'de Samples) -> Result<(Format, T)>
227    where
228        T: Deserialize<'de>,
229    {
230        let mut format = Format::unknown();
231        let deserializer = Deserializer::new(self, samples, &mut format);
232        let value = T::deserialize(deserializer)?;
233        format.reduce();
234        Ok((format, value))
235    }
236
237    /// Same as `trace_type_once` for seeded deserialization.
238    pub fn trace_type_once_with_seed<'de, S>(
239        &mut self,
240        samples: &'de Samples,
241        seed: S,
242    ) -> Result<(Format, S::Value)>
243    where
244        S: DeserializeSeed<'de>,
245    {
246        let mut format = Format::unknown();
247        let deserializer = Deserializer::new(self, samples, &mut format);
248        let value = seed.deserialize(deserializer)?;
249        format.reduce();
250        Ok((format, value))
251    }
252
253    /// Read the status of an enum and reset the value.
254    pub fn check_incomplete_enum(&mut self, name: &str) -> Option<IncompleteEnumReason> {
255        self.incomplete_enums.remove(name)
256    }
257
258    /// Same as `trace_type_once` but if `T` is an enum, we repeat the process
259    /// until all variants of `T` are covered.
260    /// We accumulate and return all the sampled values at the end.
261    pub fn trace_type<'de, T>(&mut self, samples: &'de Samples) -> Result<(Format, Vec<T>)>
262    where
263        T: Deserialize<'de>,
264    {
265        let mut values = Vec::new();
266        loop {
267            let (format, value) = self.trace_type_once::<T>(samples)?;
268            values.push(value);
269            if let Format::TypeName(name) = &format {
270                if let Some(reason) = self.check_incomplete_enum(name) {
271                    // Restart the analysis to find more variants of T.
272                    if let IncompleteEnumReason::NamedVariantsRemaining = reason {
273                        values.pop().unwrap();
274                    }
275                    continue;
276                }
277            }
278            return Ok((format, values));
279        }
280    }
281
282    /// Trace a type `T` that is simple enough that no samples of values are needed.
283    /// * If `T` is an enum, the tracing iterates until all variants of `T` are covered.
284    /// * Accumulate and return all the sampled values at the end.
285    ///   This is merely a shortcut for `self.trace_type` with a fixed empty set of samples.
286    pub fn trace_simple_type<'de, T>(&mut self) -> Result<(Format, Vec<T>)>
287    where
288        T: Deserialize<'de>,
289    {
290        static SAMPLES: Lazy<Samples> = Lazy::new(Samples::new);
291        self.trace_type(&SAMPLES)
292    }
293
294    /// Same as `trace_type` for seeded deserialization.
295    pub fn trace_type_with_seed<'de, S>(
296        &mut self,
297        samples: &'de Samples,
298        seed: S,
299    ) -> Result<(Format, Vec<S::Value>)>
300    where
301        S: DeserializeSeed<'de> + Clone,
302    {
303        let mut values = Vec::new();
304        loop {
305            let (format, value) = self.trace_type_once_with_seed(samples, seed.clone())?;
306            values.push(value);
307            if let Format::TypeName(name) = &format {
308                if let Some(reason) = self.check_incomplete_enum(name) {
309                    // Restart the analysis to find more variants of T.
310                    if let IncompleteEnumReason::NamedVariantsRemaining = reason {
311                        values.pop().unwrap();
312                    }
313                    continue;
314                }
315            }
316            return Ok((format, values));
317        }
318    }
319
320    /// Finish tracing and recover a map of normalized formats.
321    /// Returns an error if we detect incompletely traced types.
322    /// This may happen in a few of cases:
323    /// * We traced serialization of user-provided values but we are still missing the content
324    ///   of an option type, the content of a sequence type, the key or the value of a dictionary type.
325    /// * We traced deserialization of an enum type but we detect that some enum variants are still missing.
326    pub fn registry(self) -> Result<Registry> {
327        let mut registry = self.registry;
328        for (name, format) in registry.iter_mut() {
329            format
330                .normalize()
331                .map_err(|_| Error::UnknownFormatInContainer(name.clone()))?;
332        }
333        if self.incomplete_enums.is_empty() {
334            Ok(registry)
335        } else {
336            Err(Error::MissingVariants(
337                self.incomplete_enums.into_keys().collect(),
338            ))
339        }
340    }
341
342    /// Same as registry but always return a value, even if we detected issues.
343    /// This should only be use for debugging.
344    pub fn registry_unchecked(self) -> Registry {
345        let mut registry = self.registry;
346        for format in registry.values_mut() {
347            format.normalize().unwrap_or(());
348        }
349        registry
350    }
351
352    pub(crate) fn record_container(
353        &mut self,
354        samples: &mut Samples,
355        name: &'static str,
356        format: ContainerFormat,
357        value: Value,
358        record_value: bool,
359    ) -> Result<(Format, Value)> {
360        self.registry.entry(name.to_string()).unify(format)?;
361        if record_value {
362            samples.values.insert(name, value.clone());
363        }
364        Ok((Format::TypeName(name.into()), value))
365    }
366
367    pub(crate) fn record_variant(
368        &mut self,
369        samples: &mut Samples,
370        name: &'static str,
371        variant_index: u32,
372        variant_name: &'static str,
373        variant: VariantFormat,
374        variant_value: Value,
375    ) -> Result<(Format, Value)> {
376        let mut normalized_variant = variant.clone();
377        let is_complete = normalized_variant.normalize().is_ok();
378        let mut variants = BTreeMap::new();
379        variants.insert(
380            variant_index,
381            Named {
382                name: variant_name.into(),
383                value: variant,
384            },
385        );
386        let format = ContainerFormat::Enum(variants);
387        let value = Value::Variant(variant_index, Box::new(variant_value));
388        if is_complete {
389            self.serialized_variants
390                .insert((name.to_string(), variant_name.to_string()));
391        }
392        self.record_container(samples, name, format, value, false)
393    }
394
395    pub(crate) fn get_sample<'de, 'a>(
396        &'a self,
397        samples: &'de Samples,
398        name: &'static str,
399    ) -> Option<(&'a ContainerFormat, &'de Value)> {
400        match samples.value(name) {
401            Some(value) => {
402                let format = self
403                    .registry
404                    .get(name)
405                    .expect("recorded containers should have a format already");
406                Some((format, value))
407            }
408            None => None,
409        }
410    }
411}