enumset_derive/
lib.rs

1#![recursion_limit = "256"]
2
3extern crate proc_macro;
4
5use darling::util::SpannedValue;
6use darling::*;
7use proc_macro::TokenStream;
8use proc_macro2::{Literal, Span, TokenStream as SynTokenStream};
9use quote::*;
10use std::{collections::HashSet, fmt::Display};
11use syn::spanned::Spanned;
12use syn::{Error, Result, *};
13
14/// Helper function for emitting compile errors.
15fn error<T>(span: Span, message: impl Display) -> Result<T> {
16    Err(Error::new(span, message))
17}
18
19/// Decodes the custom attributes for our custom derive.
20#[derive(FromDeriveInput, Default)]
21#[darling(attributes(enumset), default)]
22struct EnumsetAttrs {
23    no_ops: bool,
24    no_super_impls: bool,
25    #[darling(default)]
26    repr: SpannedValue<Option<String>>,
27    #[darling(default)]
28    serialize_repr: SpannedValue<Option<String>>,
29    serialize_deny_unknown: bool,
30    #[darling(default)]
31    crate_name: Option<String>,
32
33    // legacy options
34    serialize_as_list: SpannedValue<bool>, // replaced with serialize_repr
35    serialize_as_map: SpannedValue<bool>,  // replaced with serialize_repr
36}
37
38/// The internal representation of an enumset.
39#[derive(Copy, Clone)]
40enum InternalRepr {
41    /// internal repr: `u8`
42    U8,
43    /// internal repr: `u16`
44    U16,
45    /// internal repr: `u32`
46    U32,
47    /// internal repr: `u64`
48    U64,
49    /// internal repr: `u128`
50    U128,
51    /// internal repr: `[u64; size]`
52    Array(usize),
53}
54impl InternalRepr {
55    /// Determines the number of variants supported by this repr.
56    fn supported_variants(&self) -> usize {
57        match self {
58            InternalRepr::U8 => 8,
59            InternalRepr::U16 => 16,
60            InternalRepr::U32 => 32,
61            InternalRepr::U64 => 64,
62            InternalRepr::U128 => 128,
63            InternalRepr::Array(size) => size * 64,
64        }
65    }
66}
67
68/// The serde representation of the enumset.
69#[derive(Copy, Clone)]
70enum SerdeRepr {
71    /// serde type: `u8`
72    U8,
73    /// serde type: `u16`
74    U16,
75    /// serde type: `u32`
76    U32,
77    /// serde type: `u64`
78    U64,
79    /// serde type: `u128`
80    U128,
81    /// serde type: list of `T`
82    List,
83    /// serde type: map of `T` to `bool`
84    Map,
85    /// serde type: list of `u64`
86    Array,
87}
88impl SerdeRepr {
89    /// Determines the number of variants supported by this repr.
90    fn supported_variants(&self) -> Option<usize> {
91        match self {
92            SerdeRepr::U8 => Some(8),
93            SerdeRepr::U16 => Some(16),
94            SerdeRepr::U32 => Some(32),
95            SerdeRepr::U64 => Some(64),
96            SerdeRepr::U128 => Some(128),
97            SerdeRepr::List => None,
98            SerdeRepr::Map => None,
99            SerdeRepr::Array => None,
100        }
101    }
102}
103
104/// An variant in the enum set type.
105struct EnumSetValue {
106    /// The name of the variant.
107    name: Ident,
108    /// The discriminant of the variant.
109    variant_repr: u32,
110}
111
112/// Stores information about the enum set type.
113#[allow(dead_code)]
114struct EnumSetInfo {
115    /// The name of the enum.
116    name: Ident,
117    /// The crate name to use.
118    crate_name: Option<Ident>,
119    /// The numeric type to represent the `EnumSet` as in memory.
120    explicit_internal_repr: Option<InternalRepr>,
121    /// Forces the internal numeric type of the `EnumSet` to be an array.
122    internal_repr_force_array: bool,
123    /// The numeric type to serialize the enum as.
124    explicit_serde_repr: Option<SerdeRepr>,
125    /// A list of variants in the enum.
126    variants: Vec<EnumSetValue>,
127    /// Visbility
128    vis: Visibility,
129
130    /// The highest encountered variant discriminant.
131    max_discrim: u32,
132    /// The span of the highest encountered variant.
133    max_discrim_span: Option<Span>,
134    /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
135    cur_discrim: u32,
136    /// A list of variant names that are already in use.
137    used_variant_names: HashSet<String>,
138    /// A list of variant discriminants that are already in use.
139    used_discriminants: HashSet<u32>,
140
141    /// Avoid generating operator overloads on the enum type.
142    no_ops: bool,
143    /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
144    no_super_impls: bool,
145    /// Disallow unknown bits while deserializing the enum.
146    serialize_deny_unknown: bool,
147}
148impl EnumSetInfo {
149    fn new(input: &DeriveInput, attrs: &EnumsetAttrs) -> EnumSetInfo {
150        EnumSetInfo {
151            name: input.ident.clone(),
152            crate_name: attrs
153                .crate_name
154                .as_ref()
155                .map(|x| Ident::new(x, Span::call_site())),
156            explicit_internal_repr: None,
157            internal_repr_force_array: false,
158            explicit_serde_repr: None,
159            variants: Vec::new(),
160            vis: input.vis.clone(),
161            max_discrim: 0,
162            max_discrim_span: None,
163            cur_discrim: 0,
164            used_variant_names: HashSet::new(),
165            used_discriminants: HashSet::new(),
166            no_ops: attrs.no_ops,
167            no_super_impls: attrs.no_super_impls,
168            serialize_deny_unknown: attrs.serialize_deny_unknown,
169        }
170    }
171
172    /// Explicits sets the serde representation of the enumset from a string.
173    fn push_serialize_repr(&mut self, span: Span, ty: &str) -> Result<()> {
174        match ty {
175            "u8" => self.explicit_serde_repr = Some(SerdeRepr::U8),
176            "u16" => self.explicit_serde_repr = Some(SerdeRepr::U16),
177            "u32" => self.explicit_serde_repr = Some(SerdeRepr::U32),
178            "u64" => self.explicit_serde_repr = Some(SerdeRepr::U64),
179            "u128" => self.explicit_serde_repr = Some(SerdeRepr::U128),
180            "list" => self.explicit_serde_repr = Some(SerdeRepr::List),
181            "map" => self.explicit_serde_repr = Some(SerdeRepr::Map),
182            "array" => self.explicit_serde_repr = Some(SerdeRepr::Array),
183            _ => error(span, format!("`{ty}` is not a valid serialized representation."))?,
184        }
185        Ok(())
186    }
187
188    /// Explicitly sets the representation of the enumset from a string.
189    fn push_repr(&mut self, span: Span, ty: &str) -> Result<()> {
190        match ty {
191            "u8" => self.explicit_internal_repr = Some(InternalRepr::U8),
192            "u16" => self.explicit_internal_repr = Some(InternalRepr::U16),
193            "u32" => self.explicit_internal_repr = Some(InternalRepr::U32),
194            "u64" => self.explicit_internal_repr = Some(InternalRepr::U64),
195            "u128" => self.explicit_internal_repr = Some(InternalRepr::U128),
196            "array" => self.internal_repr_force_array = true,
197            _ => error(span, format!("`{ty}` is not a valid internal enumset representation."))?,
198        }
199        Ok(())
200    }
201
202    /// Adds a variant to the enumset.
203    fn push_variant(&mut self, variant: &Variant) -> Result<()> {
204        if self.used_variant_names.contains(&variant.ident.to_string()) {
205            error(variant.span(), "Duplicated variant name.")
206        } else if let Fields::Unit = variant.fields {
207            // Parse the discriminant.
208            if let Some((_, expr)) = &variant.discriminant {
209                if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
210                    match i.base10_parse() {
211                        Ok(val) => self.cur_discrim = val,
212                        Err(_) => error(expr.span(), "Enum discriminants must fit into `u32`.")?,
213                    }
214                } else if let Expr::Unary(ExprUnary { op: UnOp::Neg(_), .. }) = expr {
215                    error(expr.span(), "Enum discriminants must not be negative.")?;
216                } else {
217                    error(variant.span(), "Enum discriminants must be literal expressions.")?;
218                }
219            }
220
221            // Validate the discriminant.
222            let discriminant = self.cur_discrim;
223            if discriminant >= 0xFFFFFFC0 {
224                error(variant.span(), "Maximum discriminant allowed is `0xFFFFFFBF`.")?;
225            }
226            if self.used_discriminants.contains(&discriminant) {
227                error(variant.span(), "Duplicated enum discriminant.")?;
228            }
229
230            // Add the variant to the info.
231            self.cur_discrim += 1;
232            if discriminant > self.max_discrim {
233                self.max_discrim = discriminant;
234                self.max_discrim_span = Some(variant.span());
235            }
236            self.variants
237                .push(EnumSetValue { name: variant.ident.clone(), variant_repr: discriminant });
238            self.used_variant_names.insert(variant.ident.to_string());
239            self.used_discriminants.insert(discriminant);
240
241            Ok(())
242        } else {
243            error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
244        }
245    }
246
247    /// Returns the actual internal representation of the set.
248    fn internal_repr(&self) -> InternalRepr {
249        match self.explicit_internal_repr {
250            Some(x) => x,
251            None => match self.max_discrim {
252                x if x < 8 && !self.internal_repr_force_array => InternalRepr::U8,
253                x if x < 16 && !self.internal_repr_force_array => InternalRepr::U16,
254                x if x < 32 && !self.internal_repr_force_array => InternalRepr::U32,
255                x if x < 64 && !self.internal_repr_force_array => InternalRepr::U64,
256                x => InternalRepr::Array((x as usize + 64) / 64),
257            },
258        }
259    }
260
261    /// Returns the actual serde representation of the set.
262    fn serde_repr(&self) -> SerdeRepr {
263        match self.explicit_serde_repr {
264            Some(x) => x,
265            None => match self.max_discrim {
266                x if x < 8 => SerdeRepr::U8,
267                x if x < 16 => SerdeRepr::U16,
268                x if x < 32 => SerdeRepr::U32,
269                x if x < 64 => SerdeRepr::U64,
270                x if x < 128 => SerdeRepr::U128,
271                _ => SerdeRepr::Array,
272            },
273        }
274    }
275
276    /// Validate the enumset type.
277    fn validate(&self) -> Result<()> {
278        // Gets the span of the maximum value.
279        let largest_discriminant_span = match &self.max_discrim_span {
280            Some(x) => *x,
281            None => Span::call_site(),
282        };
283
284        // Check if all bits of the bitset can fit in the memory representation, if one was given.
285        if self.internal_repr().supported_variants() <= self.max_discrim as usize {
286            error(
287                largest_discriminant_span,
288                "`repr` is too small to contain the largest discriminant.",
289            )?;
290        }
291
292        // Check if all bits of the bitset can fit in the serialization representation.
293        if let Some(supported_variants) = self.serde_repr().supported_variants() {
294            if supported_variants <= self.max_discrim as usize {
295                error(
296                    largest_discriminant_span,
297                    "`serialize_repr` is too small to contain the largest discriminant.",
298                )?;
299            }
300        }
301
302        Ok(())
303    }
304
305    /// Returns a bitmask of all variants in the set.
306    fn variant_map(&self) -> Vec<u64> {
307        let mut vec = vec![0];
308        for variant in &self.variants {
309            let (idx, bit) = (variant.variant_repr as usize / 64, variant.variant_repr % 64);
310            while idx >= vec.len() {
311                vec.push(0);
312            }
313            vec[idx] |= 1u64 << bit;
314        }
315        vec
316    }
317}
318
319/// Generates the actual `EnumSetType` impl.
320fn enum_set_type_impl(info: EnumSetInfo, warnings: Vec<(Span, &'static str)>) -> SynTokenStream {
321    let name = &info.name;
322
323    let enumset = match &info.crate_name {
324        Some(crate_name) => quote!(::#crate_name),
325        None => {
326            #[cfg(feature = "proc-macro-crate")]
327            {
328                use proc_macro_crate::FoundCrate;
329
330                let crate_name = proc_macro_crate::crate_name("enumset");
331                match crate_name {
332                    Ok(FoundCrate::Name(name)) => {
333                        let ident = Ident::new(&name, Span::call_site());
334                        quote!(::#ident)
335                    }
336                    _ => quote!(::enumset),
337                }
338            }
339
340            #[cfg(not(feature = "proc-macro-crate"))]
341            {
342                quote!(::enumset)
343            }
344        }
345    };
346    let typed_enumset = quote!(#enumset::EnumSet<#name>);
347    let core = quote!(#enumset::__internal::core_export);
348    let internal = quote!(#enumset::__internal);
349    let serde = quote!(#enumset::__internal::serde);
350
351    let repr = match info.internal_repr() {
352        InternalRepr::U8 => quote! { u8 },
353        InternalRepr::U16 => quote! { u16 },
354        InternalRepr::U32 => quote! { u32 },
355        InternalRepr::U64 => quote! { u64 },
356        InternalRepr::U128 => quote! { u128 },
357        InternalRepr::Array(size) => quote! { #internal::ArrayRepr<{ #size }> },
358    };
359    let variant_map = info.variant_map();
360    let all_variants = match info.internal_repr() {
361        InternalRepr::U8 | InternalRepr::U16 | InternalRepr::U32 | InternalRepr::U64 => {
362            let lit = Literal::u64_unsuffixed(variant_map[0]);
363            quote! { #lit }
364        }
365        InternalRepr::U128 => {
366            let lit = Literal::u128_unsuffixed(
367                variant_map[0] as u128 | variant_map.get(1).map_or(0, |x| (*x as u128) << 64),
368            );
369            quote! { #lit }
370        }
371        InternalRepr::Array(size) => {
372            let mut new = Vec::new();
373            for i in 0..size {
374                new.push(Literal::u64_unsuffixed(*variant_map.get(i).unwrap_or(&0)));
375            }
376            quote! { #internal::ArrayRepr::<{ #size }>([#(#new,)*]) }
377        }
378    };
379
380    let ops = if info.no_ops {
381        quote! {}
382    } else {
383        quote! {
384            #[automatically_derived]
385            impl<O: Into<#typed_enumset>> #core::ops::Sub<O> for #name {
386                type Output = #typed_enumset;
387                fn sub(self, other: O) -> Self::Output {
388                    #enumset::EnumSet::only(self) - other.into()
389                }
390            }
391            #[automatically_derived]
392            impl<O: Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
393                type Output = #typed_enumset;
394                fn bitand(self, other: O) -> Self::Output {
395                    #enumset::EnumSet::only(self) & other.into()
396                }
397            }
398            #[automatically_derived]
399            impl<O: Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
400                type Output = #typed_enumset;
401                fn bitor(self, other: O) -> Self::Output {
402                    #enumset::EnumSet::only(self) | other.into()
403                }
404            }
405            #[automatically_derived]
406            impl<O: Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
407                type Output = #typed_enumset;
408                fn bitxor(self, other: O) -> Self::Output {
409                    #enumset::EnumSet::only(self) ^ other.into()
410                }
411            }
412            #[automatically_derived]
413            impl #core::ops::Not for #name {
414                type Output = #typed_enumset;
415                fn not(self) -> Self::Output {
416                    !#enumset::EnumSet::only(self)
417                }
418            }
419            #[automatically_derived]
420            impl #core::cmp::PartialEq<#typed_enumset> for #name {
421                fn eq(&self, other: &#typed_enumset) -> bool {
422                    #enumset::EnumSet::only(*self) == *other
423                }
424            }
425        }
426    };
427
428    let serde_repr = info.serde_repr();
429    let serde_ops = match serde_repr {
430        SerdeRepr::U8 | SerdeRepr::U16 | SerdeRepr::U32 | SerdeRepr::U64 | SerdeRepr::U128 => {
431            let (serialize_repr, from_fn, to_fn) = match serde_repr {
432                SerdeRepr::U8 => (quote! { u8 }, quote! { from_u8 }, quote! { to_u8 }),
433                SerdeRepr::U16 => (quote! { u16 }, quote! { from_u16 }, quote! { to_u16 }),
434                SerdeRepr::U32 => (quote! { u32 }, quote! { from_u32 }, quote! { to_u32 }),
435                SerdeRepr::U64 => (quote! { u64 }, quote! { from_u64 }, quote! { to_u64 }),
436                SerdeRepr::U128 => (quote! { u128 }, quote! { from_u128 }, quote! { to_u128 }),
437                _ => unreachable!(),
438            };
439            let check_unknown = if info.serialize_deny_unknown {
440                quote! {
441                    if value & !#all_variants != 0 {
442                        use #serde::de::Error;
443                        return #core::prelude::v1::Err(
444                            D::Error::custom("enumset contains unknown bits")
445                        )
446                    }
447                }
448            } else {
449                quote! {}
450            };
451            quote! {
452                fn serialize<S: #serde::Serializer>(
453                    set: #enumset::EnumSet<#name>, ser: S,
454                ) -> #core::result::Result<S::Ok, S::Error> {
455                    let value =
456                        <#repr as #enumset::__internal::EnumSetTypeRepr>::#to_fn(&set.__priv_repr);
457                    #serde::Serialize::serialize(&value, ser)
458                }
459                fn deserialize<'de, D: #serde::Deserializer<'de>>(
460                    de: D,
461                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
462                    let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
463                    #check_unknown
464                    let value = <#repr as #enumset::__internal::EnumSetTypeRepr>::#from_fn(value);
465                    #core::prelude::v1::Ok(#enumset::EnumSet {
466                        __priv_repr: value & #all_variants,
467                    })
468                }
469            }
470        }
471        SerdeRepr::List => {
472            let expecting_str = format!("a list of {name}");
473            quote! {
474                fn serialize<S: #serde::Serializer>(
475                    set: #enumset::EnumSet<#name>, ser: S,
476                ) -> #core::result::Result<S::Ok, S::Error> {
477                    use #serde::ser::SerializeSeq;
478                    let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
479                    for bit in set {
480                        seq.serialize_element(&bit)?;
481                    }
482                    seq.end()
483                }
484                fn deserialize<'de, D: #serde::Deserializer<'de>>(
485                    de: D,
486                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
487                    struct Visitor;
488                    impl <'de> #serde::de::Visitor<'de> for Visitor {
489                        type Value = #enumset::EnumSet<#name>;
490                        fn expecting(
491                            &self, formatter: &mut #core::fmt::Formatter,
492                        ) -> #core::fmt::Result {
493                            write!(formatter, #expecting_str)
494                        }
495                        fn visit_seq<A>(
496                            mut self, mut seq: A,
497                        ) -> #core::result::Result<Self::Value, A::Error> where
498                            A: #serde::de::SeqAccess<'de>
499                        {
500                            let mut accum = #enumset::EnumSet::<#name>::new();
501                            while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
502                                accum |= val;
503                            }
504                            #core::prelude::v1::Ok(accum)
505                        }
506                    }
507                    de.deserialize_seq(Visitor)
508                }
509            }
510        }
511        SerdeRepr::Map => {
512            let expecting_str = format!("a map from {name} to bool");
513            quote! {
514                fn serialize<S: #serde::Serializer>(
515                    set: #enumset::EnumSet<#name>, ser: S,
516                ) -> #core::result::Result<S::Ok, S::Error> {
517                    use #serde::ser::SerializeMap;
518                    let mut map = ser.serialize_map(#core::prelude::v1::Some(set.len()))?;
519                    for bit in set {
520                        map.serialize_entry(&bit, &true)?;
521                    }
522                    map.end()
523                }
524                fn deserialize<'de, D: #serde::Deserializer<'de>>(
525                    de: D,
526                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
527                    struct Visitor;
528                    impl <'de> #serde::de::Visitor<'de> for Visitor {
529                        type Value = #enumset::EnumSet<#name>;
530                        fn expecting(
531                            &self, formatter: &mut #core::fmt::Formatter,
532                        ) -> #core::fmt::Result {
533                            write!(formatter, #expecting_str)
534                        }
535                        fn visit_map<A>(
536                            mut self, mut map: A,
537                        ) -> #core::result::Result<Self::Value, A::Error> where
538                            A: #serde::de::MapAccess<'de>
539                        {
540                            let mut accum = #enumset::EnumSet::<#name>::new();
541                            while let #core::prelude::v1::Some((val, is_present)) =
542                                map.next_entry::<#name, bool>()?
543                            {
544                                if is_present {
545                                    accum |= val;
546                                }
547                            }
548                            #core::prelude::v1::Ok(accum)
549                        }
550                    }
551                    de.deserialize_map(Visitor)
552                }
553            }
554        }
555        SerdeRepr::Array => {
556            let preferred_size = quote! {
557                <<#name as #internal::EnumSetTypePrivate>::Repr as #internal::EnumSetTypeRepr>
558                    ::PREFERRED_ARRAY_LEN
559            };
560            let (check_extra, convert_array) = if info.serialize_deny_unknown {
561                (
562                    quote! {
563                        if _val != 0 {
564                            return #core::prelude::v1::Err(
565                                D::Error::custom("enumset contains unknown bits")
566                            )
567                        }
568                    },
569                    quote! {
570                        match #enumset::EnumSet::<#name>::try_from_array(accum) {
571                            Some(x) => x,
572                            None => #core::prelude::v1::Err(
573                                D::Error::custom("enumset contains unknown bits")
574                            ),
575                        }
576                    },
577                )
578            } else {
579                (quote! {}, quote! {
580                    #core::prelude::v1::Ok(#enumset::EnumSet::<#name>::from_array(accum))
581                })
582            };
583            quote! {
584                fn serialize<S: #serde::Serializer>(
585                    set: #enumset::EnumSet<#name>, ser: S,
586                ) -> #core::result::Result<S::Ok, S::Error> {
587                    // read the enum as an array
588                    let array = set.as_array::<{ #preferred_size }>();
589
590                    // find the last non-zero value in the array
591                    let mut end = array.len();
592                    for i in (0..array.len()).rev() {
593                        if array[i] != 0 {
594                            break;
595                        }
596                        end = i + 1;
597                    }
598
599                    // serialize the array
600                    #serde::Serialize::serialize(&array[..end], ser)
601                }
602                fn deserialize<'de, D: #serde::Deserializer<'de>>(
603                    de: D,
604                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
605                    struct Visitor;
606                    impl <'de> #serde::de::Visitor<'de> for Visitor {
607                        type Value = #enumset::EnumSet<#name>;
608                        fn expecting(
609                            &self, formatter: &mut #core::fmt::Formatter,
610                        ) -> #core::fmt::Result {
611                            write!(formatter, "a list of u64")
612                        }
613                        fn visit_seq<A>(
614                            mut self, mut seq: A,
615                        ) -> #core::result::Result<Self::Value, A::Error> where
616                            A: #serde::de::SeqAccess<'de>
617                        {
618                            let mut accum = [0; #preferred_size];
619
620                            let mut i = 0;
621                            while let #core::prelude::v1::Some(val) = seq.next_element::<u64>()? {
622                                accum[i] = val;
623                                i += 1;
624
625                                if i == accum.len() {
626                                    break;
627                                }
628                            }
629                            while let #core::prelude::v1::Some(_val) = seq.next_element::<u64>()? {
630                                #check_extra
631                            }
632
633                            #convert_array
634                        }
635                    }
636                    de.deserialize_seq(Visitor)
637                }
638            }
639        }
640    };
641
642    let is_uninhabited = info.variants.is_empty();
643    let is_zst = info.variants.len() == 1;
644    let into_impl = if is_uninhabited {
645        quote! {
646            fn enum_into_u32(self) -> u32 {
647                panic!(concat!(stringify!(#name), " is uninhabited."))
648            }
649            unsafe fn enum_from_u32(val: u32) -> Self {
650                panic!(concat!(stringify!(#name), " is uninhabited."))
651            }
652        }
653    } else if is_zst {
654        let variant = &info.variants[0].name;
655        quote! {
656            fn enum_into_u32(self) -> u32 {
657                self as u32
658            }
659            unsafe fn enum_from_u32(val: u32) -> Self {
660                #name::#variant
661            }
662        }
663    } else {
664        let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
665        let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
666
667        let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
668            .iter()
669            .map(|x| Ident::new(x, Span::call_site()))
670            .collect();
671        let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
672            .iter()
673            .map(|x| Ident::new(x, Span::call_site()))
674            .collect();
675
676        quote! {
677            fn enum_into_u32(self) -> u32 {
678                self as u32
679            }
680            unsafe fn enum_from_u32(val: u32) -> Self {
681                // We put these in const fields so the branches they guard aren't generated even
682                // on -O0
683                #(const #const_field: bool =
684                    #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
685                match val {
686                    // Every valid variant value has an explicit branch. If they get optimized out,
687                    // great. If the representation has changed somehow, and they don't, oh well,
688                    // there's still no UB.
689                    #(#variant_value => #name::#variant_name,)*
690                    // Helps hint to the LLVM that this is a transmute. Note that this branch is
691                    // still unreachable.
692                    #(x if #const_field => {
693                        let x = x as #int_type;
694                        *(&x as *const _ as *const #name)
695                    })*
696                    // Default case. Sometimes causes LLVM to generate a table instead of a simple
697                    // transmute, but, oh well.
698                    _ => #core::hint::unreachable_unchecked(),
699                }
700            }
701        }
702    };
703
704    let eq_impl = if is_uninhabited {
705        quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
706    } else {
707        quote!((*self as u32) == (*other as u32))
708    };
709
710    let super_impls = if info.no_super_impls {
711        quote! {}
712    } else {
713        quote! {
714            #[automatically_derived]
715            impl #core::cmp::PartialEq for #name {
716                fn eq(&self, other: &Self) -> bool {
717                    #eq_impl
718                }
719            }
720            #[automatically_derived]
721            impl #core::cmp::Eq for #name { }
722            #[automatically_derived]
723            #[allow(clippy::expl_impl_clone_on_copy)]
724            impl #core::clone::Clone for #name {
725                fn clone(&self) -> Self {
726                    *self
727                }
728            }
729            #[automatically_derived]
730            impl #core::marker::Copy for #name { }
731        }
732    };
733
734    let impl_with_repr = if info.explicit_internal_repr.is_some() {
735        quote! {
736            #[automatically_derived]
737            unsafe impl #enumset::EnumSetTypeWithRepr for #name {
738                type Repr = #repr;
739            }
740        }
741    } else {
742        quote! {}
743    };
744
745    let inherent_impl_blocks = match info.internal_repr() {
746        InternalRepr::U8
747        | InternalRepr::U16
748        | InternalRepr::U32
749        | InternalRepr::U64
750        | InternalRepr::U128 => {
751            let self_as_repr_mask = if is_uninhabited {
752                quote! { 0 } // impossible anyway
753            } else {
754                quote! { 1 << self as #repr }
755            };
756
757            quote! {
758                #[automatically_derived]
759                #[doc(hidden)]
760                impl #name {
761                    /// Creates a new enumset with only this variant.
762                    #[deprecated(note = "This method is an internal implementation detail \
763                                         generated by the `enumset` crate's procedural macro. It \
764                                         should not be used directly.")]
765                    #[doc(hidden)]
766                    pub const fn __impl_enumset_internal__const_only(
767                        self,
768                    ) -> #enumset::EnumSet<#name> {
769                        #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
770                    }
771                }
772
773                #[automatically_derived]
774                #[doc(hidden)]
775                impl __EnumSetConstHelper {
776                    pub const fn const_union(
777                        &self,
778                        chain_a: #enumset::EnumSet<#name>,
779                        chain_b: #enumset::EnumSet<#name>,
780                    ) -> #enumset::EnumSet<#name> {
781                        #enumset::EnumSet {
782                            __priv_repr: chain_a.__priv_repr | chain_b.__priv_repr,
783                        }
784                    }
785
786                    pub const fn const_intersection(
787                        &self,
788                        chain_a: #enumset::EnumSet<#name>,
789                        chain_b: #enumset::EnumSet<#name>,
790                    ) -> #enumset::EnumSet<#name> {
791                        #enumset::EnumSet {
792                            __priv_repr: chain_a.__priv_repr & chain_b.__priv_repr,
793                        }
794                    }
795
796                    pub const fn const_symmetric_difference(
797                        &self,
798                        chain_a: #enumset::EnumSet<#name>,
799                        chain_b: #enumset::EnumSet<#name>,
800                    ) -> #enumset::EnumSet<#name> {
801                        #enumset::EnumSet {
802                            __priv_repr: chain_a.__priv_repr ^ chain_b.__priv_repr,
803                        }
804                    }
805
806                    pub const fn const_complement(
807                        &self,
808                        chain: #enumset::EnumSet<#name>,
809                    ) -> #enumset::EnumSet<#name> {
810                        let mut all = #enumset::EnumSet::<#name>::all();
811                        #enumset::EnumSet {
812                            __priv_repr: !chain.__priv_repr & all.__priv_repr,
813                        }
814                    }
815                }
816            }
817        }
818        InternalRepr::Array(size) => {
819            quote! {
820                #[automatically_derived]
821                #[doc(hidden)]
822                impl #name {
823                    /// Creates a new enumset with only this variant.
824                    #[deprecated(note = "This method is an internal implementation detail \
825                                         generated by the `enumset` crate's procedural macro. It \
826                                         should not be used directly.")]
827                    #[doc(hidden)]
828                    pub const fn __impl_enumset_internal__const_only(
829                        self,
830                    ) -> #enumset::EnumSet<#name> {
831                        let mut set = #enumset::EnumSet::<#name> {
832                            __priv_repr: #internal::ArrayRepr::<{ #size }>([0; #size]),
833                        };
834                        let bit = self as u32;
835                        let (idx, bit) = (bit as usize / 64, bit % 64);
836                        set.__priv_repr.0[idx] |= 1u64 << bit;
837                        set
838                    }
839                }
840
841                #[automatically_derived]
842                #[doc(hidden)]
843                impl __EnumSetConstHelper {
844                    pub const fn const_union(
845                        &self,
846                        mut chain_a: #enumset::EnumSet<#name>,
847                        chain_b: #enumset::EnumSet<#name>,
848                    ) -> #enumset::EnumSet<#name> {
849                        let mut i = 0;
850                        while i < #size {
851                            chain_a.__priv_repr.0[i] |= chain_b.__priv_repr.0[i];
852                            i += 1;
853                        }
854                        chain_a
855                    }
856
857                    pub const fn const_intersection(
858                        &self,
859                        mut chain_a: #enumset::EnumSet<#name>,
860                        chain_b: #enumset::EnumSet<#name>,
861                    ) -> #enumset::EnumSet<#name> {
862                        let mut i = 0;
863                        while i < #size {
864                            chain_a.__priv_repr.0[i] &= chain_b.__priv_repr.0[i];
865                            i += 1;
866                        }
867                        chain_a
868                    }
869
870                    pub const fn const_symmetric_difference(
871                        &self,
872                        mut chain_a: #enumset::EnumSet<#name>,
873                        chain_b: #enumset::EnumSet<#name>,
874                    ) -> #enumset::EnumSet<#name> {
875                        let mut i = 0;
876                        while i < #size {
877                            chain_a.__priv_repr.0[i] ^= chain_b.__priv_repr.0[i];
878                            i += 1;
879                        }
880                        chain_a
881                    }
882
883                    pub const fn const_complement(
884                        &self,
885                        mut chain: #enumset::EnumSet<#name>,
886                    ) -> #enumset::EnumSet<#name> {
887                        let mut all = #enumset::EnumSet::<#name>::all();
888                        let mut i = 0;
889                        while i < #size {
890                            let new = !chain.__priv_repr.0[i] & all.__priv_repr.0[i];
891                            chain.__priv_repr.0[i] = new;
892                            i += 1;
893                        }
894                        chain
895                    }
896                }
897            }
898        }
899    };
900
901    let mut generated_warnings = SynTokenStream::new();
902    for (span, warning) in warnings {
903        generated_warnings.extend(quote_spanned! {
904            span => {
905                #[deprecated(note = #warning)]
906                #[allow(non_upper_case_globals)]
907                const _w: () = ();
908                let _ = _w;
909            }
910        });
911    }
912
913    let bit_width = info.max_discrim + 1;
914    let variant_count = info.variants.len() as u32;
915    let vis = &info.vis;
916    quote! {
917        const _: () = {
918            #[automatically_derived]
919            #[doc(hidden)]
920            #vis struct __EnumSetConstHelper;
921
922            #[automatically_derived]
923            #[doc(hidden)]
924            impl #name {
925                /// Creates a new enumset helper.
926                #[deprecated(note = "This method is an internal implementation detail generated \
927                                     by the `enumset` crate's procedural macro. It should not be \
928                                     used directly.")]
929                #[doc(hidden)]
930                pub const fn __impl_enumset_internal__const_helper(
931                    self,
932                ) -> __EnumSetConstHelper {
933                    __EnumSetConstHelper
934                }
935            }
936
937            #[automatically_derived]
938            unsafe impl #internal::EnumSetTypePrivate for #name {
939                type ConstHelper = __EnumSetConstHelper;
940                const CONST_HELPER_INSTANCE: __EnumSetConstHelper = __EnumSetConstHelper;
941
942                type Repr = #repr;
943                const ALL_BITS: Self::Repr = #all_variants;
944                const BIT_WIDTH: u32 = #bit_width;
945                const VARIANT_COUNT: u32 = #variant_count;
946
947                #into_impl
948
949                #internal::__if_serde! {
950                    #serde_ops
951                }
952            }
953
954            #[automatically_derived]
955            unsafe impl #enumset::EnumSetType for #name { }
956
957            #impl_with_repr
958            #super_impls
959            #ops
960            #inherent_impl_blocks
961
962            fn __enumset_derive__generated_warnings() {
963                #generated_warnings
964            }
965        };
966    }
967}
968
969#[proc_macro_derive(EnumSetType, attributes(enumset))]
970pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
971    let input: DeriveInput = parse_macro_input!(input);
972    let input_span = input.span();
973    let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
974        Ok(attrs) => attrs,
975        Err(e) => return e.write_errors().into(),
976    };
977    derive_enum_set_type_0(input, attrs, input_span).unwrap_or_else(|e| e.to_compile_error().into())
978}
979fn derive_enum_set_type_0(
980    input: DeriveInput,
981    attrs: EnumsetAttrs,
982    _input_span: Span,
983) -> Result<TokenStream> {
984    if !input.generics.params.is_empty() {
985        error(
986            input.generics.span(),
987            "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
988        )
989    } else if let Data::Enum(data) = &input.data {
990        let mut info = EnumSetInfo::new(&input, &attrs);
991        let mut warnings = Vec::new();
992
993        // Check enum repr
994        for attr in &input.attrs {
995            if attr.path().is_ident("repr") {
996                let meta: Ident = attr.parse_args()?;
997                match meta.to_string().as_str() {
998                    "C" | "Rust" => {}
999                    "u8" | "u16" | "u32" | "u64" | "u128" | "usize" => {}
1000                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" => {}
1001                    x => error(
1002                        attr.span(),
1003                        format!("`#[repr({x})]` cannot be used on enumset variants."),
1004                    )?,
1005                }
1006            }
1007        }
1008
1009        // Parse internal representations
1010        if let Some(repr) = &*attrs.repr {
1011            info.push_repr(attrs.repr.span(), repr)?;
1012        }
1013
1014        // Parse serialization representations
1015        if let Some(serialize_repr) = &*attrs.serialize_repr {
1016            info.push_serialize_repr(attrs.serialize_repr.span(), serialize_repr)?;
1017        }
1018        if *attrs.serialize_as_map {
1019            info.explicit_serde_repr = Some(SerdeRepr::Map);
1020            warnings.push((
1021                attrs.serialize_as_map.span(),
1022                "#[enumset(serialize_as_map)] is deprecated. \
1023                 Use `#[enumset(serialize_repr = \"map\")]` instead.",
1024            ));
1025        }
1026        if *attrs.serialize_as_list {
1027            // in old versions, serialize_as_list will override serialize_as_map
1028            info.explicit_serde_repr = Some(SerdeRepr::List);
1029            warnings.push((
1030                attrs.serialize_as_list.span(),
1031                "#[enumset(serialize_as_list)] is deprecated. \
1032                 Use `#[enumset(serialize_repr = \"list\")]` instead.",
1033            ));
1034        }
1035        #[cfg(feature = "std_deprecation_warning")]
1036        {
1037            warnings.push((
1038                _input_span,
1039                "feature = \"std\" is depercated. If you rename `enumset`, use \
1040                 feature = \"proc-macro-crate\" instead. If you don't, remove the feature.",
1041            ));
1042        }
1043        #[cfg(feature = "serde2_deprecation_warning")]
1044        {
1045            warnings.push((
1046                _input_span,
1047                "feature = \"serde2\" was never valid and did nothing. Please remove the feature.",
1048            ));
1049        }
1050
1051        // Parse enum variants
1052        for variant in &data.variants {
1053            info.push_variant(variant)?;
1054        }
1055
1056        // Validate the enumset
1057        info.validate()?;
1058
1059        // Generates the actual `EnumSetType` implementation
1060        Ok(enum_set_type_impl(info, warnings).into())
1061    } else {
1062        error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
1063    }
1064}