use super::{ArchivedRc, ArchivedRcWeak, ArchivedRcWeakTag, ArchivedRcWeakVariantSome};
use crate::{
validation::{ArchiveContext, LayoutRaw, SharedContext},
ArchivePointee, RelPtr,
};
use bytecheck::{CheckBytes, Error};
use core::{any::TypeId, convert::Infallible, fmt, ptr};
use ptr_meta::Pointee;
#[derive(Debug)]
pub enum SharedPointerError<T, R, C> {
PointerCheckBytesError(T),
ValueCheckBytesError(R),
ContextError(C),
}
impl<T, R, C> fmt::Display for SharedPointerError<T, R, C>
where
T: fmt::Display,
R: fmt::Display,
C: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SharedPointerError::PointerCheckBytesError(e) => e.fmt(f),
SharedPointerError::ValueCheckBytesError(e) => e.fmt(f),
SharedPointerError::ContextError(e) => e.fmt(f),
}
}
}
#[cfg(feature = "std")]
const _: () = {
use std::error::Error;
impl<T, R, C> Error for SharedPointerError<T, R, C>
where
T: Error + 'static,
R: Error + 'static,
C: Error + 'static,
{
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
SharedPointerError::PointerCheckBytesError(e) => Some(e as &dyn Error),
SharedPointerError::ValueCheckBytesError(e) => Some(e as &dyn Error),
SharedPointerError::ContextError(e) => Some(e as &dyn Error),
}
}
}
};
#[derive(Debug)]
pub enum WeakPointerError<T, R, C> {
InvalidTag(u8),
CheckBytes(SharedPointerError<T, R, C>),
}
impl<T: fmt::Display, R: fmt::Display, C: fmt::Display> fmt::Display for WeakPointerError<T, R, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WeakPointerError::InvalidTag(tag) => {
write!(f, "archived weak had invalid tag: {}", tag)
}
WeakPointerError::CheckBytes(e) => e.fmt(f),
}
}
}
#[cfg(feature = "std")]
const _: () = {
use std::error::Error;
impl<T, R, C> Error for WeakPointerError<T, R, C>
where
T: Error + 'static,
R: Error + 'static,
C: Error + 'static,
{
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
WeakPointerError::InvalidTag(_) => None,
WeakPointerError::CheckBytes(e) => Some(e as &dyn Error),
}
}
}
};
impl<T, R, C> From<Infallible> for WeakPointerError<T, R, C> {
fn from(_: Infallible) -> Self {
unsafe { core::hint::unreachable_unchecked() }
}
}
impl<T, F, C> CheckBytes<C> for ArchivedRc<T, F>
where
T: ArchivePointee + CheckBytes<C> + LayoutRaw + Pointee + ?Sized + 'static,
C: ArchiveContext + SharedContext + ?Sized,
T::ArchivedMetadata: CheckBytes<C>,
C::Error: Error,
F: 'static,
{
type Error =
SharedPointerError<<T::ArchivedMetadata as CheckBytes<C>>::Error, T::Error, C::Error>;
#[inline]
unsafe fn check_bytes<'a>(
value: *const Self,
context: &mut C,
) -> Result<&'a Self, Self::Error> {
let rel_ptr = RelPtr::<T>::manual_check_bytes(value.cast(), context)
.map_err(SharedPointerError::PointerCheckBytesError)?;
let ptr = context
.check_rel_ptr(rel_ptr)
.map_err(SharedPointerError::ContextError)?;
let type_id = TypeId::of::<Self>();
if context
.register_shared_ptr(ptr.cast(), type_id)
.map_err(SharedPointerError::ContextError)?
{
context
.bounds_check_subtree_ptr(ptr)
.map_err(SharedPointerError::ContextError)?;
let range = context
.push_prefix_subtree(ptr)
.map_err(SharedPointerError::ContextError)?;
T::check_bytes(ptr, context).map_err(SharedPointerError::ValueCheckBytesError)?;
context
.pop_prefix_range(range)
.map_err(SharedPointerError::ContextError)?;
}
Ok(&*value)
}
}
impl ArchivedRcWeakTag {
const TAG_NONE: u8 = ArchivedRcWeakTag::None as u8;
const TAG_SOME: u8 = ArchivedRcWeakTag::Some as u8;
}
impl<T, F, C> CheckBytes<C> for ArchivedRcWeak<T, F>
where
T: ArchivePointee + CheckBytes<C> + LayoutRaw + Pointee + ?Sized + 'static,
C: ArchiveContext + SharedContext + ?Sized,
T::ArchivedMetadata: CheckBytes<C>,
C::Error: Error,
F: 'static,
{
type Error =
WeakPointerError<<T::ArchivedMetadata as CheckBytes<C>>::Error, T::Error, C::Error>;
#[inline]
unsafe fn check_bytes<'a>(
value: *const Self,
context: &mut C,
) -> Result<&'a Self, Self::Error> {
let tag = *u8::check_bytes(value.cast::<u8>(), context)?;
match tag {
ArchivedRcWeakTag::TAG_NONE => (),
ArchivedRcWeakTag::TAG_SOME => {
let value = value.cast::<ArchivedRcWeakVariantSome<T, F>>();
ArchivedRc::<T, F>::check_bytes(ptr::addr_of!((*value).1), context)
.map_err(WeakPointerError::CheckBytes)?;
}
_ => return Err(WeakPointerError::InvalidTag(tag)),
}
Ok(&*value)
}
}