ruint/
mul.rs

1use crate::{algorithms, nlimbs, Uint};
2use core::{
3    iter::Product,
4    num::Wrapping,
5    ops::{Mul, MulAssign},
6};
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9    /// Computes `self * rhs`, returning [`None`] if overflow occurred.
10    #[inline(always)]
11    #[must_use]
12    pub fn checked_mul(self, rhs: Self) -> Option<Self> {
13        match self.overflowing_mul(rhs) {
14            (value, false) => Some(value),
15            _ => None,
16        }
17    }
18
19    /// Calculates the multiplication of self and rhs.
20    ///
21    /// Returns a tuple of the multiplication along with a boolean indicating
22    /// whether an arithmetic overflow would occur. If an overflow would have
23    /// occurred then the wrapped value is returned.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// # use ruint::{Uint, uint};
29    /// # uint!{
30    /// assert_eq!(1_U1.overflowing_mul(1_U1), (1_U1, false));
31    /// assert_eq!(
32    ///     0x010000000000000000_U65.overflowing_mul(0x010000000000000000_U65),
33    ///     (0x000000000000000000_U65, true)
34    /// );
35    /// # }
36    /// ```
37    #[inline]
38    #[must_use]
39    pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
40        let mut result = Self::ZERO;
41        let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
42        if Self::SHOULD_MASK {
43            overflow |= result.limbs[LIMBS - 1] > Self::MASK;
44            result.apply_mask();
45        }
46        (result, overflow)
47    }
48
49    /// Computes `self * rhs`, saturating at the numeric bounds instead of
50    /// overflowing.
51    #[inline(always)]
52    #[must_use]
53    pub fn saturating_mul(self, rhs: Self) -> Self {
54        match self.overflowing_mul(rhs) {
55            (value, false) => value,
56            _ => Self::MAX,
57        }
58    }
59
60    /// Computes `self * rhs`, wrapping around at the boundary of the type.
61    #[inline(always)]
62    #[must_use]
63    pub fn wrapping_mul(self, rhs: Self) -> Self {
64        let mut result = Self::ZERO;
65        algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
66        result.apply_mask();
67        result
68    }
69
70    /// Computes the inverse modulo $2^{\mathtt{BITS}}$ of `self`, returning
71    /// [`None`] if the inverse does not exist.
72    #[inline]
73    #[must_use]
74    pub fn inv_ring(self) -> Option<Self> {
75        if BITS == 0 || self.limbs[0] & 1 == 0 {
76            return None;
77        }
78
79        // Compute inverse of first limb
80        let mut result = Self::ZERO;
81        result.limbs[0] = {
82            const W2: Wrapping<u64> = Wrapping(2);
83            const W3: Wrapping<u64> = Wrapping(3);
84            let n = Wrapping(self.limbs[0]);
85            let mut inv = (n * W3) ^ W2; // Correct on 4 bits.
86            inv *= W2 - n * inv; // Correct on 8 bits.
87            inv *= W2 - n * inv; // Correct on 16 bits.
88            inv *= W2 - n * inv; // Correct on 32 bits.
89            inv *= W2 - n * inv; // Correct on 64 bits.
90            debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
91            inv.0
92        };
93
94        // Continue with rest of limbs
95        let mut correct_limbs = 1;
96        while correct_limbs < LIMBS {
97            result *= Self::from(2) - self * result;
98            correct_limbs *= 2;
99        }
100        result.apply_mask();
101
102        Some(result)
103    }
104
105    /// Calculates the complete product `self * rhs` without the possibility to
106    /// overflow.
107    ///
108    /// The argument `rhs` can be any size [`Uint`], the result size is the sum
109    /// of the bit-sizes of `self` and `rhs`.
110    ///
111    /// # Panics
112    ///
113    /// This function will runtime panic of the const generic arguments are
114    /// incorrect.
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// # use ruint::{Uint, uint};
120    /// # uint!{
121    /// assert_eq!(0_U0.widening_mul(0_U0), 0_U0);
122    /// assert_eq!(1_U1.widening_mul(1_U1), 1_U2);
123    /// assert_eq!(3_U2.widening_mul(7_U3), 21_U5);
124    /// # }
125    /// ```
126    #[inline]
127    #[must_use]
128    #[allow(clippy::similar_names)] // Don't confuse `res` and `rhs`.
129    pub fn widening_mul<
130        const BITS_RHS: usize,
131        const LIMBS_RHS: usize,
132        const BITS_RES: usize,
133        const LIMBS_RES: usize,
134    >(
135        self,
136        rhs: Uint<BITS_RHS, LIMBS_RHS>,
137    ) -> Uint<BITS_RES, LIMBS_RES> {
138        assert_eq!(BITS_RES, BITS + BITS_RHS);
139        assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
140        let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
141        algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
142        if LIMBS_RES > 0 {
143            debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
144        }
145
146        result
147    }
148}
149
150impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
151    #[inline]
152    fn product<I>(iter: I) -> Self
153    where
154        I: Iterator<Item = Self>,
155    {
156        if BITS == 0 {
157            return Self::ZERO;
158        }
159        iter.fold(Self::ONE, Self::wrapping_mul)
160    }
161}
162
163impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
164    #[inline]
165    fn product<I>(iter: I) -> Self
166    where
167        I: Iterator<Item = &'a Self>,
168    {
169        if BITS == 0 {
170            return Self::ZERO;
171        }
172        iter.copied().fold(Self::ONE, Self::wrapping_mul)
173    }
174}
175
176impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::const_for;
182    use proptest::proptest;
183
184    #[test]
185    fn test_commutative() {
186        const_for!(BITS in SIZES {
187            const LIMBS: usize = nlimbs(BITS);
188            type U = Uint<BITS, LIMBS>;
189            proptest!(|(a: U, b: U)| {
190                assert_eq!(a * b, b * a);
191            });
192        });
193    }
194
195    #[test]
196    fn test_associative() {
197        const_for!(BITS in SIZES {
198            const LIMBS: usize = nlimbs(BITS);
199            type U = Uint<BITS, LIMBS>;
200            proptest!(|(a: U, b: U, c: U)| {
201                assert_eq!(a * (b * c), (a * b) * c);
202            });
203        });
204    }
205
206    #[test]
207    fn test_distributive() {
208        const_for!(BITS in SIZES {
209            const LIMBS: usize = nlimbs(BITS);
210            type U = Uint<BITS, LIMBS>;
211            proptest!(|(a: U, b: U, c: U)| {
212                assert_eq!(a * (b + c), (a * b) + (a *c));
213            });
214        });
215    }
216
217    #[test]
218    fn test_identity() {
219        const_for!(BITS in NON_ZERO {
220            const LIMBS: usize = nlimbs(BITS);
221            type U = Uint<BITS, LIMBS>;
222            proptest!(|(value: U)| {
223                assert_eq!(value * U::from(0), U::ZERO);
224                assert_eq!(value * U::from(1), value);
225            });
226        });
227    }
228
229    #[test]
230    fn test_inverse() {
231        const_for!(BITS in NON_ZERO {
232            const LIMBS: usize = nlimbs(BITS);
233            type U = Uint<BITS, LIMBS>;
234            proptest!(|(mut a: U)| {
235                a |= U::from(1); // Make sure a is invertible
236                assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
237                assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
238            });
239        });
240    }
241
242    #[test]
243    fn test_widening_mul() {
244        // Left hand side
245        const_for!(BITS_LHS in BENCH {
246            const LIMBS_LHS: usize = nlimbs(BITS_LHS);
247            type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
248
249            // Right hand side
250            const_for!(BITS_RHS in BENCH {
251                const LIMBS_RHS: usize = nlimbs(BITS_RHS);
252                type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
253
254                // Result
255                const BITS_RES: usize = BITS_LHS + BITS_RHS;
256                const LIMBS_RES: usize = nlimbs(BITS_RES);
257                type Res = Uint<BITS_RES, LIMBS_RES>;
258
259                proptest!(|(lhs: Lhs, rhs: Rhs)| {
260                    // Compute the result using the target size
261                    let expected = Res::from(lhs) * Res::from(rhs);
262                    assert_eq!(lhs.widening_mul(rhs), expected);
263                });
264            });
265        });
266    }
267}