bytecheck_derive/
lib.rs

1//! Procedural macros for bytecheck.
2
3#![deny(
4    rust_2018_compatibility,
5    rust_2018_idioms,
6    future_incompatible,
7    nonstandard_style,
8    unused,
9    clippy::all
10)]
11
12use proc_macro2::{Group, Span, TokenStream, TokenTree};
13use quote::{quote, quote_spanned};
14use syn::{
15    parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, AttrStyle, Data,
16    DeriveInput, Error, Fields, Ident, Index, Lit, LitStr, Meta, NestedMeta, Path, Token,
17    WherePredicate,
18};
19
20#[derive(Default)]
21struct Repr {
22    pub transparent: Option<Path>,
23    pub packed: Option<Path>,
24    pub c: Option<Path>,
25    pub int: Option<Path>,
26}
27
28#[derive(Default)]
29struct Attributes {
30    pub repr: Repr,
31    pub bound: Option<LitStr>,
32    pub bytecheck_crate: Option<Path>,
33}
34
35fn parse_check_bytes_attributes(attributes: &mut Attributes, meta: &Meta) -> Result<(), Error> {
36    match meta {
37        Meta::NameValue(meta) => {
38            if meta.path.is_ident("bound") {
39                if let Lit::Str(ref lit_str) = meta.lit {
40                    if attributes.bound.is_none() {
41                        attributes.bound = Some(lit_str.clone());
42                        Ok(())
43                    } else {
44                        Err(Error::new_spanned(
45                            meta,
46                            "check_bytes bound already specified",
47                        ))
48                    }
49                } else {
50                    Err(Error::new_spanned(
51                        &meta.lit,
52                        "bound arguments must be a string",
53                    ))
54                }
55            } else if meta.path.is_ident("crate") {
56                if let Lit::Str(ref lit_str) = meta.lit {
57                    if attributes.bytecheck_crate.is_none() {
58                        let tokens = respan(syn::parse_str(&lit_str.value())?, lit_str.span());
59                        let parsed: Path = syn::parse2(tokens)?;
60                        attributes.bytecheck_crate = Some(parsed);
61                        Ok(())
62                    } else {
63                        Err(Error::new_spanned(
64                            meta,
65                            "check_bytes crate already specified",
66                        ))
67                    }
68                } else {
69                    Err(Error::new_spanned(
70                        &meta.lit,
71                        "crate argument must be a string",
72                    ))
73                }
74            } else {
75                Err(Error::new_spanned(
76                    &meta.path,
77                    "unrecognized check_bytes argument",
78                ))
79            }
80        }
81        _ => Err(Error::new_spanned(
82            meta,
83            "unrecognized check_bytes argument",
84        )),
85    }
86}
87
88fn parse_attributes(input: &DeriveInput) -> Result<Attributes, Error> {
89    let mut result = Attributes::default();
90    for a in input.attrs.iter() {
91        if let AttrStyle::Outer = a.style {
92            if let Ok(Meta::List(meta)) = a.parse_meta() {
93                if meta.path.is_ident("check_bytes") {
94                    for nested in meta.nested.iter() {
95                        if let NestedMeta::Meta(meta) = nested {
96                            parse_check_bytes_attributes(&mut result, meta)?;
97                        } else {
98                            return Err(Error::new_spanned(
99                                nested,
100                                "check_bytes parameters must be metas",
101                            ));
102                        }
103                    }
104                } else if meta.path.is_ident("repr") {
105                    for n in meta.nested.iter() {
106                        if let NestedMeta::Meta(Meta::Path(path)) = n {
107                            if path.is_ident("transparent") {
108                                result.repr.transparent = Some(path.clone());
109                            } else if path.is_ident("packed") {
110                                result.repr.packed = Some(path.clone());
111                            } else if path.is_ident("C") {
112                                result.repr.c = Some(path.clone());
113                            } else if path.is_ident("align") {
114                                // Ignore alignment modifiers
115                            } else {
116                                let is_int_repr = path.is_ident("i8")
117                                    || path.is_ident("i16")
118                                    || path.is_ident("i32")
119                                    || path.is_ident("i64")
120                                    || path.is_ident("i128")
121                                    || path.is_ident("u8")
122                                    || path.is_ident("u16")
123                                    || path.is_ident("u32")
124                                    || path.is_ident("u64")
125                                    || path.is_ident("u128");
126
127                                if is_int_repr {
128                                    result.repr.int = Some(path.clone());
129                                } else {
130                                    return Err(Error::new_spanned(
131                                        path,
132                                        "invalid repr, available reprs are transparent, C, i* and u*",
133                                    ));
134                                }
135                            }
136                        }
137                    }
138                }
139            }
140        }
141    }
142    Ok(result)
143}
144
145/// Derives `CheckBytes` for the labeled type.
146///
147/// Additional arguments can be specified using the `#[check_bytes(...)]` attribute:
148///
149/// - `bound = "..."`: Adds additional bounds to the `CheckBytes` implementation. This can be
150///   especially useful when dealing with recursive structures, where bounds may need to be omitted
151///   to prevent recursive type definitions.
152///
153/// This derive macro automatically adds a type bound `field: CheckBytes<__C>` for each field type.
154/// This can cause an overflow while evaluating trait bounds if the structure eventually references
155/// its own type, as the implementation of `CheckBytes` for a struct depends on each field type
156/// implementing it as well. Adding the attribute `#[omit_bounds]` to a field will suppress this
157/// trait bound and allow recursive structures. This may be too coarse for some types, in which case
158/// additional type bounds may be required with `bound = "..."`.
159#[proc_macro_derive(CheckBytes, attributes(check_bytes, omit_bounds))]
160pub fn check_bytes_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
161    match derive_check_bytes(parse_macro_input!(input as DeriveInput)) {
162        Ok(result) => result.into(),
163        Err(e) => e.to_compile_error().into(),
164    }
165}
166
167fn derive_check_bytes(mut input: DeriveInput) -> Result<TokenStream, Error> {
168    let attributes = parse_attributes(&input)?;
169
170    let mut impl_input_generics = input.generics.clone();
171    let impl_where_clause = impl_input_generics.make_where_clause();
172    if let Some(ref bounds) = attributes.bound {
173        let clauses =
174            bounds.parse_with(Punctuated::<WherePredicate, Token![,]>::parse_terminated)?;
175        for clause in clauses {
176            impl_where_clause.predicates.push(clause);
177        }
178    }
179    impl_input_generics
180        .params
181        .insert(0, parse_quote! { __C: ?Sized });
182
183    let name = &input.ident;
184
185    let (impl_generics, _, impl_where_clause) = impl_input_generics.split_for_impl();
186    let impl_where_clause = impl_where_clause.unwrap();
187
188    input.generics.make_where_clause();
189    let (struct_generics, ty_generics, where_clause) = input.generics.split_for_impl();
190    let where_clause = where_clause.unwrap();
191
192    let check_bytes_impl = match input.data {
193        Data::Struct(ref data) => match data.fields {
194            Fields::Named(ref fields) => {
195                let mut check_where = impl_where_clause.clone();
196                for field in fields
197                    .named
198                    .iter()
199                    .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
200                {
201                    let ty = &field.ty;
202                    check_where
203                        .predicates
204                        .push(parse_quote! { #ty: CheckBytes<__C> });
205                }
206
207                let field_checks = fields.named.iter().map(|f| {
208                    let field = &f.ident;
209                    let ty = &f.ty;
210                    quote_spanned! { ty.span() =>
211                        <#ty as CheckBytes<__C>>::check_bytes(
212                            ::core::ptr::addr_of!((*value).#field),
213                            context
214                        ).map_err(|e| StructCheckError {
215                            field_name: stringify!(#field),
216                            inner: ErrorBox::new(e),
217                        })?;
218                    }
219                });
220
221                quote! {
222                    #[automatically_derived]
223                    impl #impl_generics CheckBytes<__C> for #name #ty_generics #check_where {
224                        type Error = StructCheckError;
225
226                        unsafe fn check_bytes<'__bytecheck>(
227                            value: *const Self,
228                            context: &mut __C,
229                        ) -> ::core::result::Result<&'__bytecheck Self, StructCheckError> {
230                            let bytes = value.cast::<u8>();
231                            #(#field_checks)*
232                            Ok(&*value)
233                        }
234                    }
235                }
236            }
237            Fields::Unnamed(ref fields) => {
238                let mut check_where = impl_where_clause.clone();
239                for field in fields
240                    .unnamed
241                    .iter()
242                    .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
243                {
244                    let ty = &field.ty;
245                    check_where
246                        .predicates
247                        .push(parse_quote! { #ty: CheckBytes<__C> });
248                }
249
250                let field_checks = fields.unnamed.iter().enumerate().map(|(i, f)| {
251                    let ty = &f.ty;
252                    let index = Index::from(i);
253                    quote_spanned! { ty.span() =>
254                        <#ty as CheckBytes<__C>>::check_bytes(
255                            ::core::ptr::addr_of!((*value).#index),
256                            context
257                        ).map_err(|e| TupleStructCheckError {
258                            field_index: #i,
259                            inner: ErrorBox::new(e),
260                        })?;
261                    }
262                });
263
264                quote! {
265                    #[automatically_derived]
266                    impl #impl_generics CheckBytes<__C> for #name #ty_generics #check_where {
267                        type Error = TupleStructCheckError;
268
269                        unsafe fn check_bytes<'__bytecheck>(
270                            value: *const Self,
271                            context: &mut __C,
272                        ) -> ::core::result::Result<&'__bytecheck Self, TupleStructCheckError> {
273                            let bytes = value.cast::<u8>();
274                            #(#field_checks)*
275                            Ok(&*value)
276                        }
277                    }
278                }
279            }
280            Fields::Unit => {
281                quote! {
282                    #[automatically_derived]
283                    impl #impl_generics CheckBytes<__C> for #name #ty_generics #impl_where_clause {
284                        type Error = Infallible;
285
286                        unsafe fn check_bytes<'__bytecheck>(
287                            value: *const Self,
288                            context: &mut __C,
289                        ) -> ::core::result::Result<&'__bytecheck Self, Infallible> {
290                            Ok(&*value)
291                        }
292                    }
293                }
294            }
295        },
296        Data::Enum(ref data) => {
297            if let Some(path) = attributes.repr.transparent.or(attributes.repr.packed) {
298                return Err(Error::new_spanned(
299                    path,
300                    "enums implementing CheckBytes cannot be repr(transparent) or repr(packed)",
301                ));
302            }
303
304            let repr = match attributes.repr.int {
305                None => {
306                    return Err(Error::new(
307                        input.span(),
308                        "enums implementing CheckBytes must be repr(Int)",
309                    ));
310                }
311                Some(ref repr) => repr,
312            };
313
314            let mut check_where = impl_where_clause.clone();
315            for v in data.variants.iter() {
316                match v.fields {
317                    Fields::Named(ref fields) => {
318                        for field in fields
319                            .named
320                            .iter()
321                            .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
322                        {
323                            let ty = &field.ty;
324                            check_where
325                                .predicates
326                                .push(parse_quote! { #ty: CheckBytes<__C> });
327                        }
328                    }
329                    Fields::Unnamed(ref fields) => {
330                        for field in fields
331                            .unnamed
332                            .iter()
333                            .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
334                        {
335                            let ty = &field.ty;
336                            check_where
337                                .predicates
338                                .push(parse_quote! { #ty: CheckBytes<__C> });
339                        }
340                    }
341                    Fields::Unit => (),
342                }
343            }
344
345            let tag_variant_defs = data.variants.iter().map(|v| {
346                let variant = &v.ident;
347                if let Some((_, expr)) = &v.discriminant {
348                    quote_spanned! { variant.span() => #variant = #expr }
349                } else {
350                    quote_spanned! { variant.span() => #variant }
351                }
352            });
353
354            let discriminant_const_defs = data.variants.iter().map(|v| {
355                let variant = &v.ident;
356                quote! {
357                    #[allow(non_upper_case_globals)]
358                    const #variant: #repr = Tag::#variant as #repr;
359                }
360            });
361
362            let tag_variant_values = data.variants.iter().map(|v| {
363                let name = &v.ident;
364                quote_spanned! { name.span() => Discriminant::#name }
365            });
366
367            let variant_structs = data.variants.iter().map(|v| {
368                let variant = &v.ident;
369                let variant_name = Ident::new(&format!("Variant{}", variant), v.span());
370                match v.fields {
371                    Fields::Named(ref fields) => {
372                        let fields = fields.named.iter().map(|f| {
373                            let name = &f.ident;
374                            let ty = &f.ty;
375                            quote_spanned! { f.span() => #name: #ty }
376                        });
377                        quote_spanned! { name.span() =>
378                            #[repr(C)]
379                            struct #variant_name #struct_generics #where_clause {
380                                __tag: Tag,
381                                #(#fields,)*
382                                __phantom: PhantomData<#name #ty_generics>,
383                            }
384                        }
385                    }
386                    Fields::Unnamed(ref fields) => {
387                        let fields = fields.unnamed.iter().map(|f| {
388                            let ty = &f.ty;
389                            quote_spanned! { f.span() => #ty }
390                        });
391                        quote_spanned! { name.span() =>
392                            #[repr(C)]
393                            struct #variant_name #struct_generics (
394                                Tag,
395                                #(#fields,)*
396                                PhantomData<#name #ty_generics>
397                            ) #where_clause;
398                        }
399                    }
400                    Fields::Unit => quote! {},
401                }
402            });
403
404            let check_arms = data.variants.iter().map(|v| {
405                let variant = &v.ident;
406                let variant_name = Ident::new(&format!("Variant{}", variant), v.span());
407                match v.fields {
408                    Fields::Named(ref fields) => {
409                        let checks = fields.named.iter().map(|f| {
410                            let name = &f.ident;
411                            let ty = &f.ty;
412                            quote! {
413                                <#ty as CheckBytes<__C>>::check_bytes(
414                                    ::core::ptr::addr_of!((*value).#name),
415                                    context
416                                ).map_err(|e| EnumCheckError::InvalidStruct {
417                                    variant_name: stringify!(#variant),
418                                    inner: StructCheckError {
419                                        field_name: stringify!(#name),
420                                        inner: ErrorBox::new(e),
421                                    },
422                                })?;
423                            }
424                        });
425                        quote_spanned! { variant.span() => {
426                            let value = value.cast::<#variant_name #ty_generics>();
427                            #(#checks)*
428                        } }
429                    }
430                    Fields::Unnamed(ref fields) => {
431                        let checks = fields.unnamed.iter().enumerate().map(|(i, f)| {
432                            let ty = &f.ty;
433                            let index = Index::from(i + 1);
434                            quote! {
435                                <#ty as CheckBytes<__C>>::check_bytes(
436                                    ::core::ptr::addr_of!((*value).#index),
437                                    context
438                                ).map_err(|e| EnumCheckError::InvalidTuple {
439                                    variant_name: stringify!(#variant),
440                                    inner: TupleStructCheckError {
441                                        field_index: #i,
442                                        inner: ErrorBox::new(e),
443                                    },
444                                })?;
445                            }
446                        });
447                        quote_spanned! { variant.span() => {
448                            let value = value.cast::<#variant_name #ty_generics>();
449                            #(#checks)*
450                        } }
451                    }
452                    Fields::Unit => quote_spanned! { name.span() => (), },
453                }
454            });
455
456            quote! {
457                #[repr(#repr)]
458                enum Tag {
459                    #(#tag_variant_defs,)*
460                }
461
462                struct Discriminant;
463
464                #[automatically_derived]
465                impl Discriminant {
466                    #(#discriminant_const_defs)*
467                }
468
469                #(#variant_structs)*
470
471                #[automatically_derived]
472                impl #impl_generics CheckBytes<__C> for #name #ty_generics #check_where {
473                    type Error = EnumCheckError<#repr>;
474
475                    unsafe fn check_bytes<'__bytecheck>(
476                        value: *const Self,
477                        context: &mut __C,
478                    ) -> ::core::result::Result<&'__bytecheck Self, EnumCheckError<#repr>> {
479                        let tag = *value.cast::<#repr>();
480                        match tag {
481                            #(#tag_variant_values => #check_arms)*
482                            _ => return Err(EnumCheckError::InvalidTag(tag)),
483                        }
484                        Ok(&*value)
485                    }
486                }
487            }
488        }
489        Data::Union(_) => {
490            return Err(Error::new(
491                input.span(),
492                "CheckBytes cannot be derived for unions",
493            ));
494        }
495    };
496
497    // Default to `bytecheck`, rather than `::bytecheck`,
498    // to allow providing it from a reexport, e.g. `use rkyv::bytecheck;`.
499    let bytecheck_crate = attributes
500        .bytecheck_crate
501        .unwrap_or(parse_quote!(bytecheck));
502
503    Ok(quote! {
504        #[allow(unused_results)]
505        const _: () = {
506            use ::core::{convert::Infallible, marker::PhantomData};
507            use #bytecheck_crate::{
508                CheckBytes,
509                EnumCheckError,
510                ErrorBox,
511                StructCheckError,
512                TupleStructCheckError,
513            };
514
515            #check_bytes_impl
516        };
517    })
518}
519
520fn respan(stream: TokenStream, span: Span) -> TokenStream {
521    stream
522        .into_iter()
523        .map(|token| respan_token(token, span))
524        .collect()
525}
526
527fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
528    if let TokenTree::Group(g) = &mut token {
529        *g = Group::new(g.delimiter(), respan(g.stream(), span));
530    }
531    token.set_span(span);
532    token
533}