ruint/algorithms/
mul.rs

1#![allow(clippy::module_name_repetitions)]
2
3use crate::algorithms::{borrowing_sub, DoubleWord};
4
5/// ⚠️ Computes `result += a * b` and checks for overflow.
6#[doc = crate::algorithms::unstable_warning!()]
7/// Arrays are in little-endian order. All arrays can be arbitrary sized.
8///
9/// # Algorithm
10///
11/// Trims zeros from inputs, then uses the schoolbook multiplication algorithm.
12/// It takes the shortest input as the outer loop.
13///
14/// # Examples
15///
16/// ```
17/// # use ruint::algorithms::addmul;
18/// let mut result = [0];
19/// let overflow = addmul(&mut result, &[3], &[4]);
20/// assert_eq!(overflow, false);
21/// assert_eq!(result, [12]);
22/// ```
23#[inline(always)]
24pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
25    // Trim zeros from `a`
26    while let [0, rest @ ..] = a {
27        a = rest;
28        if let [_, rest @ ..] = lhs {
29            lhs = rest;
30        }
31    }
32    a = super::trim_end_zeros(a);
33    if a.is_empty() {
34        return false;
35    }
36
37    // Trim zeros from `b`
38    while let [0, rest @ ..] = b {
39        b = rest;
40        if let [_, rest @ ..] = lhs {
41            lhs = rest;
42        }
43    }
44    b = super::trim_end_zeros(b);
45    if b.is_empty() {
46        return false;
47    }
48
49    if lhs.is_empty() {
50        return true;
51    }
52
53    let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };
54
55    // Iterate over limbs of `b` and add partial products to `lhs`.
56    let mut overflow = false;
57    for &b in b {
58        if lhs.len() >= a.len() {
59            let (target, rest) = lhs.split_at_mut(a.len());
60            let carry = addmul_nx1(target, a, b);
61            let carry = add_nx1(rest, carry);
62            overflow |= carry != 0;
63        } else {
64            overflow = true;
65            if lhs.is_empty() {
66                break;
67            }
68            addmul_nx1(lhs, &a[..lhs.len()], b);
69        }
70        lhs = &mut lhs[1..];
71    }
72    overflow
73}
74
75const ADDMUL_N_SMALL_LIMIT: usize = 8;
76
77/// ⚠️ Computes wrapping `result += a * b`, with a fast-path for when all inputs
78/// are the same length and small enough.
79#[doc = crate::algorithms::unstable_warning!()]
80/// See [`addmul`] for more details.
81#[inline(always)]
82pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
83    let n = lhs.len();
84    if n <= ADDMUL_N_SMALL_LIMIT && a.len() == n && b.len() == n {
85        addmul_n_small(lhs, a, b);
86    } else {
87        let _ = addmul(lhs, a, b);
88    }
89}
90
91#[inline(always)]
92fn addmul_n_small(lhs: &mut [u64], a: &[u64], b: &[u64]) {
93    let n = lhs.len();
94    assume!(n <= ADDMUL_N_SMALL_LIMIT);
95    assume!(a.len() == n);
96    assume!(b.len() == n);
97
98    for j in 0..n {
99        let mut carry = 0;
100        for i in 0..(n - j) {
101            (lhs[j + i], carry) = u128::muladd2(a[i], b[j], carry, lhs[j + i]).split();
102        }
103    }
104}
105
106/// ⚠️ Computes `lhs += a` and returns the carry.
107#[doc = crate::algorithms::unstable_warning!()]
108#[inline(always)]
109pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
110    if a == 0 {
111        return 0;
112    }
113    for lhs in lhs {
114        (*lhs, a) = u128::add(*lhs, a).split();
115        if a == 0 {
116            return 0;
117        }
118    }
119    a
120}
121
122/// ⚠️ Computes `lhs *= a` and returns the carry.
123#[doc = crate::algorithms::unstable_warning!()]
124#[inline(always)]
125pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
126    let mut carry = 0;
127    for lhs in lhs {
128        (*lhs, carry) = u128::muladd(*lhs, a, carry).split();
129    }
130    carry
131}
132
133/// ⚠️ Computes `lhs += a * b` and returns the carry.
134#[doc = crate::algorithms::unstable_warning!()]
135/// Requires `lhs.len() == a.len()`.
136///
137/// $$
138/// \begin{aligned}
139/// \mathsf{lhs'} &= \mod{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}}_{2^{64⋅N}}
140/// \\\\ \mathsf{carry} &= \floor{\frac{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}
141/// }{2^{64⋅N}}} \end{aligned}
142/// $$
143#[inline(always)]
144pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
145    assume!(lhs.len() == a.len());
146    let mut carry = 0;
147    for i in 0..a.len() {
148        (lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split();
149    }
150    carry
151}
152
153/// ⚠️ Computes `lhs -= a * b` and returns the borrow.
154#[doc = crate::algorithms::unstable_warning!()]
155/// Requires `lhs.len() == a.len()`.
156///
157/// $$
158/// \begin{aligned}
159/// \mathsf{lhs'} &= \mod{\mathsf{lhs} - \mathsf{a} ⋅ \mathsf{b}}_{2^{64⋅N}}
160/// \\\\ \mathsf{borrow} &= \floor{\frac{\mathsf{a} ⋅ \mathsf{b} -
161/// \mathsf{lhs}}{2^{64⋅N}}} \end{aligned}
162/// $$
163// OPT: `carry` and `borrow` can probably be merged into a single var.
164#[inline(always)]
165pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
166    assume!(lhs.len() == a.len());
167    let mut carry = 0;
168    let mut borrow = false;
169    for i in 0..a.len() {
170        // Compute product limbs
171        let limb;
172        (limb, carry) = u128::muladd(a[i], b, carry).split();
173
174        // Subtract
175        (lhs[i], borrow) = borrowing_sub(lhs[i], limb, borrow);
176    }
177    borrow as u64 + carry
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use proptest::{collection, num::u64, proptest};
184
185    #[allow(clippy::cast_possible_truncation)] // Intentional truncation.
186    fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
187        let mut overflow = 0;
188        for (i, a) in a.iter().copied().enumerate() {
189            let mut result = result.iter_mut().skip(i);
190            let mut b = b.iter().copied();
191            let mut carry = 0_u128;
192            loop {
193                match (result.next(), b.next()) {
194                    // Partial product.
195                    (Some(result), Some(b)) => {
196                        carry += u128::from(*result) + u128::from(a) * u128::from(b);
197                        *result = carry as u64;
198                        carry >>= 64;
199                    }
200                    // Carry propagation.
201                    (Some(result), None) => {
202                        carry += u128::from(*result);
203                        *result = carry as u64;
204                        carry >>= 64;
205                    }
206                    // Excess product.
207                    (None, Some(b)) => {
208                        carry += u128::from(a) * u128::from(b);
209                        overflow |= carry as u64;
210                        carry >>= 64;
211                    }
212                    // Fin.
213                    (None, None) => {
214                        break;
215                    }
216                }
217            }
218            overflow |= carry as u64;
219        }
220        overflow != 0
221    }
222
223    #[test]
224    fn test_addmul() {
225        let any_vec = collection::vec(u64::ANY, 0..10);
226        proptest!(|(mut lhs in &any_vec, a in &any_vec, b in &any_vec)| {
227            // Reference
228            let mut ref_lhs = lhs.clone();
229            let ref_overflow = addmul_ref(&mut ref_lhs, &a, &b);
230
231            // Test
232            let overflow = addmul(&mut lhs, &a, &b);
233            assert_eq!(lhs, ref_lhs);
234            assert_eq!(overflow, ref_overflow);
235        });
236    }
237
238    fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
239        let mut result = vec![0; expected.len()];
240        let overflow = addmul(&mut result, lhs, rhs);
241        assert_eq!(overflow, expected_overflow);
242        assert_eq!(result, expected);
243    }
244
245    #[test]
246    fn test_empty() {
247        test_vals(&[], &[], &[], false);
248        test_vals(&[], &[1], &[], false);
249        test_vals(&[1], &[], &[], false);
250        test_vals(&[1], &[1], &[], true);
251        test_vals(&[], &[], &[0], false);
252        test_vals(&[], &[1], &[0], false);
253        test_vals(&[1], &[], &[0], false);
254        test_vals(&[1], &[1], &[1], false);
255    }
256
257    #[test]
258    fn test_submul_nx1() {
259        let mut lhs = [
260            15520854688669198950,
261            13760048731709406392,
262            14363314282014368551,
263            13263184899940581802,
264        ];
265        let a = [
266            7955980792890017645,
267            6297379555503105007,
268            2473663400150304794,
269            18362433840513668572,
270        ];
271        let b = 17275533833223164845;
272        let borrow = submul_nx1(&mut lhs, &a, b);
273        assert_eq!(lhs, [
274            2427453526388035261,
275            7389014268281543265,
276            6670181329660292018,
277            8411211985208067428
278        ]);
279        assert_eq!(borrow, 17196576577663999042);
280    }
281}