ff_derive/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5
6use num_bigint::BigUint;
7use num_integer::Integer;
8use num_traits::{One, ToPrimitive, Zero};
9use quote::quote;
10use quote::TokenStreamExt;
11use std::iter;
12use std::str::FromStr;
13
14mod pow_fixed;
15
16enum ReprEndianness {
17    Big,
18    Little,
19}
20
21impl FromStr for ReprEndianness {
22    type Err = ();
23
24    fn from_str(s: &str) -> Result<Self, Self::Err> {
25        match s {
26            "big" => Ok(ReprEndianness::Big),
27            "little" => Ok(ReprEndianness::Little),
28            _ => Err(()),
29        }
30    }
31}
32
33impl ReprEndianness {
34    fn modulus_repr(&self, modulus: &BigUint, bytes: usize) -> Vec<u8> {
35        match self {
36            ReprEndianness::Big => {
37                let buf = modulus.to_bytes_be();
38                iter::repeat(0)
39                    .take(bytes - buf.len())
40                    .chain(buf.into_iter())
41                    .collect()
42            }
43            ReprEndianness::Little => {
44                let mut buf = modulus.to_bytes_le();
45                buf.extend(iter::repeat(0).take(bytes - buf.len()));
46                buf
47            }
48        }
49    }
50
51    fn from_repr(&self, name: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
52        let read_repr = match self {
53            ReprEndianness::Big => quote! {
54                ::ff::derive::byteorder::BigEndian::read_u64_into(r.as_ref(), &mut inner[..]);
55                inner.reverse();
56            },
57            ReprEndianness::Little => quote! {
58                ::ff::derive::byteorder::LittleEndian::read_u64_into(r.as_ref(), &mut inner[..]);
59            },
60        };
61
62        quote! {
63            use ::ff::derive::byteorder::ByteOrder;
64
65            let r = {
66                let mut inner = [0u64; #limbs];
67                #read_repr
68                #name(inner)
69            };
70        }
71    }
72
73    fn to_repr(
74        &self,
75        repr: proc_macro2::TokenStream,
76        mont_reduce_self_params: &proc_macro2::TokenStream,
77        limbs: usize,
78    ) -> proc_macro2::TokenStream {
79        let bytes = limbs * 8;
80
81        let write_repr = match self {
82            ReprEndianness::Big => quote! {
83                r.0.reverse();
84                ::ff::derive::byteorder::BigEndian::write_u64_into(&r.0, &mut repr[..]);
85            },
86            ReprEndianness::Little => quote! {
87                ::ff::derive::byteorder::LittleEndian::write_u64_into(&r.0, &mut repr[..]);
88            },
89        };
90
91        quote! {
92            use ::ff::derive::byteorder::ByteOrder;
93
94            let mut r = *self;
95            r.mont_reduce(
96                #mont_reduce_self_params
97            );
98
99            let mut repr = [0u8; #bytes];
100            #write_repr
101            #repr(repr)
102        }
103    }
104
105    fn iter_be(&self) -> proc_macro2::TokenStream {
106        match self {
107            ReprEndianness::Big => quote! {self.0.iter()},
108            ReprEndianness::Little => quote! {self.0.iter().rev()},
109        }
110    }
111}
112
113/// Derive the `PrimeField` trait.
114#[proc_macro_derive(
115    PrimeField,
116    attributes(PrimeFieldModulus, PrimeFieldGenerator, PrimeFieldReprEndianness)
117)]
118pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
119    // Parse the type definition
120    let ast: syn::DeriveInput = syn::parse(input).unwrap();
121
122    // We're given the modulus p of the prime field
123    let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs)
124        .expect("Please supply a PrimeFieldModulus attribute")
125        .parse()
126        .expect("PrimeFieldModulus should be a number");
127
128    // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
129    // nonresidue.
130    // TODO: Compute this ourselves.
131    let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
132        .expect("Please supply a PrimeFieldGenerator attribute")
133        .parse()
134        .expect("PrimeFieldGenerator should be a number");
135
136    // Field element representations may be in little-endian or big-endian.
137    let endianness = fetch_attr("PrimeFieldReprEndianness", &ast.attrs)
138        .expect("Please supply a PrimeFieldReprEndianness attribute")
139        .parse()
140        .expect("PrimeFieldReprEndianness should be 'big' or 'little'");
141
142    // The arithmetic in this library only works if the modulus*2 is smaller than the backing
143    // representation. Compute the number of limbs we need.
144    let mut limbs = 1;
145    {
146        let mod2 = (&modulus) << 1; // modulus * 2
147        let mut cur = BigUint::one() << 64; // always 64-bit limbs for now
148        while cur < mod2 {
149            limbs += 1;
150            cur <<= 64;
151        }
152    }
153
154    // The struct we're deriving for must be a wrapper around `pub [u64; limbs]`.
155    if let Some(err) = validate_struct(&ast, limbs) {
156        return err.into();
157    }
158
159    // Generate the identifier for the "Repr" type we must construct.
160    let repr_ident = syn::Ident::new(
161        &format!("{}Repr", ast.ident),
162        proc_macro2::Span::call_site(),
163    );
164
165    let mut gen = proc_macro2::TokenStream::new();
166
167    let (constants_impl, sqrt_impl) =
168        prime_field_constants_and_sqrt(&ast.ident, &modulus, limbs, generator);
169
170    gen.extend(constants_impl);
171    gen.extend(prime_field_repr_impl(&repr_ident, &endianness, limbs * 8));
172    gen.extend(prime_field_impl(
173        &ast.ident,
174        &repr_ident,
175        &modulus,
176        &endianness,
177        limbs,
178        sqrt_impl,
179    ));
180
181    // Return the generated impl
182    gen.into()
183}
184
185/// Checks that `body` contains `pub [u64; limbs]`.
186fn validate_struct(ast: &syn::DeriveInput, limbs: usize) -> Option<proc_macro2::TokenStream> {
187    // The body should be a struct.
188    let variant_data = match &ast.data {
189        syn::Data::Struct(x) => x,
190        _ => {
191            return Some(
192                syn::Error::new_spanned(ast, "PrimeField derive only works for structs.")
193                    .to_compile_error(),
194            )
195        }
196    };
197
198    // The struct should contain a single unnamed field.
199    let fields = match &variant_data.fields {
200        syn::Fields::Unnamed(x) if x.unnamed.len() == 1 => x,
201        _ => {
202            return Some(
203                syn::Error::new_spanned(
204                    &ast.ident,
205                    format!(
206                        "The struct must contain an array of limbs. Change this to `{}([u64; {}])`",
207                        ast.ident, limbs,
208                    ),
209                )
210                .to_compile_error(),
211            )
212        }
213    };
214    let field = &fields.unnamed[0];
215
216    // The field should be an array.
217    let arr = match &field.ty {
218        syn::Type::Array(x) => x,
219        _ => {
220            return Some(
221                syn::Error::new_spanned(
222                    field,
223                    format!(
224                        "The inner field must be an array of limbs. Change this to `[u64; {}]`",
225                        limbs,
226                    ),
227                )
228                .to_compile_error(),
229            )
230        }
231    };
232
233    // The array's element type should be `u64`.
234    if match arr.elem.as_ref() {
235        syn::Type::Path(path) => path
236            .path
237            .get_ident()
238            .map(|x| x.to_string() != "u64")
239            .unwrap_or(true),
240        _ => true,
241    } {
242        return Some(
243            syn::Error::new_spanned(
244                arr,
245                format!(
246                    "PrimeField derive requires 64-bit limbs. Change this to `[u64; {}]",
247                    limbs
248                ),
249            )
250            .to_compile_error(),
251        );
252    }
253
254    // The array's length should be a literal int equal to `limbs`.
255    let expr_lit = match &arr.len {
256        syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
257        syn::Expr::Group(expr_group) => match &*expr_group.expr {
258            syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
259            _ => None,
260        },
261        _ => None,
262    };
263    let lit_int = match match expr_lit {
264        Some(syn::Lit::Int(lit_int)) => Some(lit_int),
265        _ => None,
266    } {
267        Some(x) => x,
268        _ => {
269            return Some(
270                syn::Error::new_spanned(
271                    arr,
272                    format!("To derive PrimeField, change this to `[u64; {}]`.", limbs),
273                )
274                .to_compile_error(),
275            )
276        }
277    };
278    if lit_int.base10_digits() != limbs.to_string() {
279        return Some(
280            syn::Error::new_spanned(
281                lit_int,
282                format!("The given modulus requires {} limbs.", limbs),
283            )
284            .to_compile_error(),
285        );
286    }
287
288    // The field should not be public.
289    match &field.vis {
290        syn::Visibility::Inherited => (),
291        _ => {
292            return Some(
293                syn::Error::new_spanned(&field.vis, "Field must not be public.").to_compile_error(),
294            )
295        }
296    }
297
298    // Valid!
299    None
300}
301
302/// Fetch an attribute string from the derived struct.
303fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
304    for attr in attrs {
305        if let Ok(meta) = attr.parse_meta() {
306            match meta {
307                syn::Meta::NameValue(nv) => {
308                    if nv.path.get_ident().map(|i| i.to_string()) == Some(name.to_string()) {
309                        match nv.lit {
310                            syn::Lit::Str(ref s) => return Some(s.value()),
311                            _ => {
312                                panic!("attribute {} should be a string", name);
313                            }
314                        }
315                    }
316                }
317                _ => {
318                    panic!("attribute {} should be a string", name);
319                }
320            }
321        }
322    }
323
324    None
325}
326
327// Implement the wrapped ident `repr` with `bytes` bytes.
328fn prime_field_repr_impl(
329    repr: &syn::Ident,
330    endianness: &ReprEndianness,
331    bytes: usize,
332) -> proc_macro2::TokenStream {
333    let repr_iter_be = endianness.iter_be();
334
335    quote! {
336        #[derive(Copy, Clone)]
337        pub struct #repr(pub [u8; #bytes]);
338
339        impl ::ff::derive::subtle::ConstantTimeEq for #repr {
340            fn ct_eq(&self, other: &#repr) -> ::ff::derive::subtle::Choice {
341                self.0
342                    .iter()
343                    .zip(other.0.iter())
344                    .map(|(a, b)| a.ct_eq(b))
345                    .fold(1.into(), |acc, x| acc & x)
346            }
347        }
348
349        impl ::core::cmp::PartialEq for #repr {
350            fn eq(&self, other: &#repr) -> bool {
351                use ::ff::derive::subtle::ConstantTimeEq;
352                self.ct_eq(other).into()
353            }
354        }
355
356        impl ::core::cmp::Eq for #repr { }
357
358        impl ::core::default::Default for #repr {
359            fn default() -> #repr {
360                #repr([0u8; #bytes])
361            }
362        }
363
364        impl ::core::fmt::Debug for #repr
365        {
366            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
367                write!(f, "0x")?;
368                for i in #repr_iter_be {
369                    write!(f, "{:02x}", *i)?;
370                }
371
372                Ok(())
373            }
374        }
375
376        impl AsRef<[u8]> for #repr {
377            #[inline(always)]
378            fn as_ref(&self) -> &[u8] {
379                &self.0
380            }
381        }
382
383        impl AsMut<[u8]> for #repr {
384            #[inline(always)]
385            fn as_mut(&mut self) -> &mut [u8] {
386                &mut self.0
387            }
388        }
389    }
390}
391
392/// Convert BigUint into a vector of 64-bit limbs.
393fn biguint_to_real_u64_vec(mut v: BigUint, limbs: usize) -> Vec<u64> {
394    let m = BigUint::one() << 64;
395    let mut ret = vec![];
396
397    while v > BigUint::zero() {
398        let limb: BigUint = &v % &m;
399        ret.push(limb.to_u64().unwrap());
400        v >>= 64;
401    }
402
403    while ret.len() < limbs {
404        ret.push(0);
405    }
406
407    assert!(ret.len() == limbs);
408
409    ret
410}
411
412/// Convert BigUint into a tokenized vector of 64-bit limbs.
413fn biguint_to_u64_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream {
414    let ret = biguint_to_real_u64_vec(v, limbs);
415    quote!([#(#ret,)*])
416}
417
418fn biguint_num_bits(mut v: BigUint) -> u32 {
419    let mut bits = 0;
420
421    while v != BigUint::zero() {
422        v >>= 1;
423        bits += 1;
424    }
425
426    bits
427}
428
429/// BigUint modular exponentiation by square-and-multiply.
430fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
431    let mut ret = BigUint::one();
432
433    for i in exp
434        .to_bytes_be()
435        .into_iter()
436        .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
437    {
438        ret = (&ret * &ret) % modulus;
439        if i {
440            ret = (ret * &base) % modulus;
441        }
442    }
443
444    ret
445}
446
447#[test]
448fn test_exp() {
449    assert_eq!(
450        exp(
451            BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
452            &BigUint::from_str("5489673498567349856734895").unwrap(),
453            &BigUint::from_str(
454                "52435875175126190479447740508185965837690552500527637822603658699938581184513"
455            )
456            .unwrap()
457        ),
458        BigUint::from_str(
459            "4371221214068404307866768905142520595925044802278091865033317963560480051536"
460        )
461        .unwrap()
462    );
463}
464
465fn prime_field_constants_and_sqrt(
466    name: &syn::Ident,
467    modulus: &BigUint,
468    limbs: usize,
469    generator: BigUint,
470) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
471    let bytes = limbs * 8;
472    let modulus_num_bits = biguint_num_bits(modulus.clone());
473
474    // The number of bits we should "shave" from a randomly sampled representation, i.e.,
475    // if our modulus is 381 bits and our representation is 384 bits, we should shave
476    // 3 bits from the beginning of a randomly sampled 384 bit representation to
477    // reduce the cost of rejection sampling.
478    let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
479
480    // Compute R = 2**(64 * limbs) mod m
481    let r = (BigUint::one() << (limbs * 64)) % modulus;
482    let to_mont = |v| (v * &r) % modulus;
483
484    let two = BigUint::from_str("2").unwrap();
485    let p_minus_2 = modulus - &two;
486    let invert = |v| exp(v, &p_minus_2, &modulus);
487
488    // 2^-1 mod m
489    let two_inv = biguint_to_u64_vec(to_mont(invert(two)), limbs);
490
491    // modulus - 1 = 2^s * t
492    let mut s: u32 = 0;
493    let mut t = modulus - BigUint::from_str("1").unwrap();
494    while t.is_even() {
495        t >>= 1;
496        s += 1;
497    }
498
499    // Compute 2^s root of unity given the generator
500    let root_of_unity = exp(generator.clone(), &t, &modulus);
501    let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs);
502    let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs);
503    let delta = biguint_to_u64_vec(
504        to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)),
505        limbs,
506    );
507    let generator = biguint_to_u64_vec(to_mont(generator), limbs);
508
509    let sqrt_impl =
510        if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
511            // Addition chain for (r + 1) // 4
512            let mod_plus_1_over_4 = pow_fixed::generate(
513                &quote! {self},
514                (modulus + BigUint::from_str("1").unwrap()) >> 2,
515            );
516
517            quote! {
518                use ::ff::derive::subtle::ConstantTimeEq;
519
520                // Because r = 3 (mod 4)
521                // sqrt can be done with only one exponentiation,
522                // via the computation of  self^((r + 1) // 4) (mod r)
523                let sqrt = {
524                    #mod_plus_1_over_4
525                };
526
527                ::ff::derive::subtle::CtOption::new(
528                    sqrt,
529                    (sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root.
530                )
531            }
532        } else {
533            // Addition chain for (t - 1) // 2
534            let t_minus_1_over_2 = if t == BigUint::one() {
535                quote!( #name::ONE )
536            } else {
537                pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
538            };
539
540            quote! {
541                // Tonelli-Shanks algorithm works for every remaining odd prime.
542                // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
543                use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq};
544
545                // w = self^((t - 1) // 2)
546                let w = {
547                    #t_minus_1_over_2
548                };
549
550                let mut v = S;
551                let mut x = *self * &w;
552                let mut b = x * &w;
553
554                // Initialize z as the 2^S root of unity.
555                let mut z = ROOT_OF_UNITY;
556
557                for max_v in (1..=S).rev() {
558                    let mut k = 1;
559                    let mut tmp = b.square();
560                    let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();
561
562                    for j in 2..max_v {
563                        let tmp_is_one = tmp.ct_eq(&#name::ONE);
564                        let squared = #name::conditional_select(&tmp, &z, tmp_is_one).square();
565                        tmp = #name::conditional_select(&squared, &tmp, tmp_is_one);
566                        let new_z = #name::conditional_select(&z, &squared, tmp_is_one);
567                        j_less_than_v &= !j.ct_eq(&v);
568                        k = u32::conditional_select(&j, &k, tmp_is_one);
569                        z = #name::conditional_select(&z, &new_z, j_less_than_v);
570                    }
571
572                    let result = x * &z;
573                    x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
574                    z = z.square();
575                    b *= &z;
576                    v = k;
577                }
578
579                ::ff::derive::subtle::CtOption::new(
580                    x,
581                    (x * &x).ct_eq(self), // Only return Some if it's the square root.
582                )
583            }
584        };
585
586    // Compute R^2 mod m
587    let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs);
588
589    let r = biguint_to_u64_vec(r, limbs);
590    let modulus_le_bytes = ReprEndianness::Little.modulus_repr(modulus, limbs * 8);
591    let modulus_str = format!("0x{}", modulus.to_str_radix(16));
592    let modulus = biguint_to_real_u64_vec(modulus.clone(), limbs);
593
594    // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1
595    let mut inv = 1u64;
596    for _ in 0..63 {
597        inv = inv.wrapping_mul(inv);
598        inv = inv.wrapping_mul(modulus[0]);
599    }
600    inv = inv.wrapping_neg();
601
602    (
603        quote! {
604            type REPR_BYTES = [u8; #bytes];
605            type REPR_BITS = REPR_BYTES;
606
607            /// This is the modulus m of the prime field
608            const MODULUS: REPR_BITS = [#(#modulus_le_bytes,)*];
609
610            /// This is the modulus m of the prime field in limb form
611            const MODULUS_LIMBS: #name = #name([#(#modulus,)*]);
612
613            /// This is the modulus m of the prime field in hex string form
614            const MODULUS_STR: &'static str = #modulus_str;
615
616            /// The number of bits needed to represent the modulus.
617            const MODULUS_BITS: u32 = #modulus_num_bits;
618
619            /// The number of bits that must be shaved from the beginning of
620            /// the representation when randomly sampling.
621            const REPR_SHAVE_BITS: u32 = #repr_shave_bits;
622
623            /// 2^{limbs*64} mod m
624            const R: #name = #name(#r);
625
626            /// 2^{limbs*64*2} mod m
627            const R2: #name = #name(#r2);
628
629            /// -(m^{-1} mod m) mod m
630            const INV: u64 = #inv;
631
632            /// 2^{-1} mod m
633            const TWO_INV: #name = #name(#two_inv);
634
635            /// Multiplicative generator of `MODULUS` - 1 order, also quadratic
636            /// nonresidue.
637            const GENERATOR: #name = #name(#generator);
638
639            /// 2^s * t = MODULUS - 1 with t odd
640            const S: u32 = #s;
641
642            /// 2^s root of unity computed by GENERATOR^t
643            const ROOT_OF_UNITY: #name = #name(#root_of_unity);
644
645            /// (2^s)^{-1} mod m
646            const ROOT_OF_UNITY_INV: #name = #name(#root_of_unity_inv);
647
648            /// GENERATOR^{2^s}
649            const DELTA: #name = #name(#delta);
650        },
651        sqrt_impl,
652    )
653}
654
655/// Implement PrimeField for the derived type.
656fn prime_field_impl(
657    name: &syn::Ident,
658    repr: &syn::Ident,
659    modulus: &BigUint,
660    endianness: &ReprEndianness,
661    limbs: usize,
662    sqrt_impl: proc_macro2::TokenStream,
663) -> proc_macro2::TokenStream {
664    // Returns r{n} as an ident.
665    fn get_temp(n: usize) -> syn::Ident {
666        syn::Ident::new(&format!("r{}", n), proc_macro2::Span::call_site())
667    }
668
669    // The parameter list for the mont_reduce() internal method.
670    // r0: u64, mut r1: u64, mut r2: u64, ...
671    let mut mont_paramlist = proc_macro2::TokenStream::new();
672    mont_paramlist.append_separated(
673        (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
674            if i != 0 {
675                quote! {mut #x: u64}
676            } else {
677                quote! {#x: u64}
678            }
679        }),
680        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
681    );
682
683    // Implement montgomery reduction for some number of limbs
684    fn mont_impl(limbs: usize) -> proc_macro2::TokenStream {
685        let mut gen = proc_macro2::TokenStream::new();
686
687        for i in 0..limbs {
688            {
689                let temp = get_temp(i);
690                gen.extend(quote! {
691                    let k = #temp.wrapping_mul(INV);
692                    let (_, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[0], 0);
693                });
694            }
695
696            for j in 1..limbs {
697                let temp = get_temp(i + j);
698                gen.extend(quote! {
699                    let (#temp, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[#j], carry);
700                });
701            }
702
703            let temp = get_temp(i + limbs);
704
705            if i == 0 {
706                gen.extend(quote! {
707                    let (#temp, carry2) = ::ff::derive::adc(#temp, 0, carry);
708                });
709            } else {
710                gen.extend(quote! {
711                    let (#temp, carry2) = ::ff::derive::adc(#temp, carry2, carry);
712                });
713            }
714        }
715
716        for i in 0..limbs {
717            let temp = get_temp(limbs + i);
718
719            gen.extend(quote! {
720                self.0[#i] = #temp;
721            });
722        }
723
724        gen
725    }
726
727    fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
728        let mut gen = proc_macro2::TokenStream::new();
729
730        if limbs > 1 {
731            for i in 0..(limbs - 1) {
732                gen.extend(quote! {
733                    let carry = 0;
734                });
735
736                for j in (i + 1)..limbs {
737                    let temp = get_temp(i + j);
738                    if i == 0 {
739                        gen.extend(quote! {
740                            let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#j], carry);
741                        });
742                    } else {
743                        gen.extend(quote! {
744                            let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #a.0[#j], carry);
745                        });
746                    }
747                }
748
749                let temp = get_temp(i + limbs);
750
751                gen.extend(quote! {
752                    let #temp = carry;
753                });
754            }
755
756            for i in 1..(limbs * 2) {
757                let temp0 = get_temp(limbs * 2 - i);
758                let temp1 = get_temp(limbs * 2 - i - 1);
759
760                if i == 1 {
761                    gen.extend(quote! {
762                        let #temp0 = #temp1 >> 63;
763                    });
764                } else if i == (limbs * 2 - 1) {
765                    gen.extend(quote! {
766                        let #temp0 = #temp0 << 1;
767                    });
768                } else {
769                    gen.extend(quote! {
770                        let #temp0 = (#temp0 << 1) | (#temp1 >> 63);
771                    });
772                }
773            }
774        } else {
775            let temp1 = get_temp(1);
776            gen.extend(quote! {
777                let #temp1 = 0;
778            });
779        }
780
781        for i in 0..limbs {
782            let temp0 = get_temp(i * 2);
783            let temp1 = get_temp(i * 2 + 1);
784            if i == 0 {
785                gen.extend(quote! {
786                    let (#temp0, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#i], 0);
787                });
788            } else {
789                gen.extend(quote! {
790                    let (#temp0, carry) = ::ff::derive::mac(#temp0, #a.0[#i], #a.0[#i], carry);
791                });
792            }
793
794            gen.extend(quote! {
795                let (#temp1, carry) = ::ff::derive::adc(#temp1, 0, carry);
796            });
797        }
798
799        let mut mont_calling = proc_macro2::TokenStream::new();
800        mont_calling.append_separated(
801            (0..(limbs * 2)).map(get_temp),
802            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
803        );
804
805        gen.extend(quote! {
806            let mut ret = *self;
807            ret.mont_reduce(#mont_calling);
808            ret
809        });
810
811        gen
812    }
813
814    fn mul_impl(
815        a: proc_macro2::TokenStream,
816        b: proc_macro2::TokenStream,
817        limbs: usize,
818    ) -> proc_macro2::TokenStream {
819        let mut gen = proc_macro2::TokenStream::new();
820
821        for i in 0..limbs {
822            gen.extend(quote! {
823                let carry = 0;
824            });
825
826            for j in 0..limbs {
827                let temp = get_temp(i + j);
828
829                if i == 0 {
830                    gen.extend(quote! {
831                        let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #b.0[#j], carry);
832                    });
833                } else {
834                    gen.extend(quote! {
835                        let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #b.0[#j], carry);
836                    });
837                }
838            }
839
840            let temp = get_temp(i + limbs);
841
842            gen.extend(quote! {
843                let #temp = carry;
844            });
845        }
846
847        let mut mont_calling = proc_macro2::TokenStream::new();
848        mont_calling.append_separated(
849            (0..(limbs * 2)).map(get_temp),
850            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
851        );
852
853        gen.extend(quote! {
854            self.mont_reduce(#mont_calling);
855        });
856
857        gen
858    }
859
860    /// Generates an implementation of multiplicative inversion within the target prime
861    /// field.
862    fn inv_impl(a: proc_macro2::TokenStream, modulus: &BigUint) -> proc_macro2::TokenStream {
863        // Addition chain for p - 2
864        let mod_minus_2 = pow_fixed::generate(&a, modulus - BigUint::from(2u64));
865
866        quote! {
867            use ::ff::derive::subtle::ConstantTimeEq;
868
869            // By Euler's theorem, if `a` is coprime to `p` (i.e. `gcd(a, p) = 1`), then:
870            //     a^-1 ≡ a^(phi(p) - 1) mod p
871            //
872            // `ff_derive` requires that `p` is prime; in this case, `phi(p) = p - 1`, and
873            // thus:
874            //     a^-1 ≡ a^(p - 2) mod p
875            let inv = {
876                #mod_minus_2
877            };
878
879            ::ff::derive::subtle::CtOption::new(inv, !#a.is_zero())
880        }
881    }
882
883    let squaring_impl = sqr_impl(quote! {self}, limbs);
884    let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
885    let invert_impl = inv_impl(quote! {self}, modulus);
886    let montgomery_impl = mont_impl(limbs);
887
888    // self.0[0].ct_eq(&other.0[0]) & self.0[1].ct_eq(&other.0[1]) & ...
889    let mut ct_eq_impl = proc_macro2::TokenStream::new();
890    ct_eq_impl.append_separated(
891        (0..limbs).map(|i| quote! { self.0[#i].ct_eq(&other.0[#i]) }),
892        proc_macro2::Punct::new('&', proc_macro2::Spacing::Alone),
893    );
894
895    fn mont_reduce_params(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
896        // a.0[0], a.0[1], ..., 0, 0, 0, 0, ...
897        let mut mont_reduce_params = proc_macro2::TokenStream::new();
898        mont_reduce_params.append_separated(
899            (0..limbs)
900                .map(|i| quote! { #a.0[#i] })
901                .chain((0..limbs).map(|_| quote! {0})),
902            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
903        );
904        mont_reduce_params
905    }
906
907    let mont_reduce_self_params = mont_reduce_params(quote! {self}, limbs);
908    let mont_reduce_other_params = mont_reduce_params(quote! {other}, limbs);
909
910    let from_repr_impl = endianness.from_repr(name, limbs);
911    let to_repr_impl = endianness.to_repr(quote! {#repr}, &mont_reduce_self_params, limbs);
912
913    let prime_field_bits_impl = if cfg!(feature = "bits") {
914        let to_le_bits_impl = ReprEndianness::Little.to_repr(
915            quote! {::ff::derive::bitvec::array::BitArray::new},
916            &mont_reduce_self_params,
917            limbs,
918        );
919
920        Some(quote! {
921            impl ::ff::PrimeFieldBits for #name {
922                type ReprBits = REPR_BITS;
923
924                fn to_le_bits(&self) -> ::ff::FieldBits<REPR_BITS> {
925                    #to_le_bits_impl
926                }
927
928                fn char_le_bits() -> ::ff::FieldBits<REPR_BITS> {
929                    ::ff::FieldBits::new(MODULUS)
930                }
931            }
932        })
933    } else {
934        None
935    };
936
937    let top_limb_index = limbs - 1;
938
939    quote! {
940        impl ::core::marker::Copy for #name { }
941
942        impl ::core::clone::Clone for #name {
943            fn clone(&self) -> #name {
944                *self
945            }
946        }
947
948        impl ::core::default::Default for #name {
949            fn default() -> #name {
950                use ::ff::Field;
951                #name::ZERO
952            }
953        }
954
955        impl ::ff::derive::subtle::ConstantTimeEq for #name {
956            fn ct_eq(&self, other: &#name) -> ::ff::derive::subtle::Choice {
957                use ::ff::PrimeField;
958                self.to_repr().ct_eq(&other.to_repr())
959            }
960        }
961
962        impl ::core::cmp::PartialEq for #name {
963            fn eq(&self, other: &#name) -> bool {
964                use ::ff::derive::subtle::ConstantTimeEq;
965                self.ct_eq(other).into()
966            }
967        }
968
969        impl ::core::cmp::Eq for #name { }
970
971        impl ::core::fmt::Debug for #name
972        {
973            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
974                use ::ff::PrimeField;
975                write!(f, "{}({:?})", stringify!(#name), self.to_repr())
976            }
977        }
978
979        /// Elements are ordered lexicographically.
980        impl Ord for #name {
981            #[inline(always)]
982            fn cmp(&self, other: &#name) -> ::core::cmp::Ordering {
983                let mut a = *self;
984                a.mont_reduce(
985                    #mont_reduce_self_params
986                );
987
988                let mut b = *other;
989                b.mont_reduce(
990                    #mont_reduce_other_params
991                );
992
993                a.cmp_native(&b)
994            }
995        }
996
997        impl PartialOrd for #name {
998            #[inline(always)]
999            fn partial_cmp(&self, other: &#name) -> Option<::core::cmp::Ordering> {
1000                Some(self.cmp(other))
1001            }
1002        }
1003
1004        impl From<u64> for #name {
1005            #[inline(always)]
1006            fn from(val: u64) -> #name {
1007                let mut raw = [0u64; #limbs];
1008                raw[0] = val;
1009                #name(raw) * R2
1010            }
1011        }
1012
1013        impl From<#name> for #repr {
1014            fn from(e: #name) -> #repr {
1015                use ::ff::PrimeField;
1016                e.to_repr()
1017            }
1018        }
1019
1020        impl<'a> From<&'a #name> for #repr {
1021            fn from(e: &'a #name) -> #repr {
1022                use ::ff::PrimeField;
1023                e.to_repr()
1024            }
1025        }
1026
1027        impl ::ff::derive::subtle::ConditionallySelectable for #name {
1028            fn conditional_select(a: &#name, b: &#name, choice: ::ff::derive::subtle::Choice) -> #name {
1029                let mut res = [0u64; #limbs];
1030                for i in 0..#limbs {
1031                    res[i] = u64::conditional_select(&a.0[i], &b.0[i], choice);
1032                }
1033                #name(res)
1034            }
1035        }
1036
1037        impl ::core::ops::Neg for #name {
1038            type Output = #name;
1039
1040            #[inline]
1041            fn neg(self) -> #name {
1042                use ::ff::Field;
1043
1044                let mut ret = self;
1045                if !ret.is_zero_vartime() {
1046                    let mut tmp = MODULUS_LIMBS;
1047                    tmp.sub_noborrow(&ret);
1048                    ret = tmp;
1049                }
1050                ret
1051            }
1052        }
1053
1054        impl<'r> ::core::ops::Add<&'r #name> for #name {
1055            type Output = #name;
1056
1057            #[inline]
1058            fn add(self, other: &#name) -> #name {
1059                use ::core::ops::AddAssign;
1060
1061                let mut ret = self;
1062                ret.add_assign(other);
1063                ret
1064            }
1065        }
1066
1067        impl ::core::ops::Add for #name {
1068            type Output = #name;
1069
1070            #[inline]
1071            fn add(self, other: #name) -> Self {
1072                self + &other
1073            }
1074        }
1075
1076        impl<'r> ::core::ops::AddAssign<&'r #name> for #name {
1077            #[inline]
1078            fn add_assign(&mut self, other: &#name) {
1079                // This cannot exceed the backing capacity.
1080                self.add_nocarry(other);
1081
1082                // However, it may need to be reduced.
1083                self.reduce();
1084            }
1085        }
1086
1087        impl ::core::ops::AddAssign for #name {
1088            #[inline]
1089            fn add_assign(&mut self, other: #name) {
1090                self.add_assign(&other);
1091            }
1092        }
1093
1094        impl<'r> ::core::ops::Sub<&'r #name> for #name {
1095            type Output = #name;
1096
1097            #[inline]
1098            fn sub(self, other: &#name) -> Self {
1099                use ::core::ops::SubAssign;
1100
1101                let mut ret = self;
1102                ret.sub_assign(other);
1103                ret
1104            }
1105        }
1106
1107        impl ::core::ops::Sub for #name {
1108            type Output = #name;
1109
1110            #[inline]
1111            fn sub(self, other: #name) -> Self {
1112                self - &other
1113            }
1114        }
1115
1116        impl<'r> ::core::ops::SubAssign<&'r #name> for #name {
1117            #[inline]
1118            fn sub_assign(&mut self, other: &#name) {
1119                // If `other` is larger than `self`, we'll need to add the modulus to self first.
1120                if other.cmp_native(self) == ::core::cmp::Ordering::Greater {
1121                    self.add_nocarry(&MODULUS_LIMBS);
1122                }
1123
1124                self.sub_noborrow(other);
1125            }
1126        }
1127
1128        impl ::core::ops::SubAssign for #name {
1129            #[inline]
1130            fn sub_assign(&mut self, other: #name) {
1131                self.sub_assign(&other);
1132            }
1133        }
1134
1135        impl<'r> ::core::ops::Mul<&'r #name> for #name {
1136            type Output = #name;
1137
1138            #[inline]
1139            fn mul(self, other: &#name) -> Self {
1140                use ::core::ops::MulAssign;
1141
1142                let mut ret = self;
1143                ret.mul_assign(other);
1144                ret
1145            }
1146        }
1147
1148        impl ::core::ops::Mul for #name {
1149            type Output = #name;
1150
1151            #[inline]
1152            fn mul(self, other: #name) -> Self {
1153                self * &other
1154            }
1155        }
1156
1157        impl<'r> ::core::ops::MulAssign<&'r #name> for #name {
1158            #[inline]
1159            fn mul_assign(&mut self, other: &#name)
1160            {
1161                #multiply_impl
1162            }
1163        }
1164
1165        impl ::core::ops::MulAssign for #name {
1166            #[inline]
1167            fn mul_assign(&mut self, other: #name)
1168            {
1169                self.mul_assign(&other);
1170            }
1171        }
1172
1173        impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Sum<T> for #name {
1174            fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
1175                use ::ff::Field;
1176
1177                iter.fold(Self::ZERO, |acc, item| acc + item.borrow())
1178            }
1179        }
1180
1181        impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Product<T> for #name {
1182            fn product<I: Iterator<Item = T>>(iter: I) -> Self {
1183                use ::ff::Field;
1184
1185                iter.fold(Self::ONE, |acc, item| acc * item.borrow())
1186            }
1187        }
1188
1189        impl ::ff::PrimeField for #name {
1190            type Repr = #repr;
1191
1192            fn from_repr(r: #repr) -> ::ff::derive::subtle::CtOption<#name> {
1193                #from_repr_impl
1194
1195                // Try to subtract the modulus
1196                let borrow = r.0.iter().zip(MODULUS_LIMBS.0.iter()).fold(0, |borrow, (a, b)| {
1197                    ::ff::derive::sbb(*a, *b, borrow).1
1198                });
1199
1200                // If the element is smaller than MODULUS then the
1201                // subtraction will underflow, producing a borrow value
1202                // of 0xffff...ffff. Otherwise, it'll be zero.
1203                let is_some = ::ff::derive::subtle::Choice::from((borrow as u8) & 1);
1204
1205                // Convert to Montgomery form by computing
1206                // (a.R^0 * R^2) / R = a.R
1207                ::ff::derive::subtle::CtOption::new(r * &R2, is_some)
1208            }
1209
1210            fn from_repr_vartime(r: #repr) -> Option<#name> {
1211                #from_repr_impl
1212
1213                if r.is_valid() {
1214                    Some(r * R2)
1215                } else {
1216                    None
1217                }
1218            }
1219
1220            fn to_repr(&self) -> #repr {
1221                #to_repr_impl
1222            }
1223
1224            #[inline(always)]
1225            fn is_odd(&self) -> ::ff::derive::subtle::Choice {
1226                let mut r = *self;
1227                r.mont_reduce(
1228                    #mont_reduce_self_params
1229                );
1230
1231                // TODO: This looks like a constant-time result, but r.mont_reduce() is
1232                // currently implemented using variable-time code.
1233                ::ff::derive::subtle::Choice::from((r.0[0] & 1) as u8)
1234            }
1235
1236            const MODULUS: &'static str = MODULUS_STR;
1237
1238            const NUM_BITS: u32 = MODULUS_BITS;
1239
1240            const CAPACITY: u32 = Self::NUM_BITS - 1;
1241
1242            const TWO_INV: Self = TWO_INV;
1243
1244            const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;
1245
1246            const S: u32 = S;
1247
1248            const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
1249
1250            const ROOT_OF_UNITY_INV: Self = ROOT_OF_UNITY_INV;
1251
1252            const DELTA: Self = DELTA;
1253        }
1254
1255        #prime_field_bits_impl
1256
1257        impl ::ff::Field for #name {
1258            const ZERO: Self = #name([0; #limbs]);
1259            const ONE: Self = R;
1260
1261            /// Computes a uniformly random element using rejection sampling.
1262            fn random(mut rng: impl ::ff::derive::rand_core::RngCore) -> Self {
1263                loop {
1264                    let mut tmp = {
1265                        let mut repr = [0u64; #limbs];
1266                        for i in 0..#limbs {
1267                            repr[i] = rng.next_u64();
1268                        }
1269                        #name(repr)
1270                    };
1271
1272                    // Mask away the unused most-significant bits.
1273                    // Note: In some edge cases, `REPR_SHAVE_BITS` could be 64, in which case
1274                    // `0xfff... >> REPR_SHAVE_BITS` overflows. So use `checked_shr` instead.
1275                    // This is always sufficient because we will have at most one spare limb
1276                    // to accommodate values of up to twice the modulus.
1277                    tmp.0[#top_limb_index] &= 0xffffffffffffffffu64.checked_shr(REPR_SHAVE_BITS).unwrap_or(0);
1278
1279                    if tmp.is_valid() {
1280                        return tmp
1281                    }
1282                }
1283            }
1284
1285            #[inline]
1286            fn is_zero_vartime(&self) -> bool {
1287                self.0.iter().all(|&e| e == 0)
1288            }
1289
1290            #[inline]
1291            fn double(&self) -> Self {
1292                let mut ret = *self;
1293
1294                // This cannot exceed the backing capacity.
1295                let mut last = 0;
1296                for i in &mut ret.0 {
1297                    let tmp = *i >> 63;
1298                    *i <<= 1;
1299                    *i |= last;
1300                    last = tmp;
1301                }
1302
1303                // However, it may need to be reduced.
1304                ret.reduce();
1305
1306                ret
1307            }
1308
1309            fn invert(&self) -> ::ff::derive::subtle::CtOption<Self> {
1310                #invert_impl
1311            }
1312
1313            #[inline]
1314            fn square(&self) -> Self
1315            {
1316                #squaring_impl
1317            }
1318
1319            fn sqrt_ratio(num: &Self, div: &Self) -> (::ff::derive::subtle::Choice, Self) {
1320                ::ff::helpers::sqrt_ratio_generic(num, div)
1321            }
1322
1323            fn sqrt(&self) -> ::ff::derive::subtle::CtOption<Self> {
1324                #sqrt_impl
1325            }
1326        }
1327
1328        impl #name {
1329            /// Compares two elements in native representation. This is only used
1330            /// internally.
1331            #[inline(always)]
1332            fn cmp_native(&self, other: &#name) -> ::core::cmp::Ordering {
1333                for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) {
1334                    if a < b {
1335                        return ::core::cmp::Ordering::Less
1336                    } else if a > b {
1337                        return ::core::cmp::Ordering::Greater
1338                    }
1339                }
1340
1341                ::core::cmp::Ordering::Equal
1342            }
1343
1344            /// Determines if the element is really in the field. This is only used
1345            /// internally.
1346            #[inline(always)]
1347            fn is_valid(&self) -> bool {
1348                // The Ord impl calls `reduce`, which in turn calls `is_valid`, so we use
1349                // this internal function to eliminate the cycle.
1350                self.cmp_native(&MODULUS_LIMBS) == ::core::cmp::Ordering::Less
1351            }
1352
1353            #[inline(always)]
1354            fn add_nocarry(&mut self, other: &#name) {
1355                let mut carry = 0;
1356
1357                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1358                    let (new_a, new_carry) = ::ff::derive::adc(*a, *b, carry);
1359                    *a = new_a;
1360                    carry = new_carry;
1361                }
1362            }
1363
1364            #[inline(always)]
1365            fn sub_noborrow(&mut self, other: &#name) {
1366                let mut borrow = 0;
1367
1368                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1369                    let (new_a, new_borrow) = ::ff::derive::sbb(*a, *b, borrow);
1370                    *a = new_a;
1371                    borrow = new_borrow;
1372                }
1373            }
1374
1375            /// Subtracts the modulus from this element if this element is not in the
1376            /// field. Only used internally.
1377            #[inline(always)]
1378            fn reduce(&mut self) {
1379                if !self.is_valid() {
1380                    self.sub_noborrow(&MODULUS_LIMBS);
1381                }
1382            }
1383
1384            #[allow(clippy::too_many_arguments)]
1385            #[inline(always)]
1386            fn mont_reduce(
1387                &mut self,
1388                #mont_paramlist
1389            )
1390            {
1391                // The Montgomery reduction here is based on Algorithm 14.32 in
1392                // Handbook of Applied Cryptography
1393                // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
1394
1395                #montgomery_impl
1396
1397                self.reduce();
1398            }
1399        }
1400    }
1401}