linera_sdk_derive/
lib.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The procedural macros for the crate `linera-sdk`.
5
6mod utils;
7
8use proc_macro::TokenStream;
9use proc_macro2::{Ident, Span};
10use syn::{
11    parse_macro_input, Fields, ItemEnum,
12    __private::{quote::quote, TokenStream2},
13};
14
15use crate::utils::{concat, snakify};
16
17#[proc_macro_derive(GraphQLMutationRoot)]
18pub fn derive_mutation_root(input: TokenStream) -> TokenStream {
19    let input = parse_macro_input!(input as ItemEnum);
20    generate_mutation_root_code(input, "linera_sdk").into()
21}
22
23#[proc_macro_derive(GraphQLMutationRootInCrate)]
24pub fn derive_mutation_root_in_crate(input: TokenStream) -> TokenStream {
25    let input = parse_macro_input!(input as ItemEnum);
26    generate_mutation_root_code(input, "crate").into()
27}
28
29fn generate_mutation_root_code(input: ItemEnum, crate_root: &str) -> TokenStream2 {
30    let crate_root = Ident::new(crate_root, Span::call_site());
31    let enum_name = input.ident;
32    let mutation_root_name = concat(&enum_name, "MutationRoot");
33    let mut methods = vec![];
34
35    for variant in input.variants {
36        let variant_name = &variant.ident;
37        let function_name = snakify(variant_name);
38        match variant.fields {
39            Fields::Named(named) => {
40                let mut fields = vec![];
41                let mut field_names = vec![];
42                for field in named.named {
43                    let name = field.ident.expect("named fields always have names");
44                    let ty = field.ty;
45                    fields.push(quote! {#name: #ty});
46                    field_names.push(name);
47                }
48                methods.push(quote! {
49                    async fn #function_name(&self, #(#fields,)*) -> [u8; 0] {
50                        let operation = #enum_name::#variant_name {
51                            #(#field_names,)*
52                        };
53
54                        self.runtime.schedule_operation(&operation);
55
56                        []
57                    }
58                });
59            }
60            Fields::Unnamed(unnamed) => {
61                let mut fields = vec![];
62                let mut field_names = vec![];
63                for (i, field) in unnamed.unnamed.iter().enumerate() {
64                    let name = concat(&syn::parse_str::<Ident>("field").unwrap(), &i.to_string());
65                    let ty = &field.ty;
66                    fields.push(quote! {#name: #ty});
67                    field_names.push(name);
68                }
69                methods.push(quote! {
70                    async fn #function_name(&self, #(#fields,)*) -> [u8; 0] {
71                        let operation = #enum_name::#variant_name(
72                            #(#field_names,)*
73                        );
74
75                        self.runtime.schedule_operation(&operation);
76
77                        []
78                    }
79                });
80            }
81            Fields::Unit => {
82                methods.push(quote! {
83                    async fn #function_name(&self) -> [u8; 0] {
84                        let operation = #enum_name::#variant_name;
85
86                        self.runtime.schedule_operation(&operation);
87
88                        []
89                    }
90                });
91            }
92        };
93    }
94
95    quote! {
96        /// Mutation root
97        pub struct #mutation_root_name<Application>
98        where
99            Application: #crate_root::Service,
100            #crate_root::ServiceRuntime<Application>: Send + Sync,
101        {
102            runtime: ::std::sync::Arc<#crate_root::ServiceRuntime<Application>>,
103        }
104
105        #[async_graphql::Object]
106        impl<Application> #mutation_root_name<Application>
107        where
108            Application: #crate_root::Service,
109            #crate_root::ServiceRuntime<Application>: Send + Sync,
110        {
111            #(#methods)*
112        }
113
114        impl<Application> #crate_root::graphql::GraphQLMutationRoot<Application> for #enum_name
115        where
116            Application: #crate_root::Service,
117            #crate_root::ServiceRuntime<Application>: Send + Sync,
118        {
119            type MutationRoot = #mutation_root_name<Application>;
120
121            fn mutation_root(
122                runtime: ::std::sync::Arc<#crate_root::ServiceRuntime<Application>>,
123            ) -> Self::MutationRoot {
124                #mutation_root_name { runtime }
125            }
126        }
127    }
128}
129
130#[cfg(test)]
131pub mod tests {
132    use syn::{parse_quote, ItemEnum, __private::quote::quote};
133
134    use crate::generate_mutation_root_code;
135
136    fn assert_eq_no_whitespace(mut actual: String, mut expected: String) {
137        // Intentionally left here for debugging purposes
138        println!("{}", actual);
139
140        actual.retain(|c| !c.is_whitespace());
141        expected.retain(|c| !c.is_whitespace());
142
143        assert_eq!(actual, expected);
144    }
145
146    #[test]
147    fn test_derive_mutation_root() {
148        let operation: ItemEnum = parse_quote! {
149            enum SomeOperation {
150                TupleVariant(String),
151                StructVariant {
152                    a: u32,
153                    b: u64
154                },
155                EmptyVariant
156            }
157        };
158
159        let output = generate_mutation_root_code(operation, "linera_sdk");
160
161        let expected = quote! {
162            /// Mutation root
163            pub struct SomeOperationMutationRoot<Application>
164            where
165                Application: linera_sdk::Service,
166                linera_sdk::ServiceRuntime<Application>: Send + Sync,
167            {
168                runtime: ::std::sync::Arc<linera_sdk::ServiceRuntime<Application>>,
169            }
170
171            #[async_graphql::Object]
172            impl<Application> SomeOperationMutationRoot<Application>
173            where
174                Application: linera_sdk::Service,
175                linera_sdk::ServiceRuntime<Application>: Send + Sync,
176            {
177                async fn tuple_variant(&self, field0: String,) -> [u8; 0] {
178                    let operation = SomeOperation::TupleVariant(field0,);
179                    self.runtime.schedule_operation(&operation);
180                    []
181                }
182
183                async fn struct_variant(&self, a: u32, b: u64,) -> [u8; 0] {
184                    let operation = SomeOperation::StructVariant { a, b, };
185                    self.runtime.schedule_operation(&operation);
186                    []
187                }
188
189                async fn empty_variant(&self) -> [u8; 0] {
190                    let operation = SomeOperation::EmptyVariant;
191                    self.runtime.schedule_operation(&operation);
192                    []
193                }
194            }
195
196            impl<Application> linera_sdk::graphql::GraphQLMutationRoot<Application>
197                for SomeOperation
198            where
199                Application: linera_sdk::Service,
200                linera_sdk::ServiceRuntime<Application>: Send + Sync,
201            {
202                type MutationRoot = SomeOperationMutationRoot<Application>;
203
204                fn mutation_root(
205                    runtime: ::std::sync::Arc<linera_sdk::ServiceRuntime<Application>>,
206                ) -> Self::MutationRoot {
207                    SomeOperationMutationRoot { runtime }
208                }
209            }
210        };
211
212        assert_eq_no_whitespace(output.to_string(), expected.to_string());
213    }
214}