linera_views_derive/
lib.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The procedural macros for the crate `linera-views`.
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, parse_quote, ItemStruct, Type};
10
11#[derive(Debug, deluxe::ParseAttributes)]
12#[deluxe(attributes(view))]
13struct StructAttrs {
14    context: Option<syn::Type>,
15}
16
17struct Constraints<'a> {
18    input_constraints: Vec<&'a syn::WherePredicate>,
19    impl_generics: syn::ImplGenerics<'a>,
20    type_generics: syn::TypeGenerics<'a>,
21}
22
23impl<'a> Constraints<'a> {
24    fn get(item: &'a syn::ItemStruct) -> Self {
25        let (impl_generics, type_generics, maybe_where_clause) = item.generics.split_for_impl();
26        let input_constraints = maybe_where_clause
27            .map(|w| w.predicates.iter())
28            .into_iter()
29            .flatten()
30            .collect();
31
32        Self {
33            input_constraints,
34            impl_generics,
35            type_generics,
36        }
37    }
38}
39
40fn get_extended_entry(e: Type) -> TokenStream2 {
41    let syn::Type::Path(typepath) = e else {
42        panic!("The type should be a path");
43    };
44    let path_segment = typepath.path.segments.into_iter().next().unwrap();
45    let ident = path_segment.ident;
46    let arguments = path_segment.arguments;
47    quote! { #ident :: #arguments }
48}
49
50fn generate_view_code(input: ItemStruct, root: bool) -> TokenStream2 {
51    let Constraints {
52        input_constraints,
53        impl_generics,
54        type_generics,
55    } = Constraints::get(&input);
56
57    let attrs: StructAttrs = deluxe::parse_attributes(&input).unwrap();
58    let context = attrs.context.unwrap_or_else(|| {
59        let ident = &input
60            .generics
61            .type_params()
62            .next()
63            .expect("no `context` given and no type parameters")
64            .ident;
65        parse_quote! { #ident }
66    });
67
68    let struct_name = &input.ident;
69    let field_types: Vec<_> = input.fields.iter().map(|field| &field.ty).collect();
70
71    let mut name_quotes = Vec::new();
72    let mut rollback_quotes = Vec::new();
73    let mut flush_quotes = Vec::new();
74    let mut test_flush_quotes = Vec::new();
75    let mut clear_quotes = Vec::new();
76    let mut has_pending_changes_quotes = Vec::new();
77    let mut num_init_keys_quotes = Vec::new();
78    let mut pre_load_keys_quotes = Vec::new();
79    let mut post_load_keys_quotes = Vec::new();
80    for (idx, e) in input.fields.iter().enumerate() {
81        let name = e.ident.clone().unwrap();
82        let test_flush_ident = format_ident!("deleted{}", idx);
83        let idx_lit = syn::LitInt::new(&idx.to_string(), Span::call_site());
84        let g = get_extended_entry(e.ty.clone());
85        name_quotes.push(quote! { #name });
86        rollback_quotes.push(quote! { self.#name.rollback(); });
87        flush_quotes.push(quote! { let #test_flush_ident = self.#name.flush(batch)?; });
88        test_flush_quotes.push(quote! { #test_flush_ident });
89        clear_quotes.push(quote! { self.#name.clear(); });
90        has_pending_changes_quotes.push(quote! {
91            if self.#name.has_pending_changes().await {
92                return true;
93            }
94        });
95        num_init_keys_quotes.push(quote! { #g :: NUM_INIT_KEYS });
96        pre_load_keys_quotes.push(quote! {
97            let index = #idx_lit;
98            let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
99            keys.extend(#g :: pre_load(&context.clone_with_base_key(base_key))?);
100        });
101        post_load_keys_quotes.push(quote! {
102            let index = #idx_lit;
103            let pos_next = pos + #g :: NUM_INIT_KEYS;
104            let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
105            let #name = #g :: post_load(context.clone_with_base_key(base_key), &values[pos..pos_next])?;
106            pos = pos_next;
107        });
108    }
109
110    let first_name_quote = name_quotes
111        .first()
112        .expect("list of names should be non-empty");
113
114    let load_metrics = if root && cfg!(feature = "metrics") {
115        quote! {
116            #[cfg(not(target_arch = "wasm32"))]
117            linera_views::metrics::increment_counter(
118                &linera_views::metrics::LOAD_VIEW_COUNTER,
119                stringify!(#struct_name),
120                &context.base_key().bytes,
121            );
122            #[cfg(not(target_arch = "wasm32"))]
123            use linera_views::metrics::prometheus_util::MeasureLatency as _;
124            let _latency = linera_views::metrics::LOAD_VIEW_LATENCY.measure_latency();
125        }
126    } else {
127        quote! {}
128    };
129
130    quote! {
131        impl #impl_generics linera_views::views::View for #struct_name #type_generics
132        where
133            #context: linera_views::context::Context,
134            #(#input_constraints,)*
135            #(#field_types: linera_views::views::View<Context = #context>,)*
136        {
137            const NUM_INIT_KEYS: usize = #(<#field_types as linera_views::views::View>::NUM_INIT_KEYS)+*;
138
139            type Context = #context;
140
141            fn context(&self) -> &#context {
142                use linera_views::views::View;
143                self.#first_name_quote.context()
144            }
145
146            fn pre_load(context: &#context) -> Result<Vec<Vec<u8>>, linera_views::ViewError> {
147                use linera_views::context::Context as _;
148                let mut keys = Vec::new();
149                #(#pre_load_keys_quotes)*
150                Ok(keys)
151            }
152
153            fn post_load(context: #context, values: &[Option<Vec<u8>>]) -> Result<Self, linera_views::ViewError> {
154                use linera_views::context::Context as _;
155                let mut pos = 0;
156                #(#post_load_keys_quotes)*
157                Ok(Self {#(#name_quotes),*})
158            }
159
160            async fn load(context: #context) -> Result<Self, linera_views::ViewError> {
161                use linera_views::{context::Context as _, store::ReadableKeyValueStore as _};
162                #load_metrics
163                if Self::NUM_INIT_KEYS == 0 {
164                    Self::post_load(context, &[])
165                } else {
166                    let keys = Self::pre_load(&context)?;
167                    let values = context.store().read_multi_values_bytes(keys).await?;
168                    Self::post_load(context, &values)
169                }
170            }
171
172
173            fn rollback(&mut self) {
174                #(#rollback_quotes)*
175            }
176
177            async fn has_pending_changes(&self) -> bool {
178                #(#has_pending_changes_quotes)*
179                false
180            }
181
182            fn flush(&mut self, batch: &mut linera_views::batch::Batch) -> Result<bool, linera_views::ViewError> {
183                use linera_views::views::View;
184                #(#flush_quotes)*
185                Ok( #(#test_flush_quotes)&&* )
186            }
187
188            fn clear(&mut self) {
189                #(#clear_quotes)*
190            }
191        }
192    }
193}
194
195fn generate_root_view_code(input: ItemStruct) -> TokenStream2 {
196    let Constraints {
197        input_constraints,
198        impl_generics,
199        type_generics,
200    } = Constraints::get(&input);
201    let struct_name = &input.ident;
202
203    let increment_counter = if cfg!(feature = "metrics") {
204        quote! {
205            #[cfg(not(target_arch = "wasm32"))]
206            linera_views::metrics::increment_counter(
207                &linera_views::metrics::SAVE_VIEW_COUNTER,
208                stringify!(#struct_name),
209                &self.context().base_key().bytes,
210            );
211        }
212    } else {
213        quote! {}
214    };
215
216    quote! {
217        impl #impl_generics linera_views::views::RootView for #struct_name #type_generics
218        where
219            #(#input_constraints,)*
220            Self: linera_views::views::View,
221        {
222            async fn save(&mut self) -> Result<(), linera_views::ViewError> {
223                use linera_views::{context::Context, batch::Batch, store::WritableKeyValueStore as _, views::View};
224                #increment_counter
225                let mut batch = Batch::new();
226                self.flush(&mut batch)?;
227                if !batch.is_empty() {
228                    self.context().store().write_batch(batch).await?;
229                }
230                Ok(())
231            }
232        }
233    }
234}
235
236fn generate_hash_view_code(input: ItemStruct) -> TokenStream2 {
237    let Constraints {
238        input_constraints,
239        impl_generics,
240        type_generics,
241    } = Constraints::get(&input);
242    let struct_name = &input.ident;
243
244    let field_types = input.fields.iter().map(|field| &field.ty);
245    let mut field_hashes_mut = Vec::new();
246    let mut field_hashes = Vec::new();
247    for e in &input.fields {
248        let name = e.ident.as_ref().unwrap();
249        field_hashes_mut.push(quote! { hasher.write_all(self.#name.hash_mut().await?.as_ref())?; });
250        field_hashes.push(quote! { hasher.write_all(self.#name.hash().await?.as_ref())?; });
251    }
252
253    quote! {
254        impl #impl_generics linera_views::views::HashableView for #struct_name #type_generics
255        where
256            #(#field_types: linera_views::views::HashableView,)*
257            #(#input_constraints,)*
258            Self: linera_views::views::View,
259        {
260            type Hasher = linera_views::sha3::Sha3_256;
261
262            async fn hash_mut(&mut self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
263                use linera_views::views::{Hasher, HashableView};
264                use std::io::Write;
265                let mut hasher = Self::Hasher::default();
266                #(#field_hashes_mut)*
267                Ok(hasher.finalize())
268            }
269
270            async fn hash(&self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
271                use linera_views::views::{Hasher, HashableView};
272                use std::io::Write;
273                let mut hasher = Self::Hasher::default();
274                #(#field_hashes)*
275                Ok(hasher.finalize())
276            }
277        }
278    }
279}
280
281fn generate_crypto_hash_code(input: ItemStruct) -> TokenStream2 {
282    let Constraints {
283        input_constraints,
284        impl_generics,
285        type_generics,
286    } = Constraints::get(&input);
287    let field_types = input.fields.iter().map(|field| &field.ty);
288    let struct_name = &input.ident;
289    let hash_type = syn::Ident::new(&format!("{struct_name}Hash"), Span::call_site());
290    quote! {
291        impl #impl_generics linera_views::views::CryptoHashView
292        for #struct_name #type_generics
293        where
294            #(#field_types: linera_views::views::HashableView,)*
295            #(#input_constraints,)*
296            Self: linera_views::views::View,
297        {
298            async fn crypto_hash(&self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
299                use linera_base::crypto::{BcsHashable, CryptoHash};
300                use linera_views::{
301                    batch::Batch,
302                    generic_array::GenericArray,
303                    sha3::{digest::OutputSizeUser, Sha3_256},
304                    views::HashableView,
305                };
306                use serde::{Serialize, Deserialize};
307                #[derive(Serialize, Deserialize)]
308                struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
309                impl<'de> BcsHashable<'de> for #hash_type {}
310                let hash = self.hash().await?;
311                Ok(CryptoHash::new(&#hash_type(hash)))
312            }
313
314            async fn crypto_hash_mut(&mut self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
315                use linera_base::crypto::{BcsHashable, CryptoHash};
316                use linera_views::{
317                    batch::Batch,
318                    generic_array::GenericArray,
319                    sha3::{digest::OutputSizeUser, Sha3_256},
320                    views::HashableView,
321                };
322                use serde::{Serialize, Deserialize};
323                #[derive(Serialize, Deserialize)]
324                struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
325                impl<'de> BcsHashable<'de> for #hash_type {}
326                let hash = self.hash_mut().await?;
327                Ok(CryptoHash::new(&#hash_type(hash)))
328            }
329        }
330    }
331}
332
333fn generate_clonable_view_code(input: ItemStruct) -> TokenStream2 {
334    let Constraints {
335        input_constraints,
336        impl_generics,
337        type_generics,
338    } = Constraints::get(&input);
339    let struct_name = &input.ident;
340
341    let mut clone_constraints = vec![];
342    let mut clone_fields = vec![];
343
344    for field in &input.fields {
345        let name = &field.ident;
346        let ty = &field.ty;
347        clone_constraints.push(quote! { #ty: ClonableView });
348        clone_fields.push(quote! { #name: self.#name.clone_unchecked()? });
349    }
350
351    quote! {
352        impl #impl_generics linera_views::views::ClonableView for #struct_name #type_generics
353        where
354            #(#input_constraints,)*
355            #(#clone_constraints,)*
356            Self: linera_views::views::View,
357        {
358            fn clone_unchecked(&mut self) -> Result<Self, linera_views::ViewError> {
359                Ok(Self {
360                    #(#clone_fields,)*
361                })
362            }
363        }
364    }
365}
366
367#[proc_macro_derive(View, attributes(view))]
368pub fn derive_view(input: TokenStream) -> TokenStream {
369    let input = parse_macro_input!(input as ItemStruct);
370    generate_view_code(input, false).into()
371}
372
373#[proc_macro_derive(HashableView, attributes(view))]
374pub fn derive_hash_view(input: TokenStream) -> TokenStream {
375    let input = parse_macro_input!(input as ItemStruct);
376    let mut stream = generate_view_code(input.clone(), false);
377    stream.extend(generate_hash_view_code(input));
378    stream.into()
379}
380
381#[proc_macro_derive(RootView, attributes(view))]
382pub fn derive_root_view(input: TokenStream) -> TokenStream {
383    let input = parse_macro_input!(input as ItemStruct);
384    let mut stream = generate_view_code(input.clone(), true);
385    stream.extend(generate_root_view_code(input));
386    stream.into()
387}
388
389#[proc_macro_derive(CryptoHashView, attributes(view))]
390pub fn derive_crypto_hash_view(input: TokenStream) -> TokenStream {
391    let input = parse_macro_input!(input as ItemStruct);
392    let mut stream = generate_view_code(input.clone(), false);
393    stream.extend(generate_hash_view_code(input.clone()));
394    stream.extend(generate_crypto_hash_code(input));
395    stream.into()
396}
397
398#[proc_macro_derive(CryptoHashRootView, attributes(view))]
399pub fn derive_crypto_hash_root_view(input: TokenStream) -> TokenStream {
400    let input = parse_macro_input!(input as ItemStruct);
401    let mut stream = generate_view_code(input.clone(), true);
402    stream.extend(generate_root_view_code(input.clone()));
403    stream.extend(generate_hash_view_code(input.clone()));
404    stream.extend(generate_crypto_hash_code(input));
405    stream.into()
406}
407
408#[proc_macro_derive(HashableRootView, attributes(view))]
409#[cfg(test)]
410pub fn derive_hashable_root_view(input: TokenStream) -> TokenStream {
411    let input = parse_macro_input!(input as ItemStruct);
412    let mut stream = generate_view_code(input.clone(), true);
413    stream.extend(generate_root_view_code(input.clone()));
414    stream.extend(generate_hash_view_code(input));
415    stream.into()
416}
417
418#[proc_macro_derive(ClonableView, attributes(view))]
419pub fn derive_clonable_view(input: TokenStream) -> TokenStream {
420    let input = parse_macro_input!(input as ItemStruct);
421    generate_clonable_view_code(input).into()
422}
423
424#[cfg(test)]
425pub mod tests {
426
427    use quote::quote;
428    use syn::{parse_quote, AngleBracketedGenericArguments};
429
430    use crate::*;
431
432    fn pretty(tokens: TokenStream2) -> String {
433        prettyplease::unparse(
434            &syn::parse2::<syn::File>(tokens).expect("failed to parse test output"),
435        )
436    }
437
438    #[test]
439    fn test_generate_view_code() {
440        for context in SpecificContextInfo::test_cases() {
441            let input = context.test_view_input();
442            insta::assert_snapshot!(
443                format!(
444                    "test_generate_view_code{}_{}",
445                    if cfg!(feature = "metrics") {
446                        "_metrics"
447                    } else {
448                        ""
449                    },
450                    context.name,
451                ),
452                pretty(generate_view_code(input, true))
453            );
454        }
455    }
456
457    #[test]
458    fn test_generate_hash_view_code() {
459        for context in SpecificContextInfo::test_cases() {
460            let input = context.test_view_input();
461            insta::assert_snapshot!(
462                format!("test_generate_hash_view_code_{}", context.name),
463                pretty(generate_hash_view_code(input))
464            );
465        }
466    }
467
468    #[test]
469    fn test_generate_root_view_code() {
470        for context in SpecificContextInfo::test_cases() {
471            let input = context.test_view_input();
472            insta::assert_snapshot!(
473                format!(
474                    "test_generate_root_view_code{}_{}",
475                    if cfg!(feature = "metrics") {
476                        "_metrics"
477                    } else {
478                        ""
479                    },
480                    context.name,
481                ),
482                pretty(generate_root_view_code(input))
483            );
484        }
485    }
486
487    #[test]
488    fn test_generate_crypto_hash_code() {
489        for context in SpecificContextInfo::test_cases() {
490            let input = context.test_view_input();
491            insta::assert_snapshot!(pretty(generate_crypto_hash_code(input)));
492        }
493    }
494
495    #[test]
496    fn test_generate_clonable_view_code() {
497        for context in SpecificContextInfo::test_cases() {
498            let input = context.test_view_input();
499            insta::assert_snapshot!(pretty(generate_clonable_view_code(input)));
500        }
501    }
502
503    #[derive(Clone)]
504    pub struct SpecificContextInfo {
505        name: String,
506        attribute: Option<TokenStream2>,
507        context: Type,
508        generics: AngleBracketedGenericArguments,
509        where_clause: Option<TokenStream2>,
510    }
511
512    impl SpecificContextInfo {
513        pub fn empty() -> Self {
514            SpecificContextInfo {
515                name: "C".to_string(),
516                attribute: None,
517                context: syn::parse_quote! { C },
518                generics: syn::parse_quote! { <C> },
519                where_clause: None,
520            }
521        }
522
523        pub fn new(context: syn::Type) -> Self {
524            let name = quote! { #context };
525            SpecificContextInfo {
526                name: format!("{name}")
527                    .replace(' ', "")
528                    .replace([':', '<', '>'], "_"),
529                attribute: Some(quote! { #[view(context = #context)] }),
530                context,
531                generics: parse_quote! { <> },
532                where_clause: None,
533            }
534        }
535
536        /// Sets the `where_clause` to a dummy value for test cases with a where clause.
537        ///
538        /// Also adds a `MyParam` generic type parameter to the `generics` field, which is the type
539        /// constrained by the dummy predicate in the `where_clause`.
540        pub fn with_dummy_where_clause(mut self) -> Self {
541            self.generics.args.push(parse_quote! { MyParam });
542            self.where_clause = Some(quote! {
543                where MyParam: Send + Sync + 'static,
544            });
545            self.name.push_str("_with_where");
546
547            self
548        }
549
550        pub fn test_cases() -> impl Iterator<Item = Self> {
551            Some(Self::empty())
552                .into_iter()
553                .chain(
554                    [
555                        syn::parse_quote! { CustomContext },
556                        syn::parse_quote! { custom::path::to::ContextType },
557                        syn::parse_quote! { custom::GenericContext<T> },
558                    ]
559                    .into_iter()
560                    .map(Self::new),
561                )
562                .flat_map(|case| [case.clone(), case.with_dummy_where_clause()])
563        }
564
565        pub fn test_view_input(&self) -> ItemStruct {
566            let SpecificContextInfo {
567                attribute,
568                context,
569                generics,
570                where_clause,
571                ..
572            } = self;
573
574            parse_quote! {
575                #attribute
576                struct TestView #generics
577                #where_clause
578                {
579                    register: RegisterView<#context, usize>,
580                    collection: CollectionView<#context, usize, RegisterView<#context, usize>>,
581                }
582            }
583        }
584    }
585}