rkyv/rc/
validation.rs

1//! Validation implementations for shared pointers.
2
3use super::{ArchivedRc, ArchivedRcWeak, ArchivedRcWeakTag, ArchivedRcWeakVariantSome};
4use crate::{
5    validation::{ArchiveContext, LayoutRaw, SharedContext},
6    ArchivePointee, RelPtr,
7};
8use bytecheck::{CheckBytes, Error};
9use core::{any::TypeId, convert::Infallible, fmt, ptr};
10use ptr_meta::Pointee;
11
12/// Errors that can occur while checking archived shared pointers.
13#[derive(Debug)]
14pub enum SharedPointerError<T, R, C> {
15    /// An error occurred while checking the bytes of a shared value
16    PointerCheckBytesError(T),
17    /// An error occurred while checking the bytes of a shared reference
18    ValueCheckBytesError(R),
19    /// A context error occurred
20    ContextError(C),
21}
22
23impl<T, R, C> fmt::Display for SharedPointerError<T, R, C>
24where
25    T: fmt::Display,
26    R: fmt::Display,
27    C: fmt::Display,
28{
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            SharedPointerError::PointerCheckBytesError(e) => e.fmt(f),
32            SharedPointerError::ValueCheckBytesError(e) => e.fmt(f),
33            SharedPointerError::ContextError(e) => e.fmt(f),
34        }
35    }
36}
37
38#[cfg(feature = "std")]
39const _: () = {
40    use std::error::Error;
41
42    impl<T, R, C> Error for SharedPointerError<T, R, C>
43    where
44        T: Error + 'static,
45        R: Error + 'static,
46        C: Error + 'static,
47    {
48        fn source(&self) -> Option<&(dyn Error + 'static)> {
49            match self {
50                SharedPointerError::PointerCheckBytesError(e) => Some(e as &dyn Error),
51                SharedPointerError::ValueCheckBytesError(e) => Some(e as &dyn Error),
52                SharedPointerError::ContextError(e) => Some(e as &dyn Error),
53            }
54        }
55    }
56};
57
58/// Errors that can occur while checking archived weak pointers.
59#[derive(Debug)]
60pub enum WeakPointerError<T, R, C> {
61    /// The weak pointer had an invalid tag
62    InvalidTag(u8),
63    /// An error occurred while checking the underlying shared pointer
64    CheckBytes(SharedPointerError<T, R, C>),
65}
66
67impl<T: fmt::Display, R: fmt::Display, C: fmt::Display> fmt::Display for WeakPointerError<T, R, C> {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        match self {
70            WeakPointerError::InvalidTag(tag) => {
71                write!(f, "archived weak had invalid tag: {}", tag)
72            }
73            WeakPointerError::CheckBytes(e) => e.fmt(f),
74        }
75    }
76}
77
78#[cfg(feature = "std")]
79const _: () = {
80    use std::error::Error;
81
82    impl<T, R, C> Error for WeakPointerError<T, R, C>
83    where
84        T: Error + 'static,
85        R: Error + 'static,
86        C: Error + 'static,
87    {
88        fn source(&self) -> Option<&(dyn Error + 'static)> {
89            match self {
90                WeakPointerError::InvalidTag(_) => None,
91                WeakPointerError::CheckBytes(e) => Some(e as &dyn Error),
92            }
93        }
94    }
95};
96
97impl<T, R, C> From<Infallible> for WeakPointerError<T, R, C> {
98    fn from(_: Infallible) -> Self {
99        unsafe { core::hint::unreachable_unchecked() }
100    }
101}
102
103impl<T, F, C> CheckBytes<C> for ArchivedRc<T, F>
104where
105    T: ArchivePointee + CheckBytes<C> + LayoutRaw + Pointee + ?Sized + 'static,
106    C: ArchiveContext + SharedContext + ?Sized,
107    T::ArchivedMetadata: CheckBytes<C>,
108    C::Error: Error,
109    F: 'static,
110{
111    type Error =
112        SharedPointerError<<T::ArchivedMetadata as CheckBytes<C>>::Error, T::Error, C::Error>;
113
114    #[inline]
115    unsafe fn check_bytes<'a>(
116        value: *const Self,
117        context: &mut C,
118    ) -> Result<&'a Self, Self::Error> {
119        let rel_ptr = RelPtr::<T>::manual_check_bytes(value.cast(), context)
120            .map_err(SharedPointerError::PointerCheckBytesError)?;
121        let ptr = context
122            .check_rel_ptr(rel_ptr)
123            .map_err(SharedPointerError::ContextError)?;
124
125        let type_id = TypeId::of::<Self>();
126        if context
127            .register_shared_ptr(ptr.cast(), type_id)
128            .map_err(SharedPointerError::ContextError)?
129        {
130            context
131                .bounds_check_subtree_ptr(ptr)
132                .map_err(SharedPointerError::ContextError)?;
133
134            let range = context
135                .push_prefix_subtree(ptr)
136                .map_err(SharedPointerError::ContextError)?;
137            T::check_bytes(ptr, context).map_err(SharedPointerError::ValueCheckBytesError)?;
138            context
139                .pop_prefix_range(range)
140                .map_err(SharedPointerError::ContextError)?;
141        }
142        Ok(&*value)
143    }
144}
145
146impl ArchivedRcWeakTag {
147    const TAG_NONE: u8 = ArchivedRcWeakTag::None as u8;
148    const TAG_SOME: u8 = ArchivedRcWeakTag::Some as u8;
149}
150
151impl<T, F, C> CheckBytes<C> for ArchivedRcWeak<T, F>
152where
153    T: ArchivePointee + CheckBytes<C> + LayoutRaw + Pointee + ?Sized + 'static,
154    C: ArchiveContext + SharedContext + ?Sized,
155    T::ArchivedMetadata: CheckBytes<C>,
156    C::Error: Error,
157    F: 'static,
158{
159    type Error =
160        WeakPointerError<<T::ArchivedMetadata as CheckBytes<C>>::Error, T::Error, C::Error>;
161
162    #[inline]
163    unsafe fn check_bytes<'a>(
164        value: *const Self,
165        context: &mut C,
166    ) -> Result<&'a Self, Self::Error> {
167        let tag = *u8::check_bytes(value.cast::<u8>(), context)?;
168        match tag {
169            ArchivedRcWeakTag::TAG_NONE => (),
170            ArchivedRcWeakTag::TAG_SOME => {
171                let value = value.cast::<ArchivedRcWeakVariantSome<T, F>>();
172                ArchivedRc::<T, F>::check_bytes(ptr::addr_of!((*value).1), context)
173                    .map_err(WeakPointerError::CheckBytes)?;
174            }
175            _ => return Err(WeakPointerError::InvalidTag(tag)),
176        }
177        Ok(&*value)
178    }
179}