ruint/
modular.rs

1use crate::{algorithms, Uint};
2
3// FEATURE: sub_mod, neg_mod, inv_mod, div_mod, root_mod
4// See <https://en.wikipedia.org/wiki/Cipolla's_algorithm>
5// FEATURE: mul_mod_redc
6// and maybe barrett
7// See also <https://static1.squarespace.com/static/61f7cacf2d7af938cad5b81c/t/62deb4e0c434f7134c2730ee/1658762465114/modular_multiplication.pdf>
8// FEATURE: Modular wrapper class, like Wrapping.
9
10impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11    /// ⚠️ Compute $\mod{\mathtt{self}}_{\mathtt{modulus}}$.
12    ///
13    /// **Warning.** This function is not part of the stable API.
14    ///
15    /// Returns zero if the modulus is zero.
16    // FEATURE: Reduce larger bit-sizes to smaller ones.
17    #[inline]
18    #[must_use]
19    pub fn reduce_mod(mut self, modulus: Self) -> Self {
20        if modulus.is_zero() {
21            return Self::ZERO;
22        }
23        if self >= modulus {
24            self %= modulus;
25        }
26        self
27    }
28
29    /// Compute $\mod{\mathtt{self} + \mathtt{rhs}}_{\mathtt{modulus}}$.
30    ///
31    /// Returns zero if the modulus is zero.
32    #[inline]
33    #[must_use]
34    pub fn add_mod(mut self, rhs: Self, mut modulus: Self) -> Self {
35        if modulus.is_zero() {
36            return Self::ZERO;
37        }
38
39        // This is not going to truncate with the final cast because the modulus value
40        // is 64 bits.
41        #[allow(clippy::cast_possible_truncation)]
42        if BITS <= 64 {
43            self.limbs[0] =
44                ((self.limbs[0] as u128 + rhs.limbs[0] as u128) % modulus.limbs[0] as u128) as u64;
45            return self;
46        }
47
48        // do overflowing add, then check if we should divrem
49        let (result, overflow) = self.overflowing_add(rhs);
50        if overflow {
51            // Add carry bit to the result. We might need an extra limb.
52            let_double_bits!(numerator);
53            let (limb, bit) = (BITS / 64, BITS % 64);
54            let numerator = &mut numerator[..=limb];
55            numerator[..LIMBS].copy_from_slice(result.as_limbs());
56            numerator[limb] |= 1 << bit;
57
58            // TODO(dani): const block
59            // Reuse `div_rem` if we don't need an extra limb.
60            if crate::nlimbs(BITS + 1) == LIMBS {
61                let numerator = unsafe { &mut *numerator.as_mut_ptr().cast::<Self>() };
62                Self::div_rem_by_ref(numerator, &mut modulus);
63            } else {
64                Self::div_rem_bits_plus_one(numerator.as_mut_ptr(), &mut modulus);
65            }
66
67            modulus
68        } else {
69            result.reduce_mod(modulus)
70        }
71    }
72
73    #[inline(never)]
74    fn div_rem_bits_plus_one(numerator: *mut u64, modulus: &mut Self) {
75        // TODO(dani): check if this is worth special casing over just using
76        // div_rem_double_bits
77        let numerator = unsafe { core::slice::from_raw_parts_mut(numerator, LIMBS + 1) };
78        algorithms::div::div_inlined(numerator, &mut modulus.limbs);
79    }
80
81    /// Compute $\mod{\mathtt{self} ⋅ \mathtt{rhs}}_{\mathtt{modulus}}$.
82    ///
83    /// Returns zero if the modulus is zero.
84    ///
85    /// See [`mul_redc`](Self::mul_redc) for a faster variant at the cost of
86    /// some pre-computation.
87    #[inline(always)]
88    #[must_use]
89    pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
90        self.mul_mod_by_ref(&rhs, &mut modulus);
91        modulus
92    }
93
94    #[inline(never)]
95    fn mul_mod_by_ref(&self, rhs: &Self, modulus: &mut Self) {
96        if modulus.is_zero() {
97            return;
98        }
99        let_double_bits!(product);
100        let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
101        debug_assert!(!overflow);
102        Self::div_rem_double_bits(product, modulus);
103    }
104
105    #[inline]
106    fn div_rem_double_bits(numerator: &mut [u64], modulus: &mut Self) {
107        assume!(numerator.len() == crate::nlimbs(BITS * 2));
108        algorithms::div::div_inlined(numerator, &mut modulus.limbs);
109    }
110
111    /// Compute $\mod{\mathtt{self}^{\mathtt{rhs}}}_{\mathtt{modulus}}$.
112    ///
113    /// Returns zero if the modulus is zero.
114    #[inline]
115    #[must_use]
116    pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
117        if BITS == 0 || modulus <= Self::ONE {
118            return Self::ZERO;
119        }
120
121        // Exponentiation by squaring
122        let mut result = Self::ONE;
123        while exp > Self::ZERO {
124            // Multiply by base
125            if exp.limbs[0] & 1 == 1 {
126                result = result.mul_mod(self, modulus);
127            }
128
129            // Square base
130            self = self.mul_mod(self, modulus);
131            exp >>= 1;
132        }
133        result
134    }
135
136    /// Compute $\mod{\mathtt{self}^{-1}}_{\mathtt{modulus}}$.
137    ///
138    /// Returns `None` if the inverse does not exist.
139    #[inline]
140    #[must_use]
141    pub fn inv_mod(self, modulus: Self) -> Option<Self> {
142        algorithms::inv_mod(self, modulus)
143    }
144
145    /// Montgomery multiplication.
146    ///
147    /// Requires `self` and `other` to be less than `modulus`.
148    ///
149    /// Computes
150    ///
151    /// $$
152    /// \mod{\frac{\mathtt{self} ⋅ \mathtt{other}}{ 2^{64 ·
153    /// \mathtt{LIMBS}}}}_{\mathtt{modulus}} $$
154    ///
155    /// This is useful because it can be computed notably faster than
156    /// [`mul_mod`](Self::mul_mod). Many computations can be done by
157    /// pre-multiplying values with $R = 2^{64 · \mathtt{LIMBS}}$
158    /// and then using [`mul_redc`](Self::mul_redc) instead of
159    /// [`mul_mod`](Self::mul_mod).
160    ///
161    /// For this algorithm to work, it needs an extra parameter `inv` which must
162    /// be set to
163    ///
164    /// $$
165    /// \mathtt{inv} = \mod{\frac{-1}{\mathtt{modulus}} }_{2^{64}}
166    /// $$
167    ///
168    /// The `inv` value only exists for odd values of `modulus`. It can be
169    /// computed using [`inv_ring`](Self::inv_ring) from `U64`.
170    ///
171    /// ```
172    /// # use ruint::{uint, Uint, aliases::*};
173    /// # uint!{
174    /// # let modulus = 21888242871839275222246405745257275088548364400416034343698204186575808495617_U256;
175    /// let inv = U64::wrapping_from(modulus).inv_ring().unwrap().wrapping_neg().to();
176    /// let prod = 5_U256.mul_redc(6_U256, modulus, inv);
177    /// # assert_eq!(inv.wrapping_mul(modulus.wrapping_to()), u64::MAX);
178    /// # assert_eq!(inv, 0xc2e1f593efffffff);
179    /// # }
180    /// ```
181    ///
182    /// # Panics
183    ///
184    /// Panics if `inv` is not correct in debug mode.
185    #[inline]
186    #[must_use]
187    pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
188        if BITS == 0 {
189            return Self::ZERO;
190        }
191        let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv);
192        let result = Self::from_limbs(result);
193        debug_assert!(result < modulus);
194        result
195    }
196
197    /// Montgomery squaring.
198    ///
199    /// See [Self::mul_redc].
200    #[inline]
201    #[must_use]
202    pub fn square_redc(self, modulus: Self, inv: u64) -> Self {
203        if BITS == 0 {
204            return Self::ZERO;
205        }
206        let result = algorithms::square_redc(self.limbs, modulus.limbs, inv);
207        let result = Self::from_limbs(result);
208        debug_assert!(result < modulus);
209        result
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::{aliases::U64, const_for, nlimbs};
217    use proptest::{prop_assume, proptest, test_runner::Config};
218
219    #[test]
220    fn test_commutative() {
221        const_for!(BITS in SIZES {
222            const LIMBS: usize = nlimbs(BITS);
223            type U = Uint<BITS, LIMBS>;
224            proptest!(|(a: U, b: U, m: U)| {
225                assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
226            });
227        });
228    }
229
230    #[test]
231    fn test_associative() {
232        const_for!(BITS in SIZES {
233            const LIMBS: usize = nlimbs(BITS);
234            type U = Uint<BITS, LIMBS>;
235            proptest!(|(a: U, b: U, c: U, m: U)| {
236                assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
237            });
238        });
239    }
240
241    #[test]
242    fn test_distributive() {
243        const_for!(BITS in SIZES {
244            const LIMBS: usize = nlimbs(BITS);
245            type U = Uint<BITS, LIMBS>;
246            proptest!(|(a: U, b: U, c: U, m: U)| {
247                assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
248            });
249        });
250    }
251
252    #[test]
253    fn test_add_identity() {
254        const_for!(BITS in NON_ZERO {
255            const LIMBS: usize = nlimbs(BITS);
256            type U = Uint<BITS, LIMBS>;
257            proptest!(|(value: U, m: U)| {
258                assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
259            });
260        });
261    }
262
263    #[test]
264    fn test_mul_identity() {
265        const_for!(BITS in NON_ZERO {
266            const LIMBS: usize = nlimbs(BITS);
267            type U = Uint<BITS, LIMBS>;
268            proptest!(|(value: U, m: U)| {
269                assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
270                assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
271            });
272        });
273    }
274
275    #[test]
276    fn test_pow_identity() {
277        const_for!(BITS in NON_ZERO {
278            const LIMBS: usize = nlimbs(BITS);
279            type U = Uint<BITS, LIMBS>;
280            proptest!(|(a: U, m: U)| {
281                assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
282                assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
283            });
284        });
285    }
286
287    #[test]
288    fn test_pow_rules() {
289        const_for!(BITS in NON_ZERO {
290            const LIMBS: usize = nlimbs(BITS);
291            type U = Uint<BITS, LIMBS>;
292
293            // Too slow.
294            if LIMBS > 8 {
295                return;
296            }
297
298            let config = Config { cases: 5, ..Default::default() };
299            proptest!(config, |(a: U, b: U, c: U, m: U)| {
300                // TODO: a^(b+c) = a^b * a^c. Which requires carmichael fn.
301                // TODO: (a^b)^c = a^(b * c). Which requires carmichael fn.
302                assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
303            });
304        });
305    }
306
307    #[test]
308    fn test_inv() {
309        const_for!(BITS in NON_ZERO {
310            const LIMBS: usize = nlimbs(BITS);
311            type U = Uint<BITS, LIMBS>;
312            proptest!(|(a: U, m: U)| {
313                if let Some(inv) = a.inv_mod(m) {
314                    assert_eq!(a.mul_mod(inv, m), U::from(1));
315                }
316            });
317        });
318    }
319
320    #[test]
321    fn test_mul_redc() {
322        const_for!(BITS in NON_ZERO if BITS >= 16 {
323            const LIMBS: usize = nlimbs(BITS);
324            type U = Uint<BITS, LIMBS>;
325            proptest!(|(a: U, b: U, m: U)| {
326                prop_assume!(m >= U::from(2));
327                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
328                    let inv = (-inv).as_limbs()[0];
329
330                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
331                    let ar = a.mul_mod(r, m);
332                    let br = b.mul_mod(r, m);
333                    // TODO: Test for larger (>= m) values of a, b.
334
335                    let expected = a.mul_mod(b, m).mul_mod(r, m);
336
337                    assert_eq!(ar.mul_redc(br, m, inv), expected);
338                }
339            });
340        });
341    }
342
343    #[test]
344    fn test_square_redc() {
345        const_for!(BITS in NON_ZERO if BITS >= 16 {
346            const LIMBS: usize = nlimbs(BITS);
347            type U = Uint<BITS, LIMBS>;
348            proptest!(|(a: U, m: U)| {
349                prop_assume!(m >= U::from(2));
350                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
351                    let inv = (-inv).as_limbs()[0];
352
353                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
354                    let ar = a.mul_mod(r, m);
355                    // TODO: Test for larger (>= m) values of a, b.
356
357                    let expected = a.mul_mod(a, m).mul_mod(r, m);
358
359                    assert_eq!(ar.square_redc(m, inv), expected);
360                }
361            });
362        });
363    }
364}