1use super::{borrowing_sub, carrying_add, cmp, DoubleWord};
4use core::{cmp::Ordering, iter::zip};
5
6#[doc = crate::algorithms::unstable_warning!()]
8#[inline]
11#[must_use]
12pub fn mul_redc<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
13 debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
14 debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
15 debug_assert_eq!(cmp(&b, &modulus), Ordering::Less);
16
17 let mut result = [0; N];
22 let mut carry = false;
23 for b in b {
24 let mut m = 0;
25 let mut carry_1 = 0;
26 let mut carry_2 = 0;
27 for i in 0..N {
28 let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1);
30 carry_1 = next_carry;
31
32 if i == 0 {
33 m = value.wrapping_mul(inv);
35 }
36
37 let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2);
39 carry_2 = next_carry;
40
41 if i > 0 {
43 result[i - 1] = value;
44 } else {
45 debug_assert_eq!(value, 0);
46 }
47 }
48
49 let (value, next_carry) = carrying_add(carry_1, carry_2, carry);
51 result[N - 1] = value;
52 if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff {
53 carry = next_carry;
54 } else {
55 debug_assert!(!next_carry);
56 }
57 }
58
59 reduce1_carry(result, modulus, carry)
61}
62
63#[doc = crate::algorithms::unstable_warning!()]
65#[inline]
68#[must_use]
69#[allow(clippy::cast_possible_truncation)]
70pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
71 debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
72 debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
73
74 let mut result = [0; N];
75 let mut carry_outer = 0;
76 for i in 0..N {
77 let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0);
79 let mut carry_hi = false;
80 result[i] = value;
81 for j in (i + 1)..N {
82 let (value, next_carry_lo, next_carry_hi) =
83 carrying_double_mul_add(a[i], a[j], result[j], carry_lo, carry_hi);
84 result[j] = value;
85 carry_lo = next_carry_lo;
86 carry_hi = next_carry_hi;
87 }
88
89 let m = result[0].wrapping_mul(inv);
91 let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0);
92 debug_assert_eq!(value, 0);
93 for j in 1..N {
94 let (value, next_carry) = carrying_mul_add(modulus[j], m, result[j], carry);
95 result[j - 1] = value;
96 carry = next_carry;
97 }
98
99 if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff {
101 let wide = (carry_outer as u128)
102 .wrapping_add(carry_lo as u128)
103 .wrapping_add((carry_hi as u128) << 64)
104 .wrapping_add(carry as u128);
105 result[N - 1] = wide as u64;
106
107 carry_outer = (wide >> 64) as u64;
109 debug_assert!(carry_outer <= 2);
110 } else {
111 debug_assert!(!carry_hi);
113 debug_assert_eq!(carry_outer, 0);
114 let (value, carry) = carry_lo.overflowing_add(carry);
115 debug_assert!(!carry);
116 result[N - 1] = value;
117 }
118 }
119
120 debug_assert!(carry_outer <= 1);
122 reduce1_carry(result, modulus, carry_outer > 0)
123}
124
125#[inline]
126#[must_use]
127#[allow(clippy::needless_bitwise_bool)]
128fn reduce1_carry<const N: usize>(value: [u64; N], modulus: [u64; N], carry: bool) -> [u64; N] {
129 let (reduced, borrow) = sub(value, modulus);
130 if carry | !borrow {
133 reduced
134 } else {
135 value
136 }
137}
138
139#[inline]
140#[must_use]
141fn sub<const N: usize>(lhs: [u64; N], rhs: [u64; N]) -> ([u64; N], bool) {
142 let mut result = [0; N];
143 let mut borrow = false;
144 for (result, (lhs, rhs)) in zip(&mut result, zip(lhs, rhs)) {
145 let (value, next_borrow) = borrowing_sub(lhs, rhs, borrow);
146 *result = value;
147 borrow = next_borrow;
148 }
149 (result, borrow)
150}
151
152#[inline]
155#[must_use]
156#[allow(clippy::cast_possible_truncation)]
157fn carrying_mul_add(lhs: u64, rhs: u64, add: u64, carry: u64) -> (u64, u64) {
158 u128::muladd2(lhs, rhs, add, carry).split()
159}
160
161#[inline]
164#[must_use]
165#[allow(clippy::cast_possible_truncation)]
166const fn carrying_double_mul_add(
167 lhs: u64,
168 rhs: u64,
169 add: u64,
170 carry_lo: u64,
171 carry_hi: bool,
172) -> (u64, u64, bool) {
173 let wide = (lhs as u128).wrapping_mul(rhs as u128);
174 let (wide, carry_1) = wide.overflowing_add(wide);
175 let carries = (add as u128)
176 .wrapping_add(carry_lo as u128)
177 .wrapping_add((carry_hi as u128) << 64);
178 let (wide, carry_2) = wide.overflowing_add(carries);
179 (wide as u64, (wide >> 64) as u64, carry_1 | carry_2)
180}
181
182#[cfg(test)]
183mod test {
184 use super::{
185 super::{addmul, div},
186 *,
187 };
188 use crate::{aliases::U64, const_for, nlimbs, Uint};
189 use core::ops::Neg;
190 use proptest::{prop_assert_eq, proptest};
191
192 fn modmul<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N]) -> [u64; N] {
193 let mut product = vec![0; 2 * N];
195 addmul(&mut product, &a, &b);
196
197 let mut reduced = modulus;
199 div(&mut product, &mut reduced);
200 reduced
201 }
202
203 fn mul_base<const N: usize>(a: [u64; N], modulus: [u64; N]) -> [u64; N] {
204 let mut product = vec![0; 2 * N];
206 product[N..].copy_from_slice(&a);
207
208 let mut reduced = modulus;
210 div(&mut product, &mut reduced);
211 reduced
212 }
213
214 #[test]
215 fn test_mul_redc() {
216 const_for!(BITS in NON_ZERO if BITS >= 16 {
217 const LIMBS: usize = nlimbs(BITS);
218 type U = Uint<BITS, LIMBS>;
219 proptest!(|(mut a: U, mut b: U, mut m: U)| {
220 m |= U::from(1_u64); a %= m; b %= m; let a = *a.as_limbs();
224 let b = *b.as_limbs();
225 let m = *m.as_limbs();
226 let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
227
228 let result = mul_base(mul_redc(a, b, m, inv), m);
229 let expected = modmul(a, b, m);
230
231 prop_assert_eq!(result, expected);
232 });
233 });
234 }
235
236 #[test]
237 fn test_square_redc() {
238 const_for!(BITS in NON_ZERO if BITS >= 16 {
239 const LIMBS: usize = nlimbs(BITS);
240 type U = Uint<BITS, LIMBS>;
241 proptest!(|(mut a: U, mut m: U)| {
242 m |= U::from(1_u64); a %= m; let a = *a.as_limbs();
245 let m = *m.as_limbs();
246 let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
247
248 let result = mul_base(square_redc(a, m, inv), m);
249 let expected = modmul(a, a, m);
250
251 prop_assert_eq!(result, expected);
252 });
253 });
254 }
255}