1#![deny(
4 rust_2018_compatibility,
5 rust_2018_idioms,
6 future_incompatible,
7 nonstandard_style,
8 unused,
9 clippy::all
10)]
11
12use proc_macro2::{Group, Span, TokenStream, TokenTree};
13use quote::{quote, quote_spanned};
14use syn::{
15 parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, AttrStyle, Data,
16 DeriveInput, Error, Fields, Ident, Index, Lit, LitStr, Meta, NestedMeta, Path, Token,
17 WherePredicate,
18};
19
20#[derive(Default)]
21struct Repr {
22 pub transparent: Option<Path>,
23 pub packed: Option<Path>,
24 pub c: Option<Path>,
25 pub int: Option<Path>,
26}
27
28#[derive(Default)]
29struct Attributes {
30 pub repr: Repr,
31 pub bound: Option<LitStr>,
32 pub bytecheck_crate: Option<Path>,
33}
34
35fn parse_check_bytes_attributes(attributes: &mut Attributes, meta: &Meta) -> Result<(), Error> {
36 match meta {
37 Meta::NameValue(meta) => {
38 if meta.path.is_ident("bound") {
39 if let Lit::Str(ref lit_str) = meta.lit {
40 if attributes.bound.is_none() {
41 attributes.bound = Some(lit_str.clone());
42 Ok(())
43 } else {
44 Err(Error::new_spanned(
45 meta,
46 "check_bytes bound already specified",
47 ))
48 }
49 } else {
50 Err(Error::new_spanned(
51 &meta.lit,
52 "bound arguments must be a string",
53 ))
54 }
55 } else if meta.path.is_ident("crate") {
56 if let Lit::Str(ref lit_str) = meta.lit {
57 if attributes.bytecheck_crate.is_none() {
58 let tokens = respan(syn::parse_str(&lit_str.value())?, lit_str.span());
59 let parsed: Path = syn::parse2(tokens)?;
60 attributes.bytecheck_crate = Some(parsed);
61 Ok(())
62 } else {
63 Err(Error::new_spanned(
64 meta,
65 "check_bytes crate already specified",
66 ))
67 }
68 } else {
69 Err(Error::new_spanned(
70 &meta.lit,
71 "crate argument must be a string",
72 ))
73 }
74 } else {
75 Err(Error::new_spanned(
76 &meta.path,
77 "unrecognized check_bytes argument",
78 ))
79 }
80 }
81 _ => Err(Error::new_spanned(
82 meta,
83 "unrecognized check_bytes argument",
84 )),
85 }
86}
87
88fn parse_attributes(input: &DeriveInput) -> Result<Attributes, Error> {
89 let mut result = Attributes::default();
90 for a in input.attrs.iter() {
91 if let AttrStyle::Outer = a.style {
92 if let Ok(Meta::List(meta)) = a.parse_meta() {
93 if meta.path.is_ident("check_bytes") {
94 for nested in meta.nested.iter() {
95 if let NestedMeta::Meta(meta) = nested {
96 parse_check_bytes_attributes(&mut result, meta)?;
97 } else {
98 return Err(Error::new_spanned(
99 nested,
100 "check_bytes parameters must be metas",
101 ));
102 }
103 }
104 } else if meta.path.is_ident("repr") {
105 for n in meta.nested.iter() {
106 if let NestedMeta::Meta(Meta::Path(path)) = n {
107 if path.is_ident("transparent") {
108 result.repr.transparent = Some(path.clone());
109 } else if path.is_ident("packed") {
110 result.repr.packed = Some(path.clone());
111 } else if path.is_ident("C") {
112 result.repr.c = Some(path.clone());
113 } else if path.is_ident("align") {
114 } else {
116 let is_int_repr = path.is_ident("i8")
117 || path.is_ident("i16")
118 || path.is_ident("i32")
119 || path.is_ident("i64")
120 || path.is_ident("i128")
121 || path.is_ident("u8")
122 || path.is_ident("u16")
123 || path.is_ident("u32")
124 || path.is_ident("u64")
125 || path.is_ident("u128");
126
127 if is_int_repr {
128 result.repr.int = Some(path.clone());
129 } else {
130 return Err(Error::new_spanned(
131 path,
132 "invalid repr, available reprs are transparent, C, i* and u*",
133 ));
134 }
135 }
136 }
137 }
138 }
139 }
140 }
141 }
142 Ok(result)
143}
144
145#[proc_macro_derive(CheckBytes, attributes(check_bytes, omit_bounds))]
160pub fn check_bytes_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
161 match derive_check_bytes(parse_macro_input!(input as DeriveInput)) {
162 Ok(result) => result.into(),
163 Err(e) => e.to_compile_error().into(),
164 }
165}
166
167fn derive_check_bytes(mut input: DeriveInput) -> Result<TokenStream, Error> {
168 let attributes = parse_attributes(&input)?;
169
170 let mut impl_input_generics = input.generics.clone();
171 let impl_where_clause = impl_input_generics.make_where_clause();
172 if let Some(ref bounds) = attributes.bound {
173 let clauses =
174 bounds.parse_with(Punctuated::<WherePredicate, Token![,]>::parse_terminated)?;
175 for clause in clauses {
176 impl_where_clause.predicates.push(clause);
177 }
178 }
179 impl_input_generics
180 .params
181 .insert(0, parse_quote! { __C: ?Sized });
182
183 let name = &input.ident;
184
185 let (impl_generics, _, impl_where_clause) = impl_input_generics.split_for_impl();
186 let impl_where_clause = impl_where_clause.unwrap();
187
188 input.generics.make_where_clause();
189 let (struct_generics, ty_generics, where_clause) = input.generics.split_for_impl();
190 let where_clause = where_clause.unwrap();
191
192 let check_bytes_impl = match input.data {
193 Data::Struct(ref data) => match data.fields {
194 Fields::Named(ref fields) => {
195 let mut check_where = impl_where_clause.clone();
196 for field in fields
197 .named
198 .iter()
199 .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
200 {
201 let ty = &field.ty;
202 check_where
203 .predicates
204 .push(parse_quote! { #ty: CheckBytes<__C> });
205 }
206
207 let field_checks = fields.named.iter().map(|f| {
208 let field = &f.ident;
209 let ty = &f.ty;
210 quote_spanned! { ty.span() =>
211 <#ty as CheckBytes<__C>>::check_bytes(
212 ::core::ptr::addr_of!((*value).#field),
213 context
214 ).map_err(|e| StructCheckError {
215 field_name: stringify!(#field),
216 inner: ErrorBox::new(e),
217 })?;
218 }
219 });
220
221 quote! {
222 #[automatically_derived]
223 impl #impl_generics CheckBytes<__C> for #name #ty_generics #check_where {
224 type Error = StructCheckError;
225
226 unsafe fn check_bytes<'__bytecheck>(
227 value: *const Self,
228 context: &mut __C,
229 ) -> ::core::result::Result<&'__bytecheck Self, StructCheckError> {
230 let bytes = value.cast::<u8>();
231 #(#field_checks)*
232 Ok(&*value)
233 }
234 }
235 }
236 }
237 Fields::Unnamed(ref fields) => {
238 let mut check_where = impl_where_clause.clone();
239 for field in fields
240 .unnamed
241 .iter()
242 .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
243 {
244 let ty = &field.ty;
245 check_where
246 .predicates
247 .push(parse_quote! { #ty: CheckBytes<__C> });
248 }
249
250 let field_checks = fields.unnamed.iter().enumerate().map(|(i, f)| {
251 let ty = &f.ty;
252 let index = Index::from(i);
253 quote_spanned! { ty.span() =>
254 <#ty as CheckBytes<__C>>::check_bytes(
255 ::core::ptr::addr_of!((*value).#index),
256 context
257 ).map_err(|e| TupleStructCheckError {
258 field_index: #i,
259 inner: ErrorBox::new(e),
260 })?;
261 }
262 });
263
264 quote! {
265 #[automatically_derived]
266 impl #impl_generics CheckBytes<__C> for #name #ty_generics #check_where {
267 type Error = TupleStructCheckError;
268
269 unsafe fn check_bytes<'__bytecheck>(
270 value: *const Self,
271 context: &mut __C,
272 ) -> ::core::result::Result<&'__bytecheck Self, TupleStructCheckError> {
273 let bytes = value.cast::<u8>();
274 #(#field_checks)*
275 Ok(&*value)
276 }
277 }
278 }
279 }
280 Fields::Unit => {
281 quote! {
282 #[automatically_derived]
283 impl #impl_generics CheckBytes<__C> for #name #ty_generics #impl_where_clause {
284 type Error = Infallible;
285
286 unsafe fn check_bytes<'__bytecheck>(
287 value: *const Self,
288 context: &mut __C,
289 ) -> ::core::result::Result<&'__bytecheck Self, Infallible> {
290 Ok(&*value)
291 }
292 }
293 }
294 }
295 },
296 Data::Enum(ref data) => {
297 if let Some(path) = attributes.repr.transparent.or(attributes.repr.packed) {
298 return Err(Error::new_spanned(
299 path,
300 "enums implementing CheckBytes cannot be repr(transparent) or repr(packed)",
301 ));
302 }
303
304 let repr = match attributes.repr.int {
305 None => {
306 return Err(Error::new(
307 input.span(),
308 "enums implementing CheckBytes must be repr(Int)",
309 ));
310 }
311 Some(ref repr) => repr,
312 };
313
314 let mut check_where = impl_where_clause.clone();
315 for v in data.variants.iter() {
316 match v.fields {
317 Fields::Named(ref fields) => {
318 for field in fields
319 .named
320 .iter()
321 .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
322 {
323 let ty = &field.ty;
324 check_where
325 .predicates
326 .push(parse_quote! { #ty: CheckBytes<__C> });
327 }
328 }
329 Fields::Unnamed(ref fields) => {
330 for field in fields
331 .unnamed
332 .iter()
333 .filter(|f| !f.attrs.iter().any(|a| a.path.is_ident("omit_bounds")))
334 {
335 let ty = &field.ty;
336 check_where
337 .predicates
338 .push(parse_quote! { #ty: CheckBytes<__C> });
339 }
340 }
341 Fields::Unit => (),
342 }
343 }
344
345 let tag_variant_defs = data.variants.iter().map(|v| {
346 let variant = &v.ident;
347 if let Some((_, expr)) = &v.discriminant {
348 quote_spanned! { variant.span() => #variant = #expr }
349 } else {
350 quote_spanned! { variant.span() => #variant }
351 }
352 });
353
354 let discriminant_const_defs = data.variants.iter().map(|v| {
355 let variant = &v.ident;
356 quote! {
357 #[allow(non_upper_case_globals)]
358 const #variant: #repr = Tag::#variant as #repr;
359 }
360 });
361
362 let tag_variant_values = data.variants.iter().map(|v| {
363 let name = &v.ident;
364 quote_spanned! { name.span() => Discriminant::#name }
365 });
366
367 let variant_structs = data.variants.iter().map(|v| {
368 let variant = &v.ident;
369 let variant_name = Ident::new(&format!("Variant{}", variant), v.span());
370 match v.fields {
371 Fields::Named(ref fields) => {
372 let fields = fields.named.iter().map(|f| {
373 let name = &f.ident;
374 let ty = &f.ty;
375 quote_spanned! { f.span() => #name: #ty }
376 });
377 quote_spanned! { name.span() =>
378 #[repr(C)]
379 struct #variant_name #struct_generics #where_clause {
380 __tag: Tag,
381 #(#fields,)*
382 __phantom: PhantomData<#name #ty_generics>,
383 }
384 }
385 }
386 Fields::Unnamed(ref fields) => {
387 let fields = fields.unnamed.iter().map(|f| {
388 let ty = &f.ty;
389 quote_spanned! { f.span() => #ty }
390 });
391 quote_spanned! { name.span() =>
392 #[repr(C)]
393 struct #variant_name #struct_generics (
394 Tag,
395 #(#fields,)*
396 PhantomData<#name #ty_generics>
397 ) #where_clause;
398 }
399 }
400 Fields::Unit => quote! {},
401 }
402 });
403
404 let check_arms = data.variants.iter().map(|v| {
405 let variant = &v.ident;
406 let variant_name = Ident::new(&format!("Variant{}", variant), v.span());
407 match v.fields {
408 Fields::Named(ref fields) => {
409 let checks = fields.named.iter().map(|f| {
410 let name = &f.ident;
411 let ty = &f.ty;
412 quote! {
413 <#ty as CheckBytes<__C>>::check_bytes(
414 ::core::ptr::addr_of!((*value).#name),
415 context
416 ).map_err(|e| EnumCheckError::InvalidStruct {
417 variant_name: stringify!(#variant),
418 inner: StructCheckError {
419 field_name: stringify!(#name),
420 inner: ErrorBox::new(e),
421 },
422 })?;
423 }
424 });
425 quote_spanned! { variant.span() => {
426 let value = value.cast::<#variant_name #ty_generics>();
427 #(#checks)*
428 } }
429 }
430 Fields::Unnamed(ref fields) => {
431 let checks = fields.unnamed.iter().enumerate().map(|(i, f)| {
432 let ty = &f.ty;
433 let index = Index::from(i + 1);
434 quote! {
435 <#ty as CheckBytes<__C>>::check_bytes(
436 ::core::ptr::addr_of!((*value).#index),
437 context
438 ).map_err(|e| EnumCheckError::InvalidTuple {
439 variant_name: stringify!(#variant),
440 inner: TupleStructCheckError {
441 field_index: #i,
442 inner: ErrorBox::new(e),
443 },
444 })?;
445 }
446 });
447 quote_spanned! { variant.span() => {
448 let value = value.cast::<#variant_name #ty_generics>();
449 #(#checks)*
450 } }
451 }
452 Fields::Unit => quote_spanned! { name.span() => (), },
453 }
454 });
455
456 quote! {
457 #[repr(#repr)]
458 enum Tag {
459 #(#tag_variant_defs,)*
460 }
461
462 struct Discriminant;
463
464 #[automatically_derived]
465 impl Discriminant {
466 #(#discriminant_const_defs)*
467 }
468
469 #(#variant_structs)*
470
471 #[automatically_derived]
472 impl #impl_generics CheckBytes<__C> for #name #ty_generics #check_where {
473 type Error = EnumCheckError<#repr>;
474
475 unsafe fn check_bytes<'__bytecheck>(
476 value: *const Self,
477 context: &mut __C,
478 ) -> ::core::result::Result<&'__bytecheck Self, EnumCheckError<#repr>> {
479 let tag = *value.cast::<#repr>();
480 match tag {
481 #(#tag_variant_values => #check_arms)*
482 _ => return Err(EnumCheckError::InvalidTag(tag)),
483 }
484 Ok(&*value)
485 }
486 }
487 }
488 }
489 Data::Union(_) => {
490 return Err(Error::new(
491 input.span(),
492 "CheckBytes cannot be derived for unions",
493 ));
494 }
495 };
496
497 let bytecheck_crate = attributes
500 .bytecheck_crate
501 .unwrap_or(parse_quote!(bytecheck));
502
503 Ok(quote! {
504 #[allow(unused_results)]
505 const _: () = {
506 use ::core::{convert::Infallible, marker::PhantomData};
507 use #bytecheck_crate::{
508 CheckBytes,
509 EnumCheckError,
510 ErrorBox,
511 StructCheckError,
512 TupleStructCheckError,
513 };
514
515 #check_bytes_impl
516 };
517 })
518}
519
520fn respan(stream: TokenStream, span: Span) -> TokenStream {
521 stream
522 .into_iter()
523 .map(|token| respan_token(token, span))
524 .collect()
525}
526
527fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
528 if let TokenTree::Group(g) = &mut token {
529 *g = Group::new(g.delimiter(), respan(g.stream(), span));
530 }
531 token.set_span(span);
532 token
533}