1use crate::{algorithms, nlimbs, Uint};
2use core::{
3 iter::Product,
4 num::Wrapping,
5 ops::{Mul, MulAssign},
6};
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9 #[inline(always)]
11 #[must_use]
12 pub fn checked_mul(self, rhs: Self) -> Option<Self> {
13 match self.overflowing_mul(rhs) {
14 (value, false) => Some(value),
15 _ => None,
16 }
17 }
18
19 #[inline]
38 #[must_use]
39 pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
40 let mut result = Self::ZERO;
41 let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
42 if Self::SHOULD_MASK {
43 overflow |= result.limbs[LIMBS - 1] > Self::MASK;
44 result.apply_mask();
45 }
46 (result, overflow)
47 }
48
49 #[inline(always)]
52 #[must_use]
53 pub fn saturating_mul(self, rhs: Self) -> Self {
54 match self.overflowing_mul(rhs) {
55 (value, false) => value,
56 _ => Self::MAX,
57 }
58 }
59
60 #[inline(always)]
62 #[must_use]
63 pub fn wrapping_mul(self, rhs: Self) -> Self {
64 let mut result = Self::ZERO;
65 algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
66 result.apply_mask();
67 result
68 }
69
70 #[inline]
73 #[must_use]
74 pub fn inv_ring(self) -> Option<Self> {
75 if BITS == 0 || self.limbs[0] & 1 == 0 {
76 return None;
77 }
78
79 let mut result = Self::ZERO;
81 result.limbs[0] = {
82 const W2: Wrapping<u64> = Wrapping(2);
83 const W3: Wrapping<u64> = Wrapping(3);
84 let n = Wrapping(self.limbs[0]);
85 let mut inv = (n * W3) ^ W2; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
91 inv.0
92 };
93
94 let mut correct_limbs = 1;
96 while correct_limbs < LIMBS {
97 result *= Self::from(2) - self * result;
98 correct_limbs *= 2;
99 }
100 result.apply_mask();
101
102 Some(result)
103 }
104
105 #[inline]
127 #[must_use]
128 #[allow(clippy::similar_names)] pub fn widening_mul<
130 const BITS_RHS: usize,
131 const LIMBS_RHS: usize,
132 const BITS_RES: usize,
133 const LIMBS_RES: usize,
134 >(
135 self,
136 rhs: Uint<BITS_RHS, LIMBS_RHS>,
137 ) -> Uint<BITS_RES, LIMBS_RES> {
138 assert_eq!(BITS_RES, BITS + BITS_RHS);
139 assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
140 let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
141 algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
142 if LIMBS_RES > 0 {
143 debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
144 }
145
146 result
147 }
148}
149
150impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
151 #[inline]
152 fn product<I>(iter: I) -> Self
153 where
154 I: Iterator<Item = Self>,
155 {
156 if BITS == 0 {
157 return Self::ZERO;
158 }
159 iter.fold(Self::ONE, Self::wrapping_mul)
160 }
161}
162
163impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
164 #[inline]
165 fn product<I>(iter: I) -> Self
166 where
167 I: Iterator<Item = &'a Self>,
168 {
169 if BITS == 0 {
170 return Self::ZERO;
171 }
172 iter.copied().fold(Self::ONE, Self::wrapping_mul)
173 }
174}
175
176impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::const_for;
182 use proptest::proptest;
183
184 #[test]
185 fn test_commutative() {
186 const_for!(BITS in SIZES {
187 const LIMBS: usize = nlimbs(BITS);
188 type U = Uint<BITS, LIMBS>;
189 proptest!(|(a: U, b: U)| {
190 assert_eq!(a * b, b * a);
191 });
192 });
193 }
194
195 #[test]
196 fn test_associative() {
197 const_for!(BITS in SIZES {
198 const LIMBS: usize = nlimbs(BITS);
199 type U = Uint<BITS, LIMBS>;
200 proptest!(|(a: U, b: U, c: U)| {
201 assert_eq!(a * (b * c), (a * b) * c);
202 });
203 });
204 }
205
206 #[test]
207 fn test_distributive() {
208 const_for!(BITS in SIZES {
209 const LIMBS: usize = nlimbs(BITS);
210 type U = Uint<BITS, LIMBS>;
211 proptest!(|(a: U, b: U, c: U)| {
212 assert_eq!(a * (b + c), (a * b) + (a *c));
213 });
214 });
215 }
216
217 #[test]
218 fn test_identity() {
219 const_for!(BITS in NON_ZERO {
220 const LIMBS: usize = nlimbs(BITS);
221 type U = Uint<BITS, LIMBS>;
222 proptest!(|(value: U)| {
223 assert_eq!(value * U::from(0), U::ZERO);
224 assert_eq!(value * U::from(1), value);
225 });
226 });
227 }
228
229 #[test]
230 fn test_inverse() {
231 const_for!(BITS in NON_ZERO {
232 const LIMBS: usize = nlimbs(BITS);
233 type U = Uint<BITS, LIMBS>;
234 proptest!(|(mut a: U)| {
235 a |= U::from(1); assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
237 assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
238 });
239 });
240 }
241
242 #[test]
243 fn test_widening_mul() {
244 const_for!(BITS_LHS in BENCH {
246 const LIMBS_LHS: usize = nlimbs(BITS_LHS);
247 type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
248
249 const_for!(BITS_RHS in BENCH {
251 const LIMBS_RHS: usize = nlimbs(BITS_RHS);
252 type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
253
254 const BITS_RES: usize = BITS_LHS + BITS_RHS;
256 const LIMBS_RES: usize = nlimbs(BITS_RES);
257 type Res = Uint<BITS_RES, LIMBS_RES>;
258
259 proptest!(|(lhs: Lhs, rhs: Rhs)| {
260 let expected = Res::from(lhs) * Res::from(rhs);
262 assert_eq!(lhs.widening_mul(rhs), expected);
263 });
264 });
265 });
266 }
267}