1use crate::{algorithms, Uint};
2
3impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11 #[inline]
18 #[must_use]
19 pub fn reduce_mod(mut self, modulus: Self) -> Self {
20 if modulus.is_zero() {
21 return Self::ZERO;
22 }
23 if self >= modulus {
24 self %= modulus;
25 }
26 self
27 }
28
29 #[inline]
33 #[must_use]
34 pub fn add_mod(mut self, rhs: Self, mut modulus: Self) -> Self {
35 if modulus.is_zero() {
36 return Self::ZERO;
37 }
38
39 #[allow(clippy::cast_possible_truncation)]
42 if BITS <= 64 {
43 self.limbs[0] =
44 ((self.limbs[0] as u128 + rhs.limbs[0] as u128) % modulus.limbs[0] as u128) as u64;
45 return self;
46 }
47
48 let (result, overflow) = self.overflowing_add(rhs);
50 if overflow {
51 let_double_bits!(numerator);
53 let (limb, bit) = (BITS / 64, BITS % 64);
54 let numerator = &mut numerator[..=limb];
55 numerator[..LIMBS].copy_from_slice(result.as_limbs());
56 numerator[limb] |= 1 << bit;
57
58 if crate::nlimbs(BITS + 1) == LIMBS {
61 let numerator = unsafe { &mut *numerator.as_mut_ptr().cast::<Self>() };
62 Self::div_rem_by_ref(numerator, &mut modulus);
63 } else {
64 Self::div_rem_bits_plus_one(numerator.as_mut_ptr(), &mut modulus);
65 }
66
67 modulus
68 } else {
69 result.reduce_mod(modulus)
70 }
71 }
72
73 #[inline(never)]
74 fn div_rem_bits_plus_one(numerator: *mut u64, modulus: &mut Self) {
75 let numerator = unsafe { core::slice::from_raw_parts_mut(numerator, LIMBS + 1) };
78 algorithms::div::div_inlined(numerator, &mut modulus.limbs);
79 }
80
81 #[inline(always)]
88 #[must_use]
89 pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
90 self.mul_mod_by_ref(&rhs, &mut modulus);
91 modulus
92 }
93
94 #[inline(never)]
95 fn mul_mod_by_ref(&self, rhs: &Self, modulus: &mut Self) {
96 if modulus.is_zero() {
97 return;
98 }
99 let_double_bits!(product);
100 let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
101 debug_assert!(!overflow);
102 Self::div_rem_double_bits(product, modulus);
103 }
104
105 #[inline]
106 fn div_rem_double_bits(numerator: &mut [u64], modulus: &mut Self) {
107 assume!(numerator.len() == crate::nlimbs(BITS * 2));
108 algorithms::div::div_inlined(numerator, &mut modulus.limbs);
109 }
110
111 #[inline]
115 #[must_use]
116 pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
117 if BITS == 0 || modulus <= Self::ONE {
118 return Self::ZERO;
119 }
120
121 let mut result = Self::ONE;
123 while exp > Self::ZERO {
124 if exp.limbs[0] & 1 == 1 {
126 result = result.mul_mod(self, modulus);
127 }
128
129 self = self.mul_mod(self, modulus);
131 exp >>= 1;
132 }
133 result
134 }
135
136 #[inline]
140 #[must_use]
141 pub fn inv_mod(self, modulus: Self) -> Option<Self> {
142 algorithms::inv_mod(self, modulus)
143 }
144
145 #[inline]
186 #[must_use]
187 pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
188 if BITS == 0 {
189 return Self::ZERO;
190 }
191 let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv);
192 let result = Self::from_limbs(result);
193 debug_assert!(result < modulus);
194 result
195 }
196
197 #[inline]
201 #[must_use]
202 pub fn square_redc(self, modulus: Self, inv: u64) -> Self {
203 if BITS == 0 {
204 return Self::ZERO;
205 }
206 let result = algorithms::square_redc(self.limbs, modulus.limbs, inv);
207 let result = Self::from_limbs(result);
208 debug_assert!(result < modulus);
209 result
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::{aliases::U64, const_for, nlimbs};
217 use proptest::{prop_assume, proptest, test_runner::Config};
218
219 #[test]
220 fn test_commutative() {
221 const_for!(BITS in SIZES {
222 const LIMBS: usize = nlimbs(BITS);
223 type U = Uint<BITS, LIMBS>;
224 proptest!(|(a: U, b: U, m: U)| {
225 assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
226 });
227 });
228 }
229
230 #[test]
231 fn test_associative() {
232 const_for!(BITS in SIZES {
233 const LIMBS: usize = nlimbs(BITS);
234 type U = Uint<BITS, LIMBS>;
235 proptest!(|(a: U, b: U, c: U, m: U)| {
236 assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
237 });
238 });
239 }
240
241 #[test]
242 fn test_distributive() {
243 const_for!(BITS in SIZES {
244 const LIMBS: usize = nlimbs(BITS);
245 type U = Uint<BITS, LIMBS>;
246 proptest!(|(a: U, b: U, c: U, m: U)| {
247 assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
248 });
249 });
250 }
251
252 #[test]
253 fn test_add_identity() {
254 const_for!(BITS in NON_ZERO {
255 const LIMBS: usize = nlimbs(BITS);
256 type U = Uint<BITS, LIMBS>;
257 proptest!(|(value: U, m: U)| {
258 assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
259 });
260 });
261 }
262
263 #[test]
264 fn test_mul_identity() {
265 const_for!(BITS in NON_ZERO {
266 const LIMBS: usize = nlimbs(BITS);
267 type U = Uint<BITS, LIMBS>;
268 proptest!(|(value: U, m: U)| {
269 assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
270 assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
271 });
272 });
273 }
274
275 #[test]
276 fn test_pow_identity() {
277 const_for!(BITS in NON_ZERO {
278 const LIMBS: usize = nlimbs(BITS);
279 type U = Uint<BITS, LIMBS>;
280 proptest!(|(a: U, m: U)| {
281 assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
282 assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
283 });
284 });
285 }
286
287 #[test]
288 fn test_pow_rules() {
289 const_for!(BITS in NON_ZERO {
290 const LIMBS: usize = nlimbs(BITS);
291 type U = Uint<BITS, LIMBS>;
292
293 if LIMBS > 8 {
295 return;
296 }
297
298 let config = Config { cases: 5, ..Default::default() };
299 proptest!(config, |(a: U, b: U, c: U, m: U)| {
300 assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
303 });
304 });
305 }
306
307 #[test]
308 fn test_inv() {
309 const_for!(BITS in NON_ZERO {
310 const LIMBS: usize = nlimbs(BITS);
311 type U = Uint<BITS, LIMBS>;
312 proptest!(|(a: U, m: U)| {
313 if let Some(inv) = a.inv_mod(m) {
314 assert_eq!(a.mul_mod(inv, m), U::from(1));
315 }
316 });
317 });
318 }
319
320 #[test]
321 fn test_mul_redc() {
322 const_for!(BITS in NON_ZERO if BITS >= 16 {
323 const LIMBS: usize = nlimbs(BITS);
324 type U = Uint<BITS, LIMBS>;
325 proptest!(|(a: U, b: U, m: U)| {
326 prop_assume!(m >= U::from(2));
327 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
328 let inv = (-inv).as_limbs()[0];
329
330 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
331 let ar = a.mul_mod(r, m);
332 let br = b.mul_mod(r, m);
333 let expected = a.mul_mod(b, m).mul_mod(r, m);
336
337 assert_eq!(ar.mul_redc(br, m, inv), expected);
338 }
339 });
340 });
341 }
342
343 #[test]
344 fn test_square_redc() {
345 const_for!(BITS in NON_ZERO if BITS >= 16 {
346 const LIMBS: usize = nlimbs(BITS);
347 type U = Uint<BITS, LIMBS>;
348 proptest!(|(a: U, m: U)| {
349 prop_assume!(m >= U::from(2));
350 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
351 let inv = (-inv).as_limbs()[0];
352
353 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
354 let ar = a.mul_mod(r, m);
355 let expected = a.mul_mod(a, m).mul_mod(r, m);
358
359 assert_eq!(ar.square_redc(m, inv), expected);
360 }
361 });
362 });
363 }
364}