enum_iterator_derive/
lib.rs

1// Copyright (C) 2018-2021 Stephane Raux. Distributed under the 0BSD license.
2
3//! # Overview
4//! - [📦 crates.io](https://crates.io/crates/enum-iterator-derive)
5//! - [📖 Documentation](https://docs.rs/enum-iterator-derive)
6//! - [âš– 0BSD license](https://spdx.org/licenses/0BSD.html)
7//!
8//! Procedural macro to derive `IntoEnumIterator` for field-less enums.
9//!
10//! See crate [enum-iterator](https://docs.rs/enum-iterator) for details.
11//!
12//! # Contribute
13//! All contributions shall be licensed under the [0BSD license](https://spdx.org/licenses/0BSD.html).
14
15#![recursion_limit = "128"]
16#![deny(warnings)]
17
18extern crate proc_macro;
19
20use proc_macro2::{Span, TokenStream};
21use quote::{quote, ToTokens};
22use std::fmt::{self, Display};
23use syn::{DeriveInput, Ident};
24
25/// Derives `IntoEnumIterator` for field-less enums.
26#[proc_macro_derive(IntoEnumIterator)]
27pub fn into_enum_iterator(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28    derive(input)
29        .unwrap_or_else(|e| e.to_compile_error())
30        .into()
31}
32
33fn derive(input: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
34    let ast = syn::parse::<DeriveInput>(input)?;
35    if !ast.generics.params.is_empty() {
36        return Err(Error::GenericsUnsupported.with_tokens(&ast.generics));
37    }
38    let ty = &ast.ident;
39    let vis = &ast.vis;
40    let ty_doc = format!("Iterator over the variants of {}", ty);
41    let iter_ty = Ident::new(&(ty.to_string() + "EnumIterator"), Span::call_site());
42    let variants = match &ast.data {
43        syn::Data::Enum(e) => &e.variants,
44        _ => return Err(Error::ExpectedEnum.with_tokens(&ast)),
45    };
46    let arms = variants
47        .iter()
48        .enumerate()
49        .map(|(idx, v)| {
50            let id = &v.ident;
51            match v.fields {
52                syn::Fields::Unit => Ok(quote! { #idx => #ty::#id, }),
53                _ => Err(Error::ExpectedUnitVariant.with_tokens(v)),
54            }
55        })
56        .collect::<Result<Vec<_>, _>>()?;
57    let nb_variants = arms.len();
58    let tokens = quote! {
59        #[doc = #ty_doc]
60        #[derive(Clone, Copy, Debug)]
61        #vis struct #iter_ty {
62            idx: usize,
63        }
64
65        impl ::core::iter::Iterator for #iter_ty {
66            type Item = #ty;
67
68            fn next(&mut self) -> ::core::option::Option<Self::Item> {
69                let id = match self.idx {
70                    #(#arms)*
71                    _ => return ::core::option::Option::None,
72                };
73                self.idx += 1;
74                ::core::option::Option::Some(id)
75            }
76
77            fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
78                let n = #nb_variants - self.idx;
79                (n, ::core::option::Option::Some(n))
80            }
81        }
82
83        impl ::core::iter::ExactSizeIterator for #iter_ty {}
84        impl ::core::iter::FusedIterator for #iter_ty {}
85
86        impl ::enum_iterator::IntoEnumIterator for #ty {
87            type Iterator = #iter_ty;
88
89            const VARIANT_COUNT: usize = #nb_variants;
90
91            fn into_enum_iter() -> Self::Iterator {
92                #iter_ty { idx: 0 }
93            }
94        }
95    };
96    let tokens = quote! {
97        const _: () = {
98            #tokens
99        };
100    };
101    Ok(tokens)
102}
103
104#[derive(Debug)]
105enum Error {
106    ExpectedEnum,
107    ExpectedUnitVariant,
108    GenericsUnsupported,
109}
110
111impl Error {
112    fn with_tokens<T: ToTokens>(self, tokens: T) -> syn::Error {
113        syn::Error::new_spanned(tokens, self)
114    }
115}
116
117impl Display for Error {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        match self {
120            Error::ExpectedEnum => {
121                f.write_str("IntoEnumIterator can only be derived for enum types")
122            }
123            Error::ExpectedUnitVariant => f.write_str(
124                "IntoEnumIterator can only be derived for enum types with unit \
125                    variants only",
126            ),
127            Error::GenericsUnsupported => {
128                f.write_str("IntoEnumIterator cannot be derived for generic types")
129            }
130        }
131    }
132}