ruint/algorithms/gcd/
matrix.rs

1#![allow(clippy::use_self)]
2
3use crate::Uint;
4
5/// ⚠️ Lehmer update matrix
6///
7/// **Warning.** This struct is not part of the stable API.
8///
9/// Signs are implicit, the boolean `.4` encodes which of two sign
10/// patterns applies. The signs and layout of the matrix are:
11///
12/// ```text
13///     true          false
14///  [ .0  -.1]    [-.0   .1]
15///  [-.2   .3]    [ .2  -.3]
16/// ```
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub struct Matrix(pub u64, pub u64, pub u64, pub u64, pub bool);
19
20impl Matrix {
21    pub const IDENTITY: Self = Self(1, 0, 0, 1, true);
22
23    /// Returns the matrix product `self * other`.
24    #[inline]
25    #[allow(clippy::suspicious_operation_groupings)]
26    #[must_use]
27    pub const fn compose(self, other: Self) -> Self {
28        Self(
29            self.0 * other.0 + self.1 * other.2,
30            self.0 * other.1 + self.1 * other.3,
31            self.2 * other.0 + self.3 * other.2,
32            self.2 * other.1 + self.3 * other.3,
33            self.4 ^ !other.4,
34        )
35    }
36
37    /// Applies the matrix to a `Uint`.
38    #[inline]
39    pub fn apply<const BITS: usize, const LIMBS: usize>(
40        &self,
41        a: &mut Uint<BITS, LIMBS>,
42        b: &mut Uint<BITS, LIMBS>,
43    ) {
44        if BITS == 0 {
45            return;
46        }
47        // OPT: We can avoid the temporary if we implement a dedicated matrix
48        // multiplication.
49        let (c, d) = if self.4 {
50            (
51                Uint::from(self.0) * *a - Uint::from(self.1) * *b,
52                Uint::from(self.3) * *b - Uint::from(self.2) * *a,
53            )
54        } else {
55            (
56                Uint::from(self.1) * *b - Uint::from(self.0) * *a,
57                Uint::from(self.2) * *a - Uint::from(self.3) * *b,
58            )
59        };
60        *a = c;
61        *b = d;
62    }
63
64    /// Applies the matrix to a `u128`.
65    #[inline]
66    #[must_use]
67    pub const fn apply_u128(&self, a: u128, b: u128) -> (u128, u128) {
68        // Intermediate values can overflow but the final result will fit, so we
69        // compute mod 2^128.
70        if self.4 {
71            (
72                (self.0 as u128)
73                    .wrapping_mul(a)
74                    .wrapping_sub((self.1 as u128).wrapping_mul(b)),
75                (self.3 as u128)
76                    .wrapping_mul(b)
77                    .wrapping_sub((self.2 as u128).wrapping_mul(a)),
78            )
79        } else {
80            (
81                (self.1 as u128)
82                    .wrapping_mul(b)
83                    .wrapping_sub((self.0 as u128).wrapping_mul(a)),
84                (self.2 as u128)
85                    .wrapping_mul(a)
86                    .wrapping_sub((self.3 as u128).wrapping_mul(b)),
87            )
88        }
89    }
90
91    /// Compute a Lehmer update matrix from two `Uint`s.
92    ///
93    /// # Panics
94    ///
95    /// Panics if `b > a`.
96    #[inline]
97    #[must_use]
98    pub fn from<const BITS: usize, const LIMBS: usize>(
99        a: Uint<BITS, LIMBS>,
100        b: Uint<BITS, LIMBS>,
101    ) -> Self {
102        assert!(a >= b);
103
104        // Grab the first 128 bits.
105        let s = a.bit_len();
106        if s <= 64 {
107            Self::from_u64(a.try_into().unwrap(), b.try_into().unwrap())
108        } else if s <= 128 {
109            Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
110        } else {
111            let a = a >> (s - 128);
112            let b = b >> (s - 128);
113            Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
114        }
115    }
116
117    /// Compute the Lehmer update matrix for small values.
118    ///
119    /// This is essentially Euclids extended GCD algorithm for 64 bits.
120    ///
121    /// # Panics
122    ///
123    /// Panics if `r0 < r1`.
124    // OPT: Would this be faster using extended binary gcd?
125    // See <https://en.algorithmica.org/hpc/algorithms/gcd>
126    #[inline]
127    #[must_use]
128    pub fn from_u64(mut r0: u64, mut r1: u64) -> Self {
129        debug_assert!(r0 >= r1);
130        if r1 == 0_u64 {
131            return Matrix::IDENTITY;
132        }
133        let mut q00 = 1_u64;
134        let mut q01 = 0_u64;
135        let mut q10 = 0_u64;
136        let mut q11 = 1_u64;
137        loop {
138            // Loop is unrolled once to avoid swapping variables and tracking parity.
139            let q = r0 / r1;
140            r0 -= q * r1;
141            q00 += q * q10;
142            q01 += q * q11;
143            if r0 == 0_u64 {
144                return Matrix(q10, q11, q00, q01, false);
145            }
146            let q = r1 / r0;
147            r1 -= q * r0;
148            q10 += q * q00;
149            q11 += q * q01;
150            if r1 == 0_u64 {
151                return Matrix(q00, q01, q10, q11, true);
152            }
153        }
154    }
155
156    /// Compute the largest valid Lehmer update matrix for a prefix.
157    ///
158    /// Compute the Lehmer update matrix for a0 and a1 such that the matrix is
159    /// valid for any two large integers starting with the bits of a0 and
160    /// a1.
161    ///
162    /// See also `mpn_hgcd2` in GMP, but ours handles the double precision bit
163    /// separately in `lehmer_double`.
164    /// <https://gmplib.org/repo/gmp-6.1/file/tip/mpn/generic/hgcd2.c#l226>
165    ///
166    /// # Panics
167    ///
168    /// Panics if `a0` does not have the highest bit set.
169    /// Panics if `a0 < a1`.
170    #[inline]
171    #[must_use]
172    #[allow(clippy::redundant_else)]
173    #[allow(clippy::cognitive_complexity)] // REFACTOR: Improve
174    pub fn from_u64_prefix(a0: u64, mut a1: u64) -> Self {
175        const LIMIT: u64 = 1_u64 << 32;
176        debug_assert!(a0 >= 1_u64 << 63);
177        debug_assert!(a0 >= a1);
178
179        // Here we do something original: The cofactors undergo identical
180        // operations which makes them a candidate for SIMD instructions.
181        // They also never exceed 32 bit, so we can SWAR them in a single u64.
182        let mut k0 = 1_u64 << 32; // u0 = 1, v0 = 0
183        let mut k1 = 1_u64; // u1 = 0, v1 = 1
184        let mut even = true;
185        if a1 < LIMIT {
186            return Matrix::IDENTITY;
187        }
188
189        // Compute a2
190        let q = a0 / a1;
191        let mut a2 = a0 - q * a1;
192        let mut k2 = k0 + q * k1;
193        if a2 < LIMIT {
194            let u2 = k2 >> 32;
195            let v2 = k2 % LIMIT;
196
197            // Test i + 1 (odd)
198            if a2 >= v2 && a1 - a2 >= u2 {
199                return Matrix(0, 1, u2, v2, false);
200            } else {
201                return Matrix::IDENTITY;
202            }
203        }
204
205        // Compute a3
206        let q = a1 / a2;
207        let mut a3 = a1 - q * a2;
208        let mut k3 = k1 + q * k2;
209
210        // Loop until a3 < LIMIT, maintaining the last three values
211        // of a and the last four values of k.
212        while a3 >= LIMIT {
213            a1 = a2;
214            a2 = a3;
215            a3 = a1;
216            k0 = k1;
217            k1 = k2;
218            k2 = k3;
219            k3 = k1;
220            debug_assert!(a2 < a3);
221            debug_assert!(a2 > 0);
222            let q = a3 / a2;
223            a3 -= q * a2;
224            k3 += q * k2;
225            if a3 < LIMIT {
226                even = false;
227                break;
228            }
229            a1 = a2;
230            a2 = a3;
231            a3 = a1;
232            k0 = k1;
233            k1 = k2;
234            k2 = k3;
235            k3 = k1;
236            debug_assert!(a2 < a3);
237            debug_assert!(a2 > 0);
238            let q = a3 / a2;
239            a3 -= q * a2;
240            k3 += q * k2;
241        }
242        // Unpack k into cofactors u and v
243        let u0 = k0 >> 32;
244        let u1 = k1 >> 32;
245        let u2 = k2 >> 32;
246        let u3 = k3 >> 32;
247        let v0 = k0 % LIMIT;
248        let v1 = k1 % LIMIT;
249        let v2 = k2 % LIMIT;
250        let v3 = k3 % LIMIT;
251        debug_assert!(a2 >= LIMIT);
252        debug_assert!(a3 < LIMIT);
253
254        // Use Jebelean's exact condition to determine which outputs are correct.
255        // Statistically, i + 2 should be correct about two-thirds of the time.
256        if even {
257            // Test i + 1 (odd)
258            debug_assert!(a2 >= v2);
259            if a1 - a2 >= u2 + u1 {
260                // Test i + 2 (even)
261                if a3 >= u3 && a2 - a3 >= v3 + v2 {
262                    // Correct value is i + 2
263                    Matrix(u2, v2, u3, v3, true)
264                } else {
265                    // Correct value is i + 1
266                    Matrix(u1, v1, u2, v2, false)
267                }
268            } else {
269                // Correct value is i
270                Matrix(u0, v0, u1, v1, true)
271            }
272        } else {
273            // Test i + 1 (even)
274            debug_assert!(a2 >= u2);
275            if a1 - a2 >= v2 + v1 {
276                // Test i + 2 (odd)
277                if a3 >= v3 && a2 - a3 >= u3 + u2 {
278                    // Correct value is i + 2
279                    Matrix(u2, v2, u3, v3, false)
280                } else {
281                    // Correct value is i + 1
282                    Matrix(u1, v1, u2, v2, true)
283                }
284            } else {
285                // Correct value is i
286                Matrix(u0, v0, u1, v1, false)
287            }
288        }
289    }
290
291    /// Compute the Lehmer update matrix in full 64 bit precision.
292    ///
293    /// Jebelean solves this by starting in double-precission followed
294    /// by single precision once values are small enough.
295    /// Cohen instead runs a single precision round, refreshes the r0 and r1
296    /// values and continues with another single precision round on top.
297    /// Our approach is similar to Cohen, but instead doing the second round
298    /// on the same matrix, we start we a fresh matrix and multiply both in the
299    /// end. This requires 8 additional multiplications, but allows us to use
300    /// the tighter stopping conditions from Jebelean. It also seems the
301    /// simplest out of these solutions.
302    // OPT: We can update r0 and r1 in place. This won't remove the partially
303    // redundant call to lehmer_update, but it reduces memory usage.
304    #[inline]
305    #[must_use]
306    pub fn from_u128_prefix(r0: u128, r1: u128) -> Self {
307        debug_assert!(r0 >= r1);
308        let s = r0.leading_zeros();
309        let r0s = r0 << s;
310        let r1s = r1 << s;
311        let q = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as u64);
312        if q == Matrix::IDENTITY {
313            return q;
314        }
315        // We can return q here and have a perfectly valid single-word Lehmer GCD.
316        q
317        // OPT: Fix the below method to get double-word Lehmer GCD.
318
319        // Recompute r0 and r1 and take the high bits.
320        // TODO: Is it safe to do this based on just the u128 prefix?
321        // let (r0, r1) = q.apply_u128(r0, r1);
322        // let s = r0.leading_zeros();
323        // let r0s = r0 << s;
324        // let r1s = r1 << s;
325        // let qn = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as
326        // u64);
327
328        // // Multiply matrices qn * q
329        // qn.compose(q)
330    }
331}
332
333#[cfg(test)]
334#[allow(clippy::cast_lossless)]
335#[allow(clippy::many_single_char_names)]
336mod tests {
337    use super::*;
338    use crate::{const_for, nlimbs};
339    use core::{
340        cmp::{max, min},
341        mem::swap,
342        str::FromStr,
343    };
344    use proptest::{proptest, test_runner::Config};
345
346    fn gcd(mut a: u128, mut b: u128) -> u128 {
347        while b != 0 {
348            a %= b;
349            swap(&mut a, &mut b);
350        }
351        a
352    }
353
354    fn gcd_uint<const BITS: usize, const LIMBS: usize>(
355        mut a: Uint<BITS, LIMBS>,
356        mut b: Uint<BITS, LIMBS>,
357    ) -> Uint<BITS, LIMBS> {
358        while b != Uint::ZERO {
359            a %= b;
360            swap(&mut a, &mut b);
361        }
362        a
363    }
364
365    #[test]
366    fn test_from_u64_example() {
367        let (a, b) = (252, 105);
368        let m = Matrix::from_u64(a, b);
369        assert_eq!(m, Matrix(2, 5, 5, 12, false));
370        let (a, b) = m.apply_u128(a as u128, b as u128);
371        assert_eq!(a, 21);
372        assert_eq!(b, 0);
373    }
374
375    #[test]
376    fn test_from_u64() {
377        proptest!(|(a: u64, b: u64)| {
378            let (a, b) = (max(a,b), min(a,b));
379            let m = Matrix::from_u64(a, b);
380            let (c, d) = m.apply_u128(a as u128, b as u128);
381            assert!(c >= d);
382            assert_eq!(c, gcd(a as u128, b as u128));
383            assert_eq!(d, 0);
384        });
385    }
386
387    #[test]
388    fn test_from_u64_prefix() {
389        proptest!(|(a: u128, b: u128)| {
390            // Prepare input
391            let (a, b) = (max(a,b), min(a,b));
392            let s = a.leading_zeros();
393            let (sa, sb) = (a << s, b << s);
394
395            let m = Matrix::from_u64_prefix((sa >> 64) as u64, (sb >> 64) as u64);
396            let (c, d) = m.apply_u128(a, b);
397            assert!(c >= d);
398            if m == Matrix::IDENTITY {
399                assert_eq!(c, a);
400                assert_eq!(d, b);
401            } else {
402                assert!(c <= a);
403                assert!(d < b);
404                assert_eq!(gcd(a, b), gcd(c, d));
405            }
406        });
407    }
408
409    fn test_form_uint_one<const BITS: usize, const LIMBS: usize>(
410        a: Uint<BITS, LIMBS>,
411        b: Uint<BITS, LIMBS>,
412    ) {
413        let (a, b) = (max(a, b), min(a, b));
414        let m = Matrix::from(a, b);
415        let (mut c, mut d) = (a, b);
416        m.apply(&mut c, &mut d);
417        assert!(c >= d);
418        if m == Matrix::IDENTITY {
419            assert_eq!(c, a);
420            assert_eq!(d, b);
421        } else {
422            assert!(c <= a);
423            assert!(d < b);
424            assert_eq!(gcd_uint(a, b), gcd_uint(c, d));
425        }
426    }
427
428    #[test]
429    fn test_from_uint_cases() {
430        // This case fails with the double-word version above.
431        type U129 = Uint<129, 3>;
432        test_form_uint_one(
433            U129::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap(),
434            U129::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap(),
435        );
436    }
437
438    #[test]
439    #[allow(clippy::absurd_extreme_comparisons)] // Generated code
440    fn test_from_uint_proptest() {
441        const_for!(BITS in SIZES {
442            const LIMBS: usize = nlimbs(BITS);
443            type U = Uint<BITS, LIMBS>;
444            let config = Config { cases: 10, ..Default::default() };
445            proptest!(config, |(a: U, b: U)| {
446                test_form_uint_one(a, b);
447            });
448        });
449    }
450}