ruint/
modular.rs

1use crate::{algorithms, Uint};
2
3// FEATURE: sub_mod, neg_mod, inv_mod, div_mod, root_mod
4// See <https://en.wikipedia.org/wiki/Cipolla's_algorithm>
5// FEATURE: mul_mod_redc
6// and maybe barrett
7// See also <https://static1.squarespace.com/static/61f7cacf2d7af938cad5b81c/t/62deb4e0c434f7134c2730ee/1658762465114/modular_multiplication.pdf>
8// FEATURE: Modular wrapper class, like Wrapping.
9
10impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11    /// ⚠️ Compute $\mod{\mathtt{self}}_{\mathtt{modulus}}$.
12    ///
13    /// **Warning.** This function is not part of the stable API.
14    ///
15    /// Returns zero if the modulus is zero.
16    // FEATURE: Reduce larger bit-sizes to smaller ones.
17    #[inline]
18    #[must_use]
19    pub fn reduce_mod(mut self, modulus: Self) -> Self {
20        if modulus.is_zero() {
21            return Self::ZERO;
22        }
23        if self >= modulus {
24            self %= modulus;
25        }
26        self
27    }
28
29    /// Compute $\mod{\mathtt{self} + \mathtt{rhs}}_{\mathtt{modulus}}$.
30    ///
31    /// Returns zero if the modulus is zero.
32    #[inline]
33    #[must_use]
34    pub fn add_mod(self, rhs: Self, modulus: Self) -> Self {
35        // Reduce inputs
36        let lhs = self.reduce_mod(modulus);
37        let rhs = rhs.reduce_mod(modulus);
38
39        // Compute the sum and conditionally subtract modulus once.
40        let (mut result, overflow) = lhs.overflowing_add(rhs);
41        if overflow || result >= modulus {
42            result -= modulus;
43        }
44        result
45    }
46
47    /// Compute $\mod{\mathtt{self} ⋅ \mathtt{rhs}}_{\mathtt{modulus}}$.
48    ///
49    /// Returns zero if the modulus is zero.
50    ///
51    /// See [`mul_redc`](Self::mul_redc) for a faster variant at the cost of
52    /// some pre-computation.
53    #[inline]
54    #[must_use]
55    pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
56        if modulus.is_zero() {
57            return Self::ZERO;
58        }
59
60        // Allocate at least `nlimbs(2 * BITS)` limbs to store the product. This array
61        // casting is a workaround for `generic_const_exprs` not being stable.
62        let mut product = [[0u64; 2]; LIMBS];
63        let product_len = crate::nlimbs(2 * BITS);
64        debug_assert!(2 * LIMBS >= product_len);
65        // SAFETY: `[[u64; 2]; LIMBS] == [u64; 2 * LIMBS] >= [u64; nlimbs(2 * BITS)]`.
66        let product = unsafe {
67            core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
68        };
69
70        // Compute full product.
71        let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
72        debug_assert!(!overflow);
73
74        // Compute modulus using `div_rem`.
75        // This stores the remainder in the divisor, `modulus`.
76        algorithms::div(product, &mut modulus.limbs);
77
78        modulus
79    }
80
81    /// Compute $\mod{\mathtt{self}^{\mathtt{rhs}}}_{\mathtt{modulus}}$.
82    ///
83    /// Returns zero if the modulus is zero.
84    #[inline]
85    #[must_use]
86    pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
87        if BITS == 0 || modulus <= Self::ONE {
88            return Self::ZERO;
89        }
90
91        // Exponentiation by squaring
92        let mut result = Self::ONE;
93        while exp > Self::ZERO {
94            // Multiply by base
95            if exp.limbs[0] & 1 == 1 {
96                result = result.mul_mod(self, modulus);
97            }
98
99            // Square base
100            self = self.mul_mod(self, modulus);
101            exp >>= 1;
102        }
103        result
104    }
105
106    /// Compute $\mod{\mathtt{self}^{-1}}_{\mathtt{modulus}}$.
107    ///
108    /// Returns `None` if the inverse does not exist.
109    #[inline]
110    #[must_use]
111    pub fn inv_mod(self, modulus: Self) -> Option<Self> {
112        algorithms::inv_mod(self, modulus)
113    }
114
115    /// Montgomery multiplication.
116    ///
117    /// Requires `self` and `other` to be less than `modulus`.
118    ///
119    /// Computes
120    ///
121    /// $$
122    /// \mod{\frac{\mathtt{self} ⋅ \mathtt{other}}{ 2^{64 ·
123    /// \mathtt{LIMBS}}}}_{\mathtt{modulus}} $$
124    ///
125    /// This is useful because it can be computed notably faster than
126    /// [`mul_mod`](Self::mul_mod). Many computations can be done by
127    /// pre-multiplying values with $R = 2^{64 · \mathtt{LIMBS}}$
128    /// and then using [`mul_redc`](Self::mul_redc) instead of
129    /// [`mul_mod`](Self::mul_mod).
130    ///
131    /// For this algorithm to work, it needs an extra parameter `inv` which must
132    /// be set to
133    ///
134    /// $$
135    /// \mathtt{inv} = \mod{\frac{-1}{\mathtt{modulus}} }_{2^{64}}
136    /// $$
137    ///
138    /// The `inv` value only exists for odd values of `modulus`. It can be
139    /// computed using [`inv_ring`](Self::inv_ring) from `U64`.
140    ///
141    /// ```
142    /// # use ruint::{uint, Uint, aliases::*};
143    /// # uint!{
144    /// # let modulus = 21888242871839275222246405745257275088548364400416034343698204186575808495617_U256;
145    /// let inv = U64::wrapping_from(modulus).inv_ring().unwrap().wrapping_neg().to();
146    /// let prod = 5_U256.mul_redc(6_U256, modulus, inv);
147    /// # assert_eq!(inv.wrapping_mul(modulus.wrapping_to()), u64::MAX);
148    /// # assert_eq!(inv, 0xc2e1f593efffffff);
149    /// # }
150    /// ```
151    ///
152    /// # Panics
153    ///
154    /// Panics if `inv` is not correct in debug mode.
155    #[inline]
156    #[must_use]
157    pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
158        if BITS == 0 {
159            return Self::ZERO;
160        }
161        let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv);
162        let result = Self::from_limbs(result);
163        debug_assert!(result < modulus);
164        result
165    }
166
167    /// Montgomery squaring.
168    ///
169    /// See [Self::mul_redc].
170    #[inline]
171    #[must_use]
172    pub fn square_redc(self, modulus: Self, inv: u64) -> Self {
173        if BITS == 0 {
174            return Self::ZERO;
175        }
176        let result = algorithms::square_redc(self.limbs, modulus.limbs, inv);
177        let result = Self::from_limbs(result);
178        debug_assert!(result < modulus);
179        result
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::{aliases::U64, const_for, nlimbs};
187    use proptest::{prop_assume, proptest, test_runner::Config};
188
189    #[test]
190    fn test_commutative() {
191        const_for!(BITS in SIZES {
192            const LIMBS: usize = nlimbs(BITS);
193            type U = Uint<BITS, LIMBS>;
194            proptest!(|(a: U, b: U, m: U)| {
195                assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
196            });
197        });
198    }
199
200    #[test]
201    fn test_associative() {
202        const_for!(BITS in SIZES {
203            const LIMBS: usize = nlimbs(BITS);
204            type U = Uint<BITS, LIMBS>;
205            proptest!(|(a: U, b: U, c: U, m: U)| {
206                assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
207            });
208        });
209    }
210
211    #[test]
212    fn test_distributive() {
213        const_for!(BITS in SIZES {
214            const LIMBS: usize = nlimbs(BITS);
215            type U = Uint<BITS, LIMBS>;
216            proptest!(|(a: U, b: U, c: U, m: U)| {
217                assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
218            });
219        });
220    }
221
222    #[test]
223    fn test_add_identity() {
224        const_for!(BITS in NON_ZERO {
225            const LIMBS: usize = nlimbs(BITS);
226            type U = Uint<BITS, LIMBS>;
227            proptest!(|(value: U, m: U)| {
228                assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
229            });
230        });
231    }
232
233    #[test]
234    fn test_mul_identity() {
235        const_for!(BITS in NON_ZERO {
236            const LIMBS: usize = nlimbs(BITS);
237            type U = Uint<BITS, LIMBS>;
238            proptest!(|(value: U, m: U)| {
239                assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
240                assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
241            });
242        });
243    }
244
245    #[test]
246    fn test_pow_identity() {
247        const_for!(BITS in NON_ZERO {
248            const LIMBS: usize = nlimbs(BITS);
249            type U = Uint<BITS, LIMBS>;
250            proptest!(|(a: U, m: U)| {
251                assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
252                assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
253            });
254        });
255    }
256
257    #[test]
258    fn test_pow_rules() {
259        const_for!(BITS in NON_ZERO {
260            const LIMBS: usize = nlimbs(BITS);
261            type U = Uint<BITS, LIMBS>;
262
263            // Too slow.
264            if LIMBS > 8 {
265                return;
266            }
267
268            let config = Config { cases: 5, ..Default::default() };
269            proptest!(config, |(a: U, b: U, c: U, m: U)| {
270                // TODO: a^(b+c) = a^b * a^c. Which requires carmichael fn.
271                // TODO: (a^b)^c = a^(b * c). Which requires carmichael fn.
272                assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
273            });
274        });
275    }
276
277    #[test]
278    fn test_inv() {
279        const_for!(BITS in NON_ZERO {
280            const LIMBS: usize = nlimbs(BITS);
281            type U = Uint<BITS, LIMBS>;
282            proptest!(|(a: U, m: U)| {
283                if let Some(inv) = a.inv_mod(m) {
284                    assert_eq!(a.mul_mod(inv, m), U::from(1));
285                }
286            });
287        });
288    }
289
290    #[test]
291    fn test_mul_redc() {
292        const_for!(BITS in NON_ZERO if (BITS >= 16) {
293            const LIMBS: usize = nlimbs(BITS);
294            type U = Uint<BITS, LIMBS>;
295            proptest!(|(a: U, b: U, m: U)| {
296                prop_assume!(m >= U::from(2));
297                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
298                    let inv = (-inv).as_limbs()[0];
299
300                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
301                    let ar = a.mul_mod(r, m);
302                    let br = b.mul_mod(r, m);
303                    // TODO: Test for larger (>= m) values of a, b.
304
305                    let expected = a.mul_mod(b, m).mul_mod(r, m);
306
307                    assert_eq!(ar.mul_redc(br, m, inv), expected);
308                }
309            });
310        });
311    }
312
313    #[test]
314    fn test_square_redc() {
315        const_for!(BITS in NON_ZERO if (BITS >= 16) {
316            const LIMBS: usize = nlimbs(BITS);
317            type U = Uint<BITS, LIMBS>;
318            proptest!(|(a: U, m: U)| {
319                prop_assume!(m >= U::from(2));
320                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
321                    let inv = (-inv).as_limbs()[0];
322
323                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
324                    let ar = a.mul_mod(r, m);
325                    // TODO: Test for larger (>= m) values of a, b.
326
327                    let expected = a.mul_mod(a, m).mul_mod(r, m);
328
329                    assert_eq!(ar.square_redc(m, inv), expected);
330                }
331            });
332        });
333    }
334}