1use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, parse_quote, ItemStruct, Type};
10
11#[derive(Debug, deluxe::ParseAttributes)]
12#[deluxe(attributes(view))]
13struct StructAttrs {
14 context: Option<syn::Type>,
15}
16
17struct Constraints<'a> {
18 input_constraints: Vec<&'a syn::WherePredicate>,
19 impl_generics: syn::ImplGenerics<'a>,
20 type_generics: syn::TypeGenerics<'a>,
21}
22
23impl<'a> Constraints<'a> {
24 fn get(item: &'a syn::ItemStruct) -> Self {
25 let (impl_generics, type_generics, maybe_where_clause) = item.generics.split_for_impl();
26 let input_constraints = maybe_where_clause
27 .map(|w| w.predicates.iter())
28 .into_iter()
29 .flatten()
30 .collect();
31
32 Self {
33 input_constraints,
34 impl_generics,
35 type_generics,
36 }
37 }
38}
39
40fn get_extended_entry(e: Type) -> TokenStream2 {
41 let syn::Type::Path(typepath) = e else {
42 panic!("The type should be a path");
43 };
44 let path_segment = typepath.path.segments.into_iter().next().unwrap();
45 let ident = path_segment.ident;
46 let arguments = path_segment.arguments;
47 quote! { #ident :: #arguments }
48}
49
50fn generate_view_code(input: ItemStruct, root: bool) -> TokenStream2 {
51 let Constraints {
52 input_constraints,
53 impl_generics,
54 type_generics,
55 } = Constraints::get(&input);
56
57 let attrs: StructAttrs = deluxe::parse_attributes(&input).unwrap();
58 let context = attrs.context.unwrap_or_else(|| {
59 let ident = &input
60 .generics
61 .type_params()
62 .next()
63 .expect("no `context` given and no type parameters")
64 .ident;
65 parse_quote! { #ident }
66 });
67
68 let struct_name = &input.ident;
69 let field_types: Vec<_> = input.fields.iter().map(|field| &field.ty).collect();
70
71 let mut name_quotes = Vec::new();
72 let mut rollback_quotes = Vec::new();
73 let mut flush_quotes = Vec::new();
74 let mut test_flush_quotes = Vec::new();
75 let mut clear_quotes = Vec::new();
76 let mut has_pending_changes_quotes = Vec::new();
77 let mut num_init_keys_quotes = Vec::new();
78 let mut pre_load_keys_quotes = Vec::new();
79 let mut post_load_keys_quotes = Vec::new();
80 for (idx, e) in input.fields.iter().enumerate() {
81 let name = e.ident.clone().unwrap();
82 let test_flush_ident = format_ident!("deleted{}", idx);
83 let idx_lit = syn::LitInt::new(&idx.to_string(), Span::call_site());
84 let g = get_extended_entry(e.ty.clone());
85 name_quotes.push(quote! { #name });
86 rollback_quotes.push(quote! { self.#name.rollback(); });
87 flush_quotes.push(quote! { let #test_flush_ident = self.#name.flush(batch)?; });
88 test_flush_quotes.push(quote! { #test_flush_ident });
89 clear_quotes.push(quote! { self.#name.clear(); });
90 has_pending_changes_quotes.push(quote! {
91 if self.#name.has_pending_changes().await {
92 return true;
93 }
94 });
95 num_init_keys_quotes.push(quote! { #g :: NUM_INIT_KEYS });
96 pre_load_keys_quotes.push(quote! {
97 let index = #idx_lit;
98 let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
99 keys.extend(#g :: pre_load(&context.clone_with_base_key(base_key))?);
100 });
101 post_load_keys_quotes.push(quote! {
102 let index = #idx_lit;
103 let pos_next = pos + #g :: NUM_INIT_KEYS;
104 let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
105 let #name = #g :: post_load(context.clone_with_base_key(base_key), &values[pos..pos_next])?;
106 pos = pos_next;
107 });
108 }
109
110 let first_name_quote = name_quotes
111 .first()
112 .expect("list of names should be non-empty");
113
114 let load_metrics = if root && cfg!(feature = "metrics") {
115 quote! {
116 #[cfg(not(target_arch = "wasm32"))]
117 linera_views::metrics::increment_counter(
118 &linera_views::metrics::LOAD_VIEW_COUNTER,
119 stringify!(#struct_name),
120 &context.base_key().bytes,
121 );
122 #[cfg(not(target_arch = "wasm32"))]
123 use linera_views::metrics::prometheus_util::MeasureLatency as _;
124 let _latency = linera_views::metrics::LOAD_VIEW_LATENCY.measure_latency();
125 }
126 } else {
127 quote! {}
128 };
129
130 quote! {
131 impl #impl_generics linera_views::views::View for #struct_name #type_generics
132 where
133 #context: linera_views::context::Context,
134 #(#input_constraints,)*
135 #(#field_types: linera_views::views::View<Context = #context>,)*
136 {
137 const NUM_INIT_KEYS: usize = #(<#field_types as linera_views::views::View>::NUM_INIT_KEYS)+*;
138
139 type Context = #context;
140
141 fn context(&self) -> &#context {
142 use linera_views::views::View;
143 self.#first_name_quote.context()
144 }
145
146 fn pre_load(context: &#context) -> Result<Vec<Vec<u8>>, linera_views::ViewError> {
147 use linera_views::context::Context as _;
148 let mut keys = Vec::new();
149 #(#pre_load_keys_quotes)*
150 Ok(keys)
151 }
152
153 fn post_load(context: #context, values: &[Option<Vec<u8>>]) -> Result<Self, linera_views::ViewError> {
154 use linera_views::context::Context as _;
155 let mut pos = 0;
156 #(#post_load_keys_quotes)*
157 Ok(Self {#(#name_quotes),*})
158 }
159
160 async fn load(context: #context) -> Result<Self, linera_views::ViewError> {
161 use linera_views::{context::Context as _, store::ReadableKeyValueStore as _};
162 #load_metrics
163 if Self::NUM_INIT_KEYS == 0 {
164 Self::post_load(context, &[])
165 } else {
166 let keys = Self::pre_load(&context)?;
167 let values = context.store().read_multi_values_bytes(keys).await?;
168 Self::post_load(context, &values)
169 }
170 }
171
172
173 fn rollback(&mut self) {
174 #(#rollback_quotes)*
175 }
176
177 async fn has_pending_changes(&self) -> bool {
178 #(#has_pending_changes_quotes)*
179 false
180 }
181
182 fn flush(&mut self, batch: &mut linera_views::batch::Batch) -> Result<bool, linera_views::ViewError> {
183 use linera_views::views::View;
184 #(#flush_quotes)*
185 Ok( #(#test_flush_quotes)&&* )
186 }
187
188 fn clear(&mut self) {
189 #(#clear_quotes)*
190 }
191 }
192 }
193}
194
195fn generate_root_view_code(input: ItemStruct) -> TokenStream2 {
196 let Constraints {
197 input_constraints,
198 impl_generics,
199 type_generics,
200 } = Constraints::get(&input);
201 let struct_name = &input.ident;
202
203 let increment_counter = if cfg!(feature = "metrics") {
204 quote! {
205 #[cfg(not(target_arch = "wasm32"))]
206 linera_views::metrics::increment_counter(
207 &linera_views::metrics::SAVE_VIEW_COUNTER,
208 stringify!(#struct_name),
209 &self.context().base_key().bytes,
210 );
211 }
212 } else {
213 quote! {}
214 };
215
216 quote! {
217 impl #impl_generics linera_views::views::RootView for #struct_name #type_generics
218 where
219 #(#input_constraints,)*
220 Self: linera_views::views::View,
221 {
222 async fn save(&mut self) -> Result<(), linera_views::ViewError> {
223 use linera_views::{context::Context, batch::Batch, store::WritableKeyValueStore as _, views::View};
224 #increment_counter
225 let mut batch = Batch::new();
226 self.flush(&mut batch)?;
227 if !batch.is_empty() {
228 self.context().store().write_batch(batch).await?;
229 }
230 Ok(())
231 }
232 }
233 }
234}
235
236fn generate_hash_view_code(input: ItemStruct) -> TokenStream2 {
237 let Constraints {
238 input_constraints,
239 impl_generics,
240 type_generics,
241 } = Constraints::get(&input);
242 let struct_name = &input.ident;
243
244 let field_types = input.fields.iter().map(|field| &field.ty);
245 let mut field_hashes_mut = Vec::new();
246 let mut field_hashes = Vec::new();
247 for e in &input.fields {
248 let name = e.ident.as_ref().unwrap();
249 field_hashes_mut.push(quote! { hasher.write_all(self.#name.hash_mut().await?.as_ref())?; });
250 field_hashes.push(quote! { hasher.write_all(self.#name.hash().await?.as_ref())?; });
251 }
252
253 quote! {
254 impl #impl_generics linera_views::views::HashableView for #struct_name #type_generics
255 where
256 #(#field_types: linera_views::views::HashableView,)*
257 #(#input_constraints,)*
258 Self: linera_views::views::View,
259 {
260 type Hasher = linera_views::sha3::Sha3_256;
261
262 async fn hash_mut(&mut self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
263 use linera_views::views::{Hasher, HashableView};
264 use std::io::Write;
265 let mut hasher = Self::Hasher::default();
266 #(#field_hashes_mut)*
267 Ok(hasher.finalize())
268 }
269
270 async fn hash(&self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
271 use linera_views::views::{Hasher, HashableView};
272 use std::io::Write;
273 let mut hasher = Self::Hasher::default();
274 #(#field_hashes)*
275 Ok(hasher.finalize())
276 }
277 }
278 }
279}
280
281fn generate_crypto_hash_code(input: ItemStruct) -> TokenStream2 {
282 let Constraints {
283 input_constraints,
284 impl_generics,
285 type_generics,
286 } = Constraints::get(&input);
287 let field_types = input.fields.iter().map(|field| &field.ty);
288 let struct_name = &input.ident;
289 let hash_type = syn::Ident::new(&format!("{struct_name}Hash"), Span::call_site());
290 quote! {
291 impl #impl_generics linera_views::views::CryptoHashView
292 for #struct_name #type_generics
293 where
294 #(#field_types: linera_views::views::HashableView,)*
295 #(#input_constraints,)*
296 Self: linera_views::views::View,
297 {
298 async fn crypto_hash(&self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
299 use linera_base::crypto::{BcsHashable, CryptoHash};
300 use linera_views::{
301 batch::Batch,
302 generic_array::GenericArray,
303 sha3::{digest::OutputSizeUser, Sha3_256},
304 views::HashableView,
305 };
306 use serde::{Serialize, Deserialize};
307 #[derive(Serialize, Deserialize)]
308 struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
309 impl<'de> BcsHashable<'de> for #hash_type {}
310 let hash = self.hash().await?;
311 Ok(CryptoHash::new(&#hash_type(hash)))
312 }
313
314 async fn crypto_hash_mut(&mut self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
315 use linera_base::crypto::{BcsHashable, CryptoHash};
316 use linera_views::{
317 batch::Batch,
318 generic_array::GenericArray,
319 sha3::{digest::OutputSizeUser, Sha3_256},
320 views::HashableView,
321 };
322 use serde::{Serialize, Deserialize};
323 #[derive(Serialize, Deserialize)]
324 struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
325 impl<'de> BcsHashable<'de> for #hash_type {}
326 let hash = self.hash_mut().await?;
327 Ok(CryptoHash::new(&#hash_type(hash)))
328 }
329 }
330 }
331}
332
333fn generate_clonable_view_code(input: ItemStruct) -> TokenStream2 {
334 let Constraints {
335 input_constraints,
336 impl_generics,
337 type_generics,
338 } = Constraints::get(&input);
339 let struct_name = &input.ident;
340
341 let mut clone_constraints = vec![];
342 let mut clone_fields = vec![];
343
344 for field in &input.fields {
345 let name = &field.ident;
346 let ty = &field.ty;
347 clone_constraints.push(quote! { #ty: ClonableView });
348 clone_fields.push(quote! { #name: self.#name.clone_unchecked()? });
349 }
350
351 quote! {
352 impl #impl_generics linera_views::views::ClonableView for #struct_name #type_generics
353 where
354 #(#input_constraints,)*
355 #(#clone_constraints,)*
356 Self: linera_views::views::View,
357 {
358 fn clone_unchecked(&mut self) -> Result<Self, linera_views::ViewError> {
359 Ok(Self {
360 #(#clone_fields,)*
361 })
362 }
363 }
364 }
365}
366
367#[proc_macro_derive(View, attributes(view))]
368pub fn derive_view(input: TokenStream) -> TokenStream {
369 let input = parse_macro_input!(input as ItemStruct);
370 generate_view_code(input, false).into()
371}
372
373#[proc_macro_derive(HashableView, attributes(view))]
374pub fn derive_hash_view(input: TokenStream) -> TokenStream {
375 let input = parse_macro_input!(input as ItemStruct);
376 let mut stream = generate_view_code(input.clone(), false);
377 stream.extend(generate_hash_view_code(input));
378 stream.into()
379}
380
381#[proc_macro_derive(RootView, attributes(view))]
382pub fn derive_root_view(input: TokenStream) -> TokenStream {
383 let input = parse_macro_input!(input as ItemStruct);
384 let mut stream = generate_view_code(input.clone(), true);
385 stream.extend(generate_root_view_code(input));
386 stream.into()
387}
388
389#[proc_macro_derive(CryptoHashView, attributes(view))]
390pub fn derive_crypto_hash_view(input: TokenStream) -> TokenStream {
391 let input = parse_macro_input!(input as ItemStruct);
392 let mut stream = generate_view_code(input.clone(), false);
393 stream.extend(generate_hash_view_code(input.clone()));
394 stream.extend(generate_crypto_hash_code(input));
395 stream.into()
396}
397
398#[proc_macro_derive(CryptoHashRootView, attributes(view))]
399pub fn derive_crypto_hash_root_view(input: TokenStream) -> TokenStream {
400 let input = parse_macro_input!(input as ItemStruct);
401 let mut stream = generate_view_code(input.clone(), true);
402 stream.extend(generate_root_view_code(input.clone()));
403 stream.extend(generate_hash_view_code(input.clone()));
404 stream.extend(generate_crypto_hash_code(input));
405 stream.into()
406}
407
408#[proc_macro_derive(HashableRootView, attributes(view))]
409#[cfg(test)]
410pub fn derive_hashable_root_view(input: TokenStream) -> TokenStream {
411 let input = parse_macro_input!(input as ItemStruct);
412 let mut stream = generate_view_code(input.clone(), true);
413 stream.extend(generate_root_view_code(input.clone()));
414 stream.extend(generate_hash_view_code(input));
415 stream.into()
416}
417
418#[proc_macro_derive(ClonableView, attributes(view))]
419pub fn derive_clonable_view(input: TokenStream) -> TokenStream {
420 let input = parse_macro_input!(input as ItemStruct);
421 generate_clonable_view_code(input).into()
422}
423
424#[cfg(test)]
425pub mod tests {
426
427 use quote::quote;
428 use syn::{parse_quote, AngleBracketedGenericArguments};
429
430 use crate::*;
431
432 fn pretty(tokens: TokenStream2) -> String {
433 prettyplease::unparse(
434 &syn::parse2::<syn::File>(tokens).expect("failed to parse test output"),
435 )
436 }
437
438 #[test]
439 fn test_generate_view_code() {
440 for context in SpecificContextInfo::test_cases() {
441 let input = context.test_view_input();
442 insta::assert_snapshot!(
443 format!(
444 "test_generate_view_code{}_{}",
445 if cfg!(feature = "metrics") {
446 "_metrics"
447 } else {
448 ""
449 },
450 context.name,
451 ),
452 pretty(generate_view_code(input, true))
453 );
454 }
455 }
456
457 #[test]
458 fn test_generate_hash_view_code() {
459 for context in SpecificContextInfo::test_cases() {
460 let input = context.test_view_input();
461 insta::assert_snapshot!(
462 format!("test_generate_hash_view_code_{}", context.name),
463 pretty(generate_hash_view_code(input))
464 );
465 }
466 }
467
468 #[test]
469 fn test_generate_root_view_code() {
470 for context in SpecificContextInfo::test_cases() {
471 let input = context.test_view_input();
472 insta::assert_snapshot!(
473 format!(
474 "test_generate_root_view_code{}_{}",
475 if cfg!(feature = "metrics") {
476 "_metrics"
477 } else {
478 ""
479 },
480 context.name,
481 ),
482 pretty(generate_root_view_code(input))
483 );
484 }
485 }
486
487 #[test]
488 fn test_generate_crypto_hash_code() {
489 for context in SpecificContextInfo::test_cases() {
490 let input = context.test_view_input();
491 insta::assert_snapshot!(pretty(generate_crypto_hash_code(input)));
492 }
493 }
494
495 #[test]
496 fn test_generate_clonable_view_code() {
497 for context in SpecificContextInfo::test_cases() {
498 let input = context.test_view_input();
499 insta::assert_snapshot!(pretty(generate_clonable_view_code(input)));
500 }
501 }
502
503 #[derive(Clone)]
504 pub struct SpecificContextInfo {
505 name: String,
506 attribute: Option<TokenStream2>,
507 context: Type,
508 generics: AngleBracketedGenericArguments,
509 where_clause: Option<TokenStream2>,
510 }
511
512 impl SpecificContextInfo {
513 pub fn empty() -> Self {
514 SpecificContextInfo {
515 name: "C".to_string(),
516 attribute: None,
517 context: syn::parse_quote! { C },
518 generics: syn::parse_quote! { <C> },
519 where_clause: None,
520 }
521 }
522
523 pub fn new(context: syn::Type) -> Self {
524 let name = quote! { #context };
525 SpecificContextInfo {
526 name: format!("{name}")
527 .replace(' ', "")
528 .replace([':', '<', '>'], "_"),
529 attribute: Some(quote! { #[view(context = #context)] }),
530 context,
531 generics: parse_quote! { <> },
532 where_clause: None,
533 }
534 }
535
536 pub fn with_dummy_where_clause(mut self) -> Self {
541 self.generics.args.push(parse_quote! { MyParam });
542 self.where_clause = Some(quote! {
543 where MyParam: Send + Sync + 'static,
544 });
545 self.name.push_str("_with_where");
546
547 self
548 }
549
550 pub fn test_cases() -> impl Iterator<Item = Self> {
551 Some(Self::empty())
552 .into_iter()
553 .chain(
554 [
555 syn::parse_quote! { CustomContext },
556 syn::parse_quote! { custom::path::to::ContextType },
557 syn::parse_quote! { custom::GenericContext<T> },
558 ]
559 .into_iter()
560 .map(Self::new),
561 )
562 .flat_map(|case| [case.clone(), case.with_dummy_where_clause()])
563 }
564
565 pub fn test_view_input(&self) -> ItemStruct {
566 let SpecificContextInfo {
567 attribute,
568 context,
569 generics,
570 where_clause,
571 ..
572 } = self;
573
574 parse_quote! {
575 #attribute
576 struct TestView #generics
577 #where_clause
578 {
579 register: RegisterView<#context, usize>,
580 collection: CollectionView<#context, usize, RegisterView<#context, usize>>,
581 }
582 }
583 }
584 }
585}