1mod utils;
7
8use proc_macro::TokenStream;
9use proc_macro2::{Ident, Span};
10use syn::{
11 parse_macro_input, Fields, ItemEnum,
12 __private::{quote::quote, TokenStream2},
13};
14
15use crate::utils::{concat, snakify};
16
17#[proc_macro_derive(GraphQLMutationRoot)]
18pub fn derive_mutation_root(input: TokenStream) -> TokenStream {
19 let input = parse_macro_input!(input as ItemEnum);
20 generate_mutation_root_code(input, "linera_sdk").into()
21}
22
23#[proc_macro_derive(GraphQLMutationRootInCrate)]
24pub fn derive_mutation_root_in_crate(input: TokenStream) -> TokenStream {
25 let input = parse_macro_input!(input as ItemEnum);
26 generate_mutation_root_code(input, "crate").into()
27}
28
29fn generate_mutation_root_code(input: ItemEnum, crate_root: &str) -> TokenStream2 {
30 let crate_root = Ident::new(crate_root, Span::call_site());
31 let enum_name = input.ident;
32 let mutation_root_name = concat(&enum_name, "MutationRoot");
33 let mut methods = vec![];
34
35 for variant in input.variants {
36 let variant_name = &variant.ident;
37 let function_name = snakify(variant_name);
38 match variant.fields {
39 Fields::Named(named) => {
40 let mut fields = vec![];
41 let mut field_names = vec![];
42 for field in named.named {
43 let name = field.ident.expect("named fields always have names");
44 let ty = field.ty;
45 fields.push(quote! {#name: #ty});
46 field_names.push(name);
47 }
48 methods.push(quote! {
49 async fn #function_name(&self, #(#fields,)*) -> [u8; 0] {
50 let operation = #enum_name::#variant_name {
51 #(#field_names,)*
52 };
53
54 self.runtime.schedule_operation(&operation);
55
56 []
57 }
58 });
59 }
60 Fields::Unnamed(unnamed) => {
61 let mut fields = vec![];
62 let mut field_names = vec![];
63 for (i, field) in unnamed.unnamed.iter().enumerate() {
64 let name = concat(&syn::parse_str::<Ident>("field").unwrap(), &i.to_string());
65 let ty = &field.ty;
66 fields.push(quote! {#name: #ty});
67 field_names.push(name);
68 }
69 methods.push(quote! {
70 async fn #function_name(&self, #(#fields,)*) -> [u8; 0] {
71 let operation = #enum_name::#variant_name(
72 #(#field_names,)*
73 );
74
75 self.runtime.schedule_operation(&operation);
76
77 []
78 }
79 });
80 }
81 Fields::Unit => {
82 methods.push(quote! {
83 async fn #function_name(&self) -> [u8; 0] {
84 let operation = #enum_name::#variant_name;
85
86 self.runtime.schedule_operation(&operation);
87
88 []
89 }
90 });
91 }
92 };
93 }
94
95 quote! {
96 pub struct #mutation_root_name<Application>
98 where
99 Application: #crate_root::Service,
100 #crate_root::ServiceRuntime<Application>: Send + Sync,
101 {
102 runtime: ::std::sync::Arc<#crate_root::ServiceRuntime<Application>>,
103 }
104
105 #[async_graphql::Object]
106 impl<Application> #mutation_root_name<Application>
107 where
108 Application: #crate_root::Service,
109 #crate_root::ServiceRuntime<Application>: Send + Sync,
110 {
111 #(#methods)*
112 }
113
114 impl<Application> #crate_root::graphql::GraphQLMutationRoot<Application> for #enum_name
115 where
116 Application: #crate_root::Service,
117 #crate_root::ServiceRuntime<Application>: Send + Sync,
118 {
119 type MutationRoot = #mutation_root_name<Application>;
120
121 fn mutation_root(
122 runtime: ::std::sync::Arc<#crate_root::ServiceRuntime<Application>>,
123 ) -> Self::MutationRoot {
124 #mutation_root_name { runtime }
125 }
126 }
127 }
128}
129
130#[cfg(test)]
131pub mod tests {
132 use syn::{parse_quote, ItemEnum, __private::quote::quote};
133
134 use crate::generate_mutation_root_code;
135
136 fn assert_eq_no_whitespace(mut actual: String, mut expected: String) {
137 println!("{}", actual);
139
140 actual.retain(|c| !c.is_whitespace());
141 expected.retain(|c| !c.is_whitespace());
142
143 assert_eq!(actual, expected);
144 }
145
146 #[test]
147 fn test_derive_mutation_root() {
148 let operation: ItemEnum = parse_quote! {
149 enum SomeOperation {
150 TupleVariant(String),
151 StructVariant {
152 a: u32,
153 b: u64
154 },
155 EmptyVariant
156 }
157 };
158
159 let output = generate_mutation_root_code(operation, "linera_sdk");
160
161 let expected = quote! {
162 pub struct SomeOperationMutationRoot<Application>
164 where
165 Application: linera_sdk::Service,
166 linera_sdk::ServiceRuntime<Application>: Send + Sync,
167 {
168 runtime: ::std::sync::Arc<linera_sdk::ServiceRuntime<Application>>,
169 }
170
171 #[async_graphql::Object]
172 impl<Application> SomeOperationMutationRoot<Application>
173 where
174 Application: linera_sdk::Service,
175 linera_sdk::ServiceRuntime<Application>: Send + Sync,
176 {
177 async fn tuple_variant(&self, field0: String,) -> [u8; 0] {
178 let operation = SomeOperation::TupleVariant(field0,);
179 self.runtime.schedule_operation(&operation);
180 []
181 }
182
183 async fn struct_variant(&self, a: u32, b: u64,) -> [u8; 0] {
184 let operation = SomeOperation::StructVariant { a, b, };
185 self.runtime.schedule_operation(&operation);
186 []
187 }
188
189 async fn empty_variant(&self) -> [u8; 0] {
190 let operation = SomeOperation::EmptyVariant;
191 self.runtime.schedule_operation(&operation);
192 []
193 }
194 }
195
196 impl<Application> linera_sdk::graphql::GraphQLMutationRoot<Application>
197 for SomeOperation
198 where
199 Application: linera_sdk::Service,
200 linera_sdk::ServiceRuntime<Application>: Send + Sync,
201 {
202 type MutationRoot = SomeOperationMutationRoot<Application>;
203
204 fn mutation_root(
205 runtime: ::std::sync::Arc<linera_sdk::ServiceRuntime<Application>>,
206 ) -> Self::MutationRoot {
207 SomeOperationMutationRoot { runtime }
208 }
209 }
210 };
211
212 assert_eq_no_whitespace(output.to_string(), expected.to_string());
213 }
214}