1use crate::{algorithms, Uint};
2
3impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11 #[inline]
18 #[must_use]
19 pub fn reduce_mod(mut self, modulus: Self) -> Self {
20 if modulus.is_zero() {
21 return Self::ZERO;
22 }
23 if self >= modulus {
24 self %= modulus;
25 }
26 self
27 }
28
29 #[inline]
33 #[must_use]
34 pub fn add_mod(self, rhs: Self, modulus: Self) -> Self {
35 let lhs = self.reduce_mod(modulus);
37 let rhs = rhs.reduce_mod(modulus);
38
39 let (mut result, overflow) = lhs.overflowing_add(rhs);
41 if overflow || result >= modulus {
42 result -= modulus;
43 }
44 result
45 }
46
47 #[inline]
54 #[must_use]
55 pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
56 if modulus.is_zero() {
57 return Self::ZERO;
58 }
59
60 let mut product = [[0u64; 2]; LIMBS];
63 let product_len = crate::nlimbs(2 * BITS);
64 debug_assert!(2 * LIMBS >= product_len);
65 let product = unsafe {
67 core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
68 };
69
70 let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
72 debug_assert!(!overflow);
73
74 algorithms::div(product, &mut modulus.limbs);
77
78 modulus
79 }
80
81 #[inline]
85 #[must_use]
86 pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
87 if BITS == 0 || modulus <= Self::ONE {
88 return Self::ZERO;
89 }
90
91 let mut result = Self::ONE;
93 while exp > Self::ZERO {
94 if exp.limbs[0] & 1 == 1 {
96 result = result.mul_mod(self, modulus);
97 }
98
99 self = self.mul_mod(self, modulus);
101 exp >>= 1;
102 }
103 result
104 }
105
106 #[inline]
110 #[must_use]
111 pub fn inv_mod(self, modulus: Self) -> Option<Self> {
112 algorithms::inv_mod(self, modulus)
113 }
114
115 #[inline]
156 #[must_use]
157 pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
158 if BITS == 0 {
159 return Self::ZERO;
160 }
161 let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv);
162 let result = Self::from_limbs(result);
163 debug_assert!(result < modulus);
164 result
165 }
166
167 #[inline]
171 #[must_use]
172 pub fn square_redc(self, modulus: Self, inv: u64) -> Self {
173 if BITS == 0 {
174 return Self::ZERO;
175 }
176 let result = algorithms::square_redc(self.limbs, modulus.limbs, inv);
177 let result = Self::from_limbs(result);
178 debug_assert!(result < modulus);
179 result
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::{aliases::U64, const_for, nlimbs};
187 use proptest::{prop_assume, proptest, test_runner::Config};
188
189 #[test]
190 fn test_commutative() {
191 const_for!(BITS in SIZES {
192 const LIMBS: usize = nlimbs(BITS);
193 type U = Uint<BITS, LIMBS>;
194 proptest!(|(a: U, b: U, m: U)| {
195 assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
196 });
197 });
198 }
199
200 #[test]
201 fn test_associative() {
202 const_for!(BITS in SIZES {
203 const LIMBS: usize = nlimbs(BITS);
204 type U = Uint<BITS, LIMBS>;
205 proptest!(|(a: U, b: U, c: U, m: U)| {
206 assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
207 });
208 });
209 }
210
211 #[test]
212 fn test_distributive() {
213 const_for!(BITS in SIZES {
214 const LIMBS: usize = nlimbs(BITS);
215 type U = Uint<BITS, LIMBS>;
216 proptest!(|(a: U, b: U, c: U, m: U)| {
217 assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
218 });
219 });
220 }
221
222 #[test]
223 fn test_add_identity() {
224 const_for!(BITS in NON_ZERO {
225 const LIMBS: usize = nlimbs(BITS);
226 type U = Uint<BITS, LIMBS>;
227 proptest!(|(value: U, m: U)| {
228 assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
229 });
230 });
231 }
232
233 #[test]
234 fn test_mul_identity() {
235 const_for!(BITS in NON_ZERO {
236 const LIMBS: usize = nlimbs(BITS);
237 type U = Uint<BITS, LIMBS>;
238 proptest!(|(value: U, m: U)| {
239 assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
240 assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
241 });
242 });
243 }
244
245 #[test]
246 fn test_pow_identity() {
247 const_for!(BITS in NON_ZERO {
248 const LIMBS: usize = nlimbs(BITS);
249 type U = Uint<BITS, LIMBS>;
250 proptest!(|(a: U, m: U)| {
251 assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
252 assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
253 });
254 });
255 }
256
257 #[test]
258 fn test_pow_rules() {
259 const_for!(BITS in NON_ZERO {
260 const LIMBS: usize = nlimbs(BITS);
261 type U = Uint<BITS, LIMBS>;
262
263 if LIMBS > 8 {
265 return;
266 }
267
268 let config = Config { cases: 5, ..Default::default() };
269 proptest!(config, |(a: U, b: U, c: U, m: U)| {
270 assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
273 });
274 });
275 }
276
277 #[test]
278 fn test_inv() {
279 const_for!(BITS in NON_ZERO {
280 const LIMBS: usize = nlimbs(BITS);
281 type U = Uint<BITS, LIMBS>;
282 proptest!(|(a: U, m: U)| {
283 if let Some(inv) = a.inv_mod(m) {
284 assert_eq!(a.mul_mod(inv, m), U::from(1));
285 }
286 });
287 });
288 }
289
290 #[test]
291 fn test_mul_redc() {
292 const_for!(BITS in NON_ZERO if (BITS >= 16) {
293 const LIMBS: usize = nlimbs(BITS);
294 type U = Uint<BITS, LIMBS>;
295 proptest!(|(a: U, b: U, m: U)| {
296 prop_assume!(m >= U::from(2));
297 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
298 let inv = (-inv).as_limbs()[0];
299
300 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
301 let ar = a.mul_mod(r, m);
302 let br = b.mul_mod(r, m);
303 let expected = a.mul_mod(b, m).mul_mod(r, m);
306
307 assert_eq!(ar.mul_redc(br, m, inv), expected);
308 }
309 });
310 });
311 }
312
313 #[test]
314 fn test_square_redc() {
315 const_for!(BITS in NON_ZERO if (BITS >= 16) {
316 const LIMBS: usize = nlimbs(BITS);
317 type U = Uint<BITS, LIMBS>;
318 proptest!(|(a: U, m: U)| {
319 prop_assume!(m >= U::from(2));
320 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
321 let inv = (-inv).as_limbs()[0];
322
323 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
324 let ar = a.mul_mod(r, m);
325 let expected = a.mul_mod(a, m).mul_mod(r, m);
328
329 assert_eq!(ar.square_redc(m, inv), expected);
330 }
331 });
332 });
333 }
334}