enumset_derive/
lib.rs

1#![recursion_limit = "256"]
2
3extern crate proc_macro;
4
5use darling::util::SpannedValue;
6use darling::*;
7use proc_macro::TokenStream;
8use proc_macro2::{Literal, Span, TokenStream as SynTokenStream};
9use quote::*;
10use std::{collections::HashSet, fmt::Display};
11use syn::spanned::Spanned;
12use syn::{Error, Result, *};
13
14/// Helper function for emitting compile errors.
15fn error<T>(span: Span, message: impl Display) -> Result<T> {
16    Err(Error::new(span, message))
17}
18
19/// Decodes the custom attributes for our custom derive.
20#[derive(FromDeriveInput, Default)]
21#[darling(attributes(enumset), default)]
22struct EnumsetAttrs {
23    no_ops: bool,
24    no_super_impls: bool,
25    #[darling(default)]
26    repr: SpannedValue<Option<String>>,
27    #[darling(default)]
28    serialize_repr: SpannedValue<Option<String>>,
29    serialize_deny_unknown: bool,
30    #[darling(default)]
31    crate_name: Option<String>,
32
33    // legacy options
34    serialize_as_list: SpannedValue<bool>, // replaced with serialize_repr
35    serialize_as_map: SpannedValue<bool>,  // replaced with serialize_repr
36}
37
38/// The internal representation of an enumset.
39#[derive(Copy, Clone)]
40enum InternalRepr {
41    /// internal repr: `u8`
42    U8,
43    /// internal repr: `u16`
44    U16,
45    /// internal repr: `u32`
46    U32,
47    /// internal repr: `u64`
48    U64,
49    /// internal repr: `u128`
50    U128,
51    /// internal repr: `[u64; size]`
52    Array(usize),
53}
54impl InternalRepr {
55    /// Determines the number of variants supported by this repr.
56    fn supported_variants(&self) -> usize {
57        match self {
58            InternalRepr::U8 => 8,
59            InternalRepr::U16 => 16,
60            InternalRepr::U32 => 32,
61            InternalRepr::U64 => 64,
62            InternalRepr::U128 => 128,
63            InternalRepr::Array(size) => size * 64,
64        }
65    }
66}
67
68/// The serde representation of the enumset.
69#[derive(Copy, Clone)]
70enum SerdeRepr {
71    /// serde type: `u8`
72    U8,
73    /// serde type: `u16`
74    U16,
75    /// serde type: `u32`
76    U32,
77    /// serde type: `u64`
78    U64,
79    /// serde type: `u128`
80    U128,
81    /// serde type: list of `T`
82    List,
83    /// serde type: map of `T` to `bool`
84    Map,
85    /// serde type: list of `u64`
86    Array,
87}
88impl SerdeRepr {
89    /// Determines the number of variants supported by this repr.
90    fn supported_variants(&self) -> Option<usize> {
91        match self {
92            SerdeRepr::U8 => Some(8),
93            SerdeRepr::U16 => Some(16),
94            SerdeRepr::U32 => Some(32),
95            SerdeRepr::U64 => Some(64),
96            SerdeRepr::U128 => Some(128),
97            SerdeRepr::List => None,
98            SerdeRepr::Map => None,
99            SerdeRepr::Array => None,
100        }
101    }
102}
103
104/// An variant in the enum set type.
105struct EnumSetValue {
106    /// The name of the variant.
107    name: Ident,
108    /// The discriminant of the variant.
109    variant_repr: u32,
110}
111
112/// Stores information about the enum set type.
113#[allow(dead_code)]
114struct EnumSetInfo {
115    /// The name of the enum.
116    name: Ident,
117    /// The crate name to use.
118    crate_name: Option<Ident>,
119    /// The numeric type to represent the `EnumSet` as in memory.
120    explicit_internal_repr: Option<InternalRepr>,
121    /// Forces the internal numeric type of the `EnumSet` to be an array.
122    internal_repr_force_array: bool,
123    /// The numeric type to serialize the enum as.
124    explicit_serde_repr: Option<SerdeRepr>,
125    /// A list of variants in the enum.
126    variants: Vec<EnumSetValue>,
127    /// Visbility
128    vis: Visibility,
129
130    /// The highest encountered variant discriminant.
131    max_discrim: u32,
132    /// The span of the highest encountered variant.
133    max_discrim_span: Option<Span>,
134    /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
135    cur_discrim: u32,
136    /// A list of variant names that are already in use.
137    used_variant_names: HashSet<String>,
138    /// A list of variant discriminants that are already in use.
139    used_discriminants: HashSet<u32>,
140
141    /// Avoid generating operator overloads on the enum type.
142    no_ops: bool,
143    /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
144    no_super_impls: bool,
145    /// Disallow unknown bits while deserializing the enum.
146    serialize_deny_unknown: bool,
147}
148impl EnumSetInfo {
149    fn new(input: &DeriveInput, attrs: &EnumsetAttrs) -> EnumSetInfo {
150        EnumSetInfo {
151            name: input.ident.clone(),
152            crate_name: attrs
153                .crate_name
154                .as_ref()
155                .map(|x| Ident::new(x, Span::call_site())),
156            explicit_internal_repr: None,
157            internal_repr_force_array: false,
158            explicit_serde_repr: None,
159            variants: Vec::new(),
160            vis: input.vis.clone(),
161            max_discrim: 0,
162            max_discrim_span: None,
163            cur_discrim: 0,
164            used_variant_names: HashSet::new(),
165            used_discriminants: HashSet::new(),
166            no_ops: attrs.no_ops,
167            no_super_impls: attrs.no_super_impls,
168            serialize_deny_unknown: attrs.serialize_deny_unknown,
169        }
170    }
171
172    /// Explicits sets the serde representation of the enumset from a string.
173    fn push_serialize_repr(&mut self, span: Span, ty: &str) -> Result<()> {
174        match ty {
175            "u8" => self.explicit_serde_repr = Some(SerdeRepr::U8),
176            "u16" => self.explicit_serde_repr = Some(SerdeRepr::U16),
177            "u32" => self.explicit_serde_repr = Some(SerdeRepr::U32),
178            "u64" => self.explicit_serde_repr = Some(SerdeRepr::U64),
179            "u128" => self.explicit_serde_repr = Some(SerdeRepr::U128),
180            "list" => self.explicit_serde_repr = Some(SerdeRepr::List),
181            "map" => self.explicit_serde_repr = Some(SerdeRepr::Map),
182            "array" => self.explicit_serde_repr = Some(SerdeRepr::Array),
183            _ => error(span, format!("`{}` is not a valid serialized representation.", ty))?,
184        }
185        Ok(())
186    }
187
188    /// Explicitly sets the representation of the enumset from a string.
189    fn push_repr(&mut self, span: Span, ty: &str) -> Result<()> {
190        match ty {
191            "u8" => self.explicit_internal_repr = Some(InternalRepr::U8),
192            "u16" => self.explicit_internal_repr = Some(InternalRepr::U16),
193            "u32" => self.explicit_internal_repr = Some(InternalRepr::U32),
194            "u64" => self.explicit_internal_repr = Some(InternalRepr::U64),
195            "u128" => self.explicit_internal_repr = Some(InternalRepr::U128),
196            "array" => self.internal_repr_force_array = true,
197            _ => error(span, format!("`{}` is not a valid internal enumset representation.", ty))?,
198        }
199        Ok(())
200    }
201
202    /// Adds a variant to the enumset.
203    fn push_variant(&mut self, variant: &Variant) -> Result<()> {
204        if self.used_variant_names.contains(&variant.ident.to_string()) {
205            error(variant.span(), "Duplicated variant name.")
206        } else if let Fields::Unit = variant.fields {
207            // Parse the discriminant.
208            if let Some((_, expr)) = &variant.discriminant {
209                if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
210                    match i.base10_parse() {
211                        Ok(val) => self.cur_discrim = val,
212                        Err(_) => error(expr.span(), "Enum discriminants must fit into `u32`.")?,
213                    }
214                } else if let Expr::Unary(ExprUnary { op: UnOp::Neg(_), .. }) = expr {
215                    error(expr.span(), "Enum discriminants must not be negative.")?;
216                } else {
217                    error(variant.span(), "Enum discriminants must be literal expressions.")?;
218                }
219            }
220
221            // Validate the discriminant.
222            let discriminant = self.cur_discrim;
223            if discriminant >= 0xFFFFFFC0 {
224                error(variant.span(), "Maximum discriminant allowed is `0xFFFFFFBF`.")?;
225            }
226            if self.used_discriminants.contains(&discriminant) {
227                error(variant.span(), "Duplicated enum discriminant.")?;
228            }
229
230            // Add the variant to the info.
231            self.cur_discrim += 1;
232            if discriminant > self.max_discrim {
233                self.max_discrim = discriminant;
234                self.max_discrim_span = Some(variant.span());
235            }
236            self.variants
237                .push(EnumSetValue { name: variant.ident.clone(), variant_repr: discriminant });
238            self.used_variant_names.insert(variant.ident.to_string());
239            self.used_discriminants.insert(discriminant);
240
241            Ok(())
242        } else {
243            error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
244        }
245    }
246
247    /// Returns the actual internal representation of the set.
248    fn internal_repr(&self) -> InternalRepr {
249        match self.explicit_internal_repr {
250            Some(x) => x,
251            None => match self.max_discrim {
252                x if x < 8 && !self.internal_repr_force_array => InternalRepr::U8,
253                x if x < 16 && !self.internal_repr_force_array => InternalRepr::U16,
254                x if x < 32 && !self.internal_repr_force_array => InternalRepr::U32,
255                x if x < 64 && !self.internal_repr_force_array => InternalRepr::U64,
256                x => InternalRepr::Array((x as usize + 64) / 64),
257            },
258        }
259    }
260
261    /// Returns the actual serde representation of the set.
262    fn serde_repr(&self) -> SerdeRepr {
263        match self.explicit_serde_repr {
264            Some(x) => x,
265            None => match self.max_discrim {
266                x if x < 8 => SerdeRepr::U8,
267                x if x < 16 => SerdeRepr::U16,
268                x if x < 32 => SerdeRepr::U32,
269                x if x < 64 => SerdeRepr::U64,
270                x if x < 128 => SerdeRepr::U128,
271                _ => SerdeRepr::Array,
272            },
273        }
274    }
275
276    /// Validate the enumset type.
277    fn validate(&self) -> Result<()> {
278        // Gets the span of the maximum value.
279        let largest_discriminant_span = match &self.max_discrim_span {
280            Some(x) => *x,
281            None => Span::call_site(),
282        };
283
284        // Check if all bits of the bitset can fit in the memory representation, if one was given.
285        if self.internal_repr().supported_variants() <= self.max_discrim as usize {
286            error(
287                largest_discriminant_span,
288                "`repr` is too small to contain the largest discriminant.",
289            )?;
290        }
291
292        // Check if all bits of the bitset can fit in the serialization representation.
293        if let Some(supported_variants) = self.serde_repr().supported_variants() {
294            if supported_variants <= self.max_discrim as usize {
295                error(
296                    largest_discriminant_span,
297                    "`serialize_repr` is too small to contain the largest discriminant.",
298                )?;
299            }
300        }
301
302        Ok(())
303    }
304
305    /// Returns a bitmask of all variants in the set.
306    fn variant_map(&self) -> Vec<u64> {
307        let mut vec = vec![0];
308        for variant in &self.variants {
309            let (idx, bit) = (variant.variant_repr as usize / 64, variant.variant_repr % 64);
310            while idx >= vec.len() {
311                vec.push(0);
312            }
313            vec[idx] |= 1u64 << bit;
314        }
315        vec
316    }
317}
318
319/// Generates the actual `EnumSetType` impl.
320fn enum_set_type_impl(info: EnumSetInfo, warnings: Vec<(Span, &'static str)>) -> SynTokenStream {
321    let name = &info.name;
322
323    let enumset = match &info.crate_name {
324        Some(crate_name) => quote!(::#crate_name),
325        None => {
326            #[cfg(feature = "proc-macro-crate")]
327            {
328                use proc_macro_crate::FoundCrate;
329
330                let crate_name = proc_macro_crate::crate_name("enumset");
331                match crate_name {
332                    Ok(FoundCrate::Name(name)) => {
333                        let ident = Ident::new(&name, Span::call_site());
334                        quote!(::#ident)
335                    }
336                    _ => quote!(::enumset),
337                }
338            }
339
340            #[cfg(not(feature = "proc-macro-crate"))]
341            {
342                quote!(::enumset)
343            }
344        }
345    };
346    let typed_enumset = quote!(#enumset::EnumSet<#name>);
347    let core = quote!(#enumset::__internal::core_export);
348    let internal = quote!(#enumset::__internal);
349    #[cfg(feature = "serde")]
350    let serde = quote!(#enumset::__internal::serde);
351
352    let repr = match info.internal_repr() {
353        InternalRepr::U8 => quote! { u8 },
354        InternalRepr::U16 => quote! { u16 },
355        InternalRepr::U32 => quote! { u32 },
356        InternalRepr::U64 => quote! { u64 },
357        InternalRepr::U128 => quote! { u128 },
358        InternalRepr::Array(size) => quote! { #internal::ArrayRepr<{ #size }> },
359    };
360    let variant_map = info.variant_map();
361    let all_variants = match info.internal_repr() {
362        InternalRepr::U8 | InternalRepr::U16 | InternalRepr::U32 | InternalRepr::U64 => {
363            let lit = Literal::u64_unsuffixed(variant_map[0]);
364            quote! { #lit }
365        }
366        InternalRepr::U128 => {
367            let lit = Literal::u128_unsuffixed(
368                variant_map[0] as u128 | variant_map.get(1).map_or(0, |x| (*x as u128) << 64),
369            );
370            quote! { #lit }
371        }
372        InternalRepr::Array(size) => {
373            let mut new = Vec::new();
374            for i in 0..size {
375                new.push(Literal::u64_unsuffixed(*variant_map.get(i).unwrap_or(&0)));
376            }
377            quote! { #internal::ArrayRepr::<{ #size }>([#(#new,)*]) }
378        }
379    };
380
381    let ops = if info.no_ops {
382        quote! {}
383    } else {
384        quote! {
385            #[automatically_derived]
386            impl<O: Into<#typed_enumset>> #core::ops::Sub<O> for #name {
387                type Output = #typed_enumset;
388                fn sub(self, other: O) -> Self::Output {
389                    #enumset::EnumSet::only(self) - other.into()
390                }
391            }
392            #[automatically_derived]
393            impl<O: Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
394                type Output = #typed_enumset;
395                fn bitand(self, other: O) -> Self::Output {
396                    #enumset::EnumSet::only(self) & other.into()
397                }
398            }
399            #[automatically_derived]
400            impl<O: Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
401                type Output = #typed_enumset;
402                fn bitor(self, other: O) -> Self::Output {
403                    #enumset::EnumSet::only(self) | other.into()
404                }
405            }
406            #[automatically_derived]
407            impl<O: Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
408                type Output = #typed_enumset;
409                fn bitxor(self, other: O) -> Self::Output {
410                    #enumset::EnumSet::only(self) ^ other.into()
411                }
412            }
413            #[automatically_derived]
414            impl #core::ops::Not for #name {
415                type Output = #typed_enumset;
416                fn not(self) -> Self::Output {
417                    !#enumset::EnumSet::only(self)
418                }
419            }
420            #[automatically_derived]
421            impl #core::cmp::PartialEq<#typed_enumset> for #name {
422                fn eq(&self, other: &#typed_enumset) -> bool {
423                    #enumset::EnumSet::only(*self) == *other
424                }
425            }
426        }
427    };
428
429    #[cfg(feature = "serde")]
430    let serde_repr = info.serde_repr();
431
432    #[cfg(feature = "serde")]
433    let serde_ops = match serde_repr {
434        SerdeRepr::U8 | SerdeRepr::U16 | SerdeRepr::U32 | SerdeRepr::U64 | SerdeRepr::U128 => {
435            let (serialize_repr, from_fn, to_fn) = match serde_repr {
436                SerdeRepr::U8 => (quote! { u8 }, quote! { from_u8 }, quote! { to_u8 }),
437                SerdeRepr::U16 => (quote! { u16 }, quote! { from_u16 }, quote! { to_u16 }),
438                SerdeRepr::U32 => (quote! { u32 }, quote! { from_u32 }, quote! { to_u32 }),
439                SerdeRepr::U64 => (quote! { u64 }, quote! { from_u64 }, quote! { to_u64 }),
440                SerdeRepr::U128 => (quote! { u128 }, quote! { from_u128 }, quote! { to_u128 }),
441                _ => unreachable!(),
442            };
443            let check_unknown = if info.serialize_deny_unknown {
444                quote! {
445                    if value & !#all_variants != 0 {
446                        use #serde::de::Error;
447                        return #core::prelude::v1::Err(
448                            D::Error::custom("enumset contains unknown bits")
449                        )
450                    }
451                }
452            } else {
453                quote! {}
454            };
455            quote! {
456                fn serialize<S: #serde::Serializer>(
457                    set: #enumset::EnumSet<#name>, ser: S,
458                ) -> #core::result::Result<S::Ok, S::Error> {
459                    let value =
460                        <#repr as #enumset::__internal::EnumSetTypeRepr>::#to_fn(&set.__priv_repr);
461                    #serde::Serialize::serialize(&value, ser)
462                }
463                fn deserialize<'de, D: #serde::Deserializer<'de>>(
464                    de: D,
465                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
466                    let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
467                    #check_unknown
468                    let value = <#repr as #enumset::__internal::EnumSetTypeRepr>::#from_fn(value);
469                    #core::prelude::v1::Ok(#enumset::EnumSet {
470                        __priv_repr: value & #all_variants,
471                    })
472                }
473            }
474        }
475        SerdeRepr::List => {
476            let expecting_str = format!("a list of {}", name);
477            quote! {
478                fn serialize<S: #serde::Serializer>(
479                    set: #enumset::EnumSet<#name>, ser: S,
480                ) -> #core::result::Result<S::Ok, S::Error> {
481                    use #serde::ser::SerializeSeq;
482                    let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
483                    for bit in set {
484                        seq.serialize_element(&bit)?;
485                    }
486                    seq.end()
487                }
488                fn deserialize<'de, D: #serde::Deserializer<'de>>(
489                    de: D,
490                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
491                    struct Visitor;
492                    impl <'de> #serde::de::Visitor<'de> for Visitor {
493                        type Value = #enumset::EnumSet<#name>;
494                        fn expecting(
495                            &self, formatter: &mut #core::fmt::Formatter,
496                        ) -> #core::fmt::Result {
497                            write!(formatter, #expecting_str)
498                        }
499                        fn visit_seq<A>(
500                            mut self, mut seq: A,
501                        ) -> #core::result::Result<Self::Value, A::Error> where
502                            A: #serde::de::SeqAccess<'de>
503                        {
504                            let mut accum = #enumset::EnumSet::<#name>::new();
505                            while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
506                                accum |= val;
507                            }
508                            #core::prelude::v1::Ok(accum)
509                        }
510                    }
511                    de.deserialize_seq(Visitor)
512                }
513            }
514        }
515        SerdeRepr::Map => {
516            let expecting_str = format!("a map from {} to bool", name);
517            quote! {
518                fn serialize<S: #serde::Serializer>(
519                    set: #enumset::EnumSet<#name>, ser: S,
520                ) -> #core::result::Result<S::Ok, S::Error> {
521                    use #serde::ser::SerializeMap;
522                    let mut map = ser.serialize_map(#core::prelude::v1::Some(set.len()))?;
523                    for bit in set {
524                        map.serialize_entry(&bit, &true)?;
525                    }
526                    map.end()
527                }
528                fn deserialize<'de, D: #serde::Deserializer<'de>>(
529                    de: D,
530                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
531                    struct Visitor;
532                    impl <'de> #serde::de::Visitor<'de> for Visitor {
533                        type Value = #enumset::EnumSet<#name>;
534                        fn expecting(
535                            &self, formatter: &mut #core::fmt::Formatter,
536                        ) -> #core::fmt::Result {
537                            write!(formatter, #expecting_str)
538                        }
539                        fn visit_map<A>(
540                            mut self, mut map: A,
541                        ) -> #core::result::Result<Self::Value, A::Error> where
542                            A: #serde::de::MapAccess<'de>
543                        {
544                            let mut accum = #enumset::EnumSet::<#name>::new();
545                            while let #core::prelude::v1::Some((val, is_present)) =
546                                map.next_entry::<#name, bool>()?
547                            {
548                                if is_present {
549                                    accum |= val;
550                                }
551                            }
552                            #core::prelude::v1::Ok(accum)
553                        }
554                    }
555                    de.deserialize_map(Visitor)
556                }
557            }
558        }
559        SerdeRepr::Array => {
560            let preferred_size = quote! {
561                <<#name as #internal::EnumSetTypePrivate>::Repr as #internal::EnumSetTypeRepr>
562                    ::PREFERRED_ARRAY_LEN
563            };
564            let (check_extra, convert_array) = if info.serialize_deny_unknown {
565                (
566                    quote! {
567                        if _val != 0 {
568                            return #core::prelude::v1::Err(
569                                D::Error::custom("enumset contains unknown bits")
570                            )
571                        }
572                    },
573                    quote! {
574                        match #enumset::EnumSet::<#name>::try_from_array(accum) {
575                            Some(x) => x,
576                            None => #core::prelude::v1::Err(
577                                D::Error::custom("enumset contains unknown bits")
578                            ),
579                        }
580                    },
581                )
582            } else {
583                (quote! {}, quote! {
584                    #core::prelude::v1::Ok(#enumset::EnumSet::<#name>::from_array(accum))
585                })
586            };
587            quote! {
588                fn serialize<S: #serde::Serializer>(
589                    set: #enumset::EnumSet<#name>, ser: S,
590                ) -> #core::result::Result<S::Ok, S::Error> {
591                    // read the enum as an array
592                    let array = set.as_array::<{ #preferred_size }>();
593
594                    // find the last non-zero value in the array
595                    let mut end = array.len();
596                    for i in (0..array.len()).rev() {
597                        if array[i] != 0 {
598                            break;
599                        }
600                        end = i + 1;
601                    }
602
603                    // serialize the array
604                    #serde::Serialize::serialize(&array[..end], ser)
605                }
606                fn deserialize<'de, D: #serde::Deserializer<'de>>(
607                    de: D,
608                ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
609                    struct Visitor;
610                    impl <'de> #serde::de::Visitor<'de> for Visitor {
611                        type Value = #enumset::EnumSet<#name>;
612                        fn expecting(
613                            &self, formatter: &mut #core::fmt::Formatter,
614                        ) -> #core::fmt::Result {
615                            write!(formatter, "a list of u64")
616                        }
617                        fn visit_seq<A>(
618                            mut self, mut seq: A,
619                        ) -> #core::result::Result<Self::Value, A::Error> where
620                            A: #serde::de::SeqAccess<'de>
621                        {
622                            let mut accum = [0; #preferred_size];
623
624                            let mut i = 0;
625                            while let #core::prelude::v1::Some(val) = seq.next_element::<u64>()? {
626                                accum[i] = val;
627                                i += 1;
628
629                                if i == accum.len() {
630                                    break;
631                                }
632                            }
633                            while let #core::prelude::v1::Some(_val) = seq.next_element::<u64>()? {
634                                #check_extra
635                            }
636
637                            #convert_array
638                        }
639                    }
640                    de.deserialize_seq(Visitor)
641                }
642            }
643        }
644    };
645
646    #[cfg(not(feature = "serde"))]
647    let serde_ops = quote! {};
648
649    let is_uninhabited = info.variants.is_empty();
650    let is_zst = info.variants.len() == 1;
651    let into_impl = if is_uninhabited {
652        quote! {
653            fn enum_into_u32(self) -> u32 {
654                panic!(concat!(stringify!(#name), " is uninhabited."))
655            }
656            unsafe fn enum_from_u32(val: u32) -> Self {
657                panic!(concat!(stringify!(#name), " is uninhabited."))
658            }
659        }
660    } else if is_zst {
661        let variant = &info.variants[0].name;
662        quote! {
663            fn enum_into_u32(self) -> u32 {
664                self as u32
665            }
666            unsafe fn enum_from_u32(val: u32) -> Self {
667                #name::#variant
668            }
669        }
670    } else {
671        let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
672        let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
673
674        let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
675            .iter()
676            .map(|x| Ident::new(x, Span::call_site()))
677            .collect();
678        let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
679            .iter()
680            .map(|x| Ident::new(x, Span::call_site()))
681            .collect();
682
683        quote! {
684            fn enum_into_u32(self) -> u32 {
685                self as u32
686            }
687            unsafe fn enum_from_u32(val: u32) -> Self {
688                // We put these in const fields so the branches they guard aren't generated even
689                // on -O0
690                #(const #const_field: bool =
691                    #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
692                match val {
693                    // Every valid variant value has an explicit branch. If they get optimized out,
694                    // great. If the representation has changed somehow, and they don't, oh well,
695                    // there's still no UB.
696                    #(#variant_value => #name::#variant_name,)*
697                    // Helps hint to the LLVM that this is a transmute. Note that this branch is
698                    // still unreachable.
699                    #(x if #const_field => {
700                        let x = x as #int_type;
701                        *(&x as *const _ as *const #name)
702                    })*
703                    // Default case. Sometimes causes LLVM to generate a table instead of a simple
704                    // transmute, but, oh well.
705                    _ => #core::hint::unreachable_unchecked(),
706                }
707            }
708        }
709    };
710
711    let eq_impl = if is_uninhabited {
712        quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
713    } else {
714        quote!((*self as u32) == (*other as u32))
715    };
716
717    let super_impls = if info.no_super_impls {
718        quote! {}
719    } else {
720        quote! {
721            #[automatically_derived]
722            impl #core::cmp::PartialEq for #name {
723                fn eq(&self, other: &Self) -> bool {
724                    #eq_impl
725                }
726            }
727            #[automatically_derived]
728            impl #core::cmp::Eq for #name { }
729            #[automatically_derived]
730            #[allow(clippy::expl_impl_clone_on_copy)]
731            impl #core::clone::Clone for #name {
732                fn clone(&self) -> Self {
733                    *self
734                }
735            }
736            #[automatically_derived]
737            impl #core::marker::Copy for #name { }
738        }
739    };
740
741    let impl_with_repr = if info.explicit_internal_repr.is_some() {
742        quote! {
743            #[automatically_derived]
744            unsafe impl #enumset::EnumSetTypeWithRepr for #name {
745                type Repr = #repr;
746            }
747        }
748    } else {
749        quote! {}
750    };
751
752    let inherent_impl_blocks = match info.internal_repr() {
753        InternalRepr::U8
754        | InternalRepr::U16
755        | InternalRepr::U32
756        | InternalRepr::U64
757        | InternalRepr::U128 => {
758            let self_as_repr_mask = if is_uninhabited {
759                quote! { 0 } // impossible anyway
760            } else {
761                quote! { 1 << self as #repr }
762            };
763
764            quote! {
765                #[automatically_derived]
766                #[doc(hidden)]
767                impl #name {
768                    /// Creates a new enumset with only this variant.
769                    #[deprecated(note = "This method is an internal implementation detail \
770                                         generated by the `enumset` crate's procedural macro. It \
771                                         should not be used directly.")]
772                    #[doc(hidden)]
773                    pub const fn __impl_enumset_internal__const_only(
774                        self,
775                    ) -> #enumset::EnumSet<#name> {
776                        #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
777                    }
778                }
779
780                #[automatically_derived]
781                #[doc(hidden)]
782                impl __EnumSetConstHelper {
783                    pub const fn const_union(
784                        &self,
785                        chain_a: #enumset::EnumSet<#name>,
786                        chain_b: #enumset::EnumSet<#name>,
787                    ) -> #enumset::EnumSet<#name> {
788                        #enumset::EnumSet {
789                            __priv_repr: chain_a.__priv_repr | chain_b.__priv_repr,
790                        }
791                    }
792
793                    pub const fn const_intersection(
794                        &self,
795                        chain_a: #enumset::EnumSet<#name>,
796                        chain_b: #enumset::EnumSet<#name>,
797                    ) -> #enumset::EnumSet<#name> {
798                        #enumset::EnumSet {
799                            __priv_repr: chain_a.__priv_repr & chain_b.__priv_repr,
800                        }
801                    }
802
803                    pub const fn const_symmetric_difference(
804                        &self,
805                        chain_a: #enumset::EnumSet<#name>,
806                        chain_b: #enumset::EnumSet<#name>,
807                    ) -> #enumset::EnumSet<#name> {
808                        #enumset::EnumSet {
809                            __priv_repr: chain_a.__priv_repr ^ chain_b.__priv_repr,
810                        }
811                    }
812
813                    pub const fn const_complement(
814                        &self,
815                        chain: #enumset::EnumSet<#name>,
816                    ) -> #enumset::EnumSet<#name> {
817                        let mut all = #enumset::EnumSet::<#name>::all();
818                        #enumset::EnumSet {
819                            __priv_repr: !chain.__priv_repr & all.__priv_repr,
820                        }
821                    }
822                }
823            }
824        }
825        InternalRepr::Array(size) => {
826            quote! {
827                #[automatically_derived]
828                #[doc(hidden)]
829                impl #name {
830                    /// Creates a new enumset with only this variant.
831                    #[deprecated(note = "This method is an internal implementation detail \
832                                         generated by the `enumset` crate's procedural macro. It \
833                                         should not be used directly.")]
834                    #[doc(hidden)]
835                    pub const fn __impl_enumset_internal__const_only(
836                        self,
837                    ) -> #enumset::EnumSet<#name> {
838                        let mut set = #enumset::EnumSet::<#name> {
839                            __priv_repr: #internal::ArrayRepr::<{ #size }>([0; #size]),
840                        };
841                        let bit = self as u32;
842                        let (idx, bit) = (bit as usize / 64, bit % 64);
843                        set.__priv_repr.0[idx] |= 1u64 << bit;
844                        set
845                    }
846                }
847
848                #[automatically_derived]
849                #[doc(hidden)]
850                impl __EnumSetConstHelper {
851                    pub const fn const_union(
852                        &self,
853                        mut chain_a: #enumset::EnumSet<#name>,
854                        chain_b: #enumset::EnumSet<#name>,
855                    ) -> #enumset::EnumSet<#name> {
856                        let mut i = 0;
857                        while i < #size {
858                            chain_a.__priv_repr.0[i] |= chain_b.__priv_repr.0[i];
859                            i += 1;
860                        }
861                        chain_a
862                    }
863
864                    pub const fn const_intersection(
865                        &self,
866                        mut chain_a: #enumset::EnumSet<#name>,
867                        chain_b: #enumset::EnumSet<#name>,
868                    ) -> #enumset::EnumSet<#name> {
869                        let mut i = 0;
870                        while i < #size {
871                            chain_a.__priv_repr.0[i] &= chain_b.__priv_repr.0[i];
872                            i += 1;
873                        }
874                        chain_a
875                    }
876
877                    pub const fn const_symmetric_difference(
878                        &self,
879                        mut chain_a: #enumset::EnumSet<#name>,
880                        chain_b: #enumset::EnumSet<#name>,
881                    ) -> #enumset::EnumSet<#name> {
882                        let mut i = 0;
883                        while i < #size {
884                            chain_a.__priv_repr.0[i] ^= chain_b.__priv_repr.0[i];
885                            i += 1;
886                        }
887                        chain_a
888                    }
889
890                    pub const fn const_complement(
891                        &self,
892                        mut chain: #enumset::EnumSet<#name>,
893                    ) -> #enumset::EnumSet<#name> {
894                        let mut all = #enumset::EnumSet::<#name>::all();
895                        let mut i = 0;
896                        while i < #size {
897                            let new = !chain.__priv_repr.0[i] & all.__priv_repr.0[i];
898                            chain.__priv_repr.0[i] = new;
899                            i += 1;
900                        }
901                        chain
902                    }
903                }
904            }
905        }
906    };
907
908    let mut generated_warnings = SynTokenStream::new();
909    for (span, warning) in warnings {
910        generated_warnings.extend(quote_spanned! {
911            span => {
912                #[deprecated(note = #warning)]
913                #[allow(non_upper_case_globals)]
914                const _w: () = ();
915                let _ = _w;
916            }
917        });
918    }
919
920    let bit_width = info.max_discrim + 1;
921    let variant_count = info.variants.len() as u32;
922    let vis = &info.vis;
923    quote! {
924        const _: () = {
925            #[automatically_derived]
926            #[doc(hidden)]
927            #vis struct __EnumSetConstHelper;
928
929            #[automatically_derived]
930            #[doc(hidden)]
931            impl #name {
932                /// Creates a new enumset helper.
933                #[deprecated(note = "This method is an internal implementation detail generated \
934                                     by the `enumset` crate's procedural macro. It should not be \
935                                     used directly.")]
936                #[doc(hidden)]
937                pub const fn __impl_enumset_internal__const_helper(
938                    self,
939                ) -> __EnumSetConstHelper {
940                    __EnumSetConstHelper
941                }
942            }
943
944            #[automatically_derived]
945            unsafe impl #internal::EnumSetTypePrivate for #name {
946                type ConstHelper = __EnumSetConstHelper;
947                const CONST_HELPER_INSTANCE: __EnumSetConstHelper = __EnumSetConstHelper;
948
949                type Repr = #repr;
950                const ALL_BITS: Self::Repr = #all_variants;
951                const BIT_WIDTH: u32 = #bit_width;
952                const VARIANT_COUNT: u32 = #variant_count;
953
954                #into_impl
955                #serde_ops
956            }
957
958            #[automatically_derived]
959            unsafe impl #enumset::EnumSetType for #name { }
960
961            #impl_with_repr
962            #super_impls
963            #ops
964            #inherent_impl_blocks
965
966            fn __enumset_derive__generated_warnings() {
967                #generated_warnings
968            }
969        };
970    }
971}
972
973#[proc_macro_derive(EnumSetType, attributes(enumset))]
974pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
975    let input: DeriveInput = parse_macro_input!(input);
976    let input_span = input.span();
977    let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
978        Ok(attrs) => attrs,
979        Err(e) => return e.write_errors().into(),
980    };
981    derive_enum_set_type_0(input, attrs, input_span).unwrap_or_else(|e| e.to_compile_error().into())
982}
983fn derive_enum_set_type_0(
984    input: DeriveInput,
985    attrs: EnumsetAttrs,
986    _input_span: Span,
987) -> Result<TokenStream> {
988    if !input.generics.params.is_empty() {
989        error(
990            input.generics.span(),
991            "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
992        )
993    } else if let Data::Enum(data) = &input.data {
994        let mut info = EnumSetInfo::new(&input, &attrs);
995        let mut warnings = Vec::new();
996
997        // Check enum repr
998        for attr in &input.attrs {
999            if attr.path().is_ident("repr") {
1000                let meta: Ident = attr.parse_args()?;
1001                match meta.to_string().as_str() {
1002                    "C" | "Rust" => {}
1003                    "u8" | "u16" | "u32" | "u64" | "u128" | "usize" => {}
1004                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" => {}
1005                    x => error(
1006                        attr.span(),
1007                        format!("`#[repr({})]` cannot be used on enumset variants.", x),
1008                    )?,
1009                }
1010            }
1011        }
1012
1013        // Parse internal representations
1014        if let Some(repr) = &*attrs.repr {
1015            info.push_repr(attrs.repr.span(), repr)?;
1016        }
1017
1018        // Parse serialization representations
1019        if let Some(serialize_repr) = &*attrs.serialize_repr {
1020            info.push_serialize_repr(attrs.serialize_repr.span(), serialize_repr)?;
1021        }
1022        if *attrs.serialize_as_map {
1023            info.explicit_serde_repr = Some(SerdeRepr::Map);
1024            warnings.push((
1025                attrs.serialize_as_map.span(),
1026                "#[enumset(serialize_as_map)] is deprecated. \
1027                 Use `#[enumset(serialize_repr = \"map\")]` instead.",
1028            ));
1029        }
1030        if *attrs.serialize_as_list {
1031            // in old versions, serialize_as_list will override serialize_as_map
1032            info.explicit_serde_repr = Some(SerdeRepr::List);
1033            warnings.push((
1034                attrs.serialize_as_list.span(),
1035                "#[enumset(serialize_as_list)] is deprecated. \
1036                 Use `#[enumset(serialize_repr = \"list\")]` instead.",
1037            ));
1038        }
1039        #[cfg(feature = "std_deprecation_warning")]
1040        {
1041            warnings.push((
1042                _input_span,
1043                "feature = \"std\" is depercated. If you rename `enumset`, use \
1044                 feature = \"proc-macro-crate\" instead. If you don't, remove the feature.",
1045            ));
1046        }
1047
1048        // Parse enum variants
1049        for variant in &data.variants {
1050            info.push_variant(variant)?;
1051        }
1052
1053        // Validate the enumset
1054        info.validate()?;
1055
1056        // Generates the actual `EnumSetType` implementation
1057        Ok(enum_set_type_impl(info, warnings).into())
1058    } else {
1059        error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
1060    }
1061}