ruint/algorithms/gcd/
matrix.rs

1#![allow(clippy::use_self)]
2
3use crate::Uint;
4
5/// ⚠️ Lehmer update matrix
6#[doc = crate::algorithms::unstable_warning!()]
7/// Signs are implicit, the boolean `.4` encodes which of two sign
8/// patterns applies. The signs and layout of the matrix are:
9///
10/// ```text
11///     true          false
12///  [ .0  -.1]    [-.0   .1]
13///  [-.2   .3]    [ .2  -.3]
14/// ```
15#[derive(Clone, Copy, PartialEq, Eq, Debug)]
16pub struct Matrix(pub u64, pub u64, pub u64, pub u64, pub bool);
17
18impl Matrix {
19    pub const IDENTITY: Self = Self(1, 0, 0, 1, true);
20
21    /// ⚠️ Returns the matrix product `self * other`.
22    #[doc = crate::algorithms::unstable_warning!()]
23    #[inline]
24    #[allow(clippy::suspicious_operation_groupings)]
25    #[must_use]
26    pub const fn compose(self, other: Self) -> Self {
27        Self(
28            self.0 * other.0 + self.1 * other.2,
29            self.0 * other.1 + self.1 * other.3,
30            self.2 * other.0 + self.3 * other.2,
31            self.2 * other.1 + self.3 * other.3,
32            self.4 ^ !other.4,
33        )
34    }
35
36    /// ⚠️ Applies the matrix to a `Uint`.
37    #[doc = crate::algorithms::unstable_warning!()]
38    #[inline]
39    pub fn apply<const BITS: usize, const LIMBS: usize>(
40        &self,
41        a: &mut Uint<BITS, LIMBS>,
42        b: &mut Uint<BITS, LIMBS>,
43    ) {
44        if BITS == 0 {
45            return;
46        }
47        // OPT: We can avoid the temporary if we implement a dedicated matrix
48        // multiplication.
49        let (c, d) = if self.4 {
50            (
51                Uint::from(self.0) * *a - Uint::from(self.1) * *b,
52                Uint::from(self.3) * *b - Uint::from(self.2) * *a,
53            )
54        } else {
55            (
56                Uint::from(self.1) * *b - Uint::from(self.0) * *a,
57                Uint::from(self.2) * *a - Uint::from(self.3) * *b,
58            )
59        };
60        *a = c;
61        *b = d;
62    }
63
64    /// ⚠️ Applies the matrix to a `u128`.
65    #[doc = crate::algorithms::unstable_warning!()]
66    #[inline]
67    #[must_use]
68    pub const fn apply_u128(&self, a: u128, b: u128) -> (u128, u128) {
69        // Intermediate values can overflow but the final result will fit, so we
70        // compute mod 2^128.
71        if self.4 {
72            (
73                (self.0 as u128)
74                    .wrapping_mul(a)
75                    .wrapping_sub((self.1 as u128).wrapping_mul(b)),
76                (self.3 as u128)
77                    .wrapping_mul(b)
78                    .wrapping_sub((self.2 as u128).wrapping_mul(a)),
79            )
80        } else {
81            (
82                (self.1 as u128)
83                    .wrapping_mul(b)
84                    .wrapping_sub((self.0 as u128).wrapping_mul(a)),
85                (self.2 as u128)
86                    .wrapping_mul(a)
87                    .wrapping_sub((self.3 as u128).wrapping_mul(b)),
88            )
89        }
90    }
91
92    #[allow(clippy::missing_panics_doc)] // False positive.
93    /// ⚠️ Compute a Lehmer update matrix from two `Uint`s.
94    #[doc = crate::algorithms::unstable_warning!()]
95    /// # Panics
96    ///
97    /// Panics if `b > a`.
98    #[inline]
99    #[must_use]
100    pub fn from<const BITS: usize, const LIMBS: usize>(
101        a: Uint<BITS, LIMBS>,
102        b: Uint<BITS, LIMBS>,
103    ) -> Self {
104        assert!(a >= b);
105
106        // Grab the first 128 bits.
107        let s = a.bit_len();
108        if s <= 64 {
109            Self::from_u64(a.try_into().unwrap(), b.try_into().unwrap())
110        } else if s <= 128 {
111            Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
112        } else {
113            let a = a >> (s - 128);
114            let b = b >> (s - 128);
115            Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
116        }
117    }
118
119    /// ⚠️ Compute the Lehmer update matrix for small values.
120    #[doc = crate::algorithms::unstable_warning!()]
121    /// This is essentially Euclids extended GCD algorithm for 64 bits.
122    ///
123    /// # Panics
124    ///
125    /// Panics if `r0 < r1`.
126    // OPT: Would this be faster using extended binary gcd?
127    // See <https://en.algorithmica.org/hpc/algorithms/gcd>
128    #[inline]
129    #[must_use]
130    pub fn from_u64(mut r0: u64, mut r1: u64) -> Self {
131        debug_assert!(r0 >= r1);
132        if r1 == 0_u64 {
133            return Matrix::IDENTITY;
134        }
135        let mut q00 = 1_u64;
136        let mut q01 = 0_u64;
137        let mut q10 = 0_u64;
138        let mut q11 = 1_u64;
139        loop {
140            // Loop is unrolled once to avoid swapping variables and tracking parity.
141            let q = r0 / r1;
142            r0 -= q * r1;
143            q00 += q * q10;
144            q01 += q * q11;
145            if r0 == 0_u64 {
146                return Matrix(q10, q11, q00, q01, false);
147            }
148            let q = r1 / r0;
149            r1 -= q * r0;
150            q10 += q * q00;
151            q11 += q * q01;
152            if r1 == 0_u64 {
153                return Matrix(q00, q01, q10, q11, true);
154            }
155        }
156    }
157
158    /// ⚠️ Compute the largest valid Lehmer update matrix for a prefix.
159    #[doc = crate::algorithms::unstable_warning!()]
160    /// Compute the Lehmer update matrix for a0 and a1 such that the matrix is
161    /// valid for any two large integers starting with the bits of a0 and
162    /// a1.
163    ///
164    /// See also `mpn_hgcd2` in GMP, but ours handles the double precision bit
165    /// separately in `lehmer_double`.
166    /// <https://gmplib.org/repo/gmp-6.1/file/tip/mpn/generic/hgcd2.c#l226>
167    ///
168    /// # Panics
169    ///
170    /// Panics if `a0` does not have the highest bit set.
171    /// Panics if `a0 < a1`.
172    #[inline]
173    #[must_use]
174    #[allow(clippy::redundant_else)]
175    #[allow(clippy::cognitive_complexity)] // REFACTOR: Improve
176    pub fn from_u64_prefix(a0: u64, mut a1: u64) -> Self {
177        const LIMIT: u64 = 1_u64 << 32;
178        debug_assert!(a0 >= 1_u64 << 63);
179        debug_assert!(a0 >= a1);
180
181        // Here we do something original: The cofactors undergo identical
182        // operations which makes them a candidate for SIMD instructions.
183        // They also never exceed 32 bit, so we can SWAR them in a single u64.
184        let mut k0 = 1_u64 << 32; // u0 = 1, v0 = 0
185        let mut k1 = 1_u64; // u1 = 0, v1 = 1
186        let mut even = true;
187        if a1 < LIMIT {
188            return Matrix::IDENTITY;
189        }
190
191        // Compute a2
192        let q = a0 / a1;
193        let mut a2 = a0 - q * a1;
194        let mut k2 = k0 + q * k1;
195        if a2 < LIMIT {
196            let u2 = k2 >> 32;
197            let v2 = k2 % LIMIT;
198
199            // Test i + 1 (odd)
200            if a2 >= v2 && a1 - a2 >= u2 {
201                return Matrix(0, 1, u2, v2, false);
202            } else {
203                return Matrix::IDENTITY;
204            }
205        }
206
207        // Compute a3
208        let q = a1 / a2;
209        let mut a3 = a1 - q * a2;
210        let mut k3 = k1 + q * k2;
211
212        // Loop until a3 < LIMIT, maintaining the last three values
213        // of a and the last four values of k.
214        while a3 >= LIMIT {
215            a1 = a2;
216            a2 = a3;
217            a3 = a1;
218            k0 = k1;
219            k1 = k2;
220            k2 = k3;
221            k3 = k1;
222            debug_assert!(a2 < a3);
223            debug_assert!(a2 > 0);
224            let q = a3 / a2;
225            a3 -= q * a2;
226            k3 += q * k2;
227            if a3 < LIMIT {
228                even = false;
229                break;
230            }
231            a1 = a2;
232            a2 = a3;
233            a3 = a1;
234            k0 = k1;
235            k1 = k2;
236            k2 = k3;
237            k3 = k1;
238            debug_assert!(a2 < a3);
239            debug_assert!(a2 > 0);
240            let q = a3 / a2;
241            a3 -= q * a2;
242            k3 += q * k2;
243        }
244        // Unpack k into cofactors u and v
245        let u0 = k0 >> 32;
246        let u1 = k1 >> 32;
247        let u2 = k2 >> 32;
248        let u3 = k3 >> 32;
249        let v0 = k0 % LIMIT;
250        let v1 = k1 % LIMIT;
251        let v2 = k2 % LIMIT;
252        let v3 = k3 % LIMIT;
253        debug_assert!(a2 >= LIMIT);
254        debug_assert!(a3 < LIMIT);
255
256        // Use Jebelean's exact condition to determine which outputs are correct.
257        // Statistically, i + 2 should be correct about two-thirds of the time.
258        if even {
259            // Test i + 1 (odd)
260            debug_assert!(a2 >= v2);
261            if a1 - a2 >= u2 + u1 {
262                // Test i + 2 (even)
263                if a3 >= u3 && a2 - a3 >= v3 + v2 {
264                    // Correct value is i + 2
265                    Matrix(u2, v2, u3, v3, true)
266                } else {
267                    // Correct value is i + 1
268                    Matrix(u1, v1, u2, v2, false)
269                }
270            } else {
271                // Correct value is i
272                Matrix(u0, v0, u1, v1, true)
273            }
274        } else {
275            // Test i + 1 (even)
276            debug_assert!(a2 >= u2);
277            if a1 - a2 >= v2 + v1 {
278                // Test i + 2 (odd)
279                if a3 >= v3 && a2 - a3 >= u3 + u2 {
280                    // Correct value is i + 2
281                    Matrix(u2, v2, u3, v3, false)
282                } else {
283                    // Correct value is i + 1
284                    Matrix(u1, v1, u2, v2, true)
285                }
286            } else {
287                // Correct value is i
288                Matrix(u0, v0, u1, v1, false)
289            }
290        }
291    }
292
293    /// ⚠️ Compute the Lehmer update matrix in full 64 bit precision.
294    #[doc = crate::algorithms::unstable_warning!()]
295    /// Jebelean solves this by starting in double-precission followed
296    /// by single precision once values are small enough.
297    /// Cohen instead runs a single precision round, refreshes the r0 and r1
298    /// values and continues with another single precision round on top.
299    /// Our approach is similar to Cohen, but instead doing the second round
300    /// on the same matrix, we start we a fresh matrix and multiply both in the
301    /// end. This requires 8 additional multiplications, but allows us to use
302    /// the tighter stopping conditions from Jebelean. It also seems the
303    /// simplest out of these solutions.
304    // OPT: We can update r0 and r1 in place. This won't remove the partially
305    // redundant call to lehmer_update, but it reduces memory usage.
306    #[inline]
307    #[must_use]
308    pub fn from_u128_prefix(r0: u128, r1: u128) -> Self {
309        debug_assert!(r0 >= r1);
310        let s = r0.leading_zeros();
311        let r0s = r0 << s;
312        let r1s = r1 << s;
313        let q = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as u64);
314        if q == Matrix::IDENTITY {
315            return q;
316        }
317        // We can return q here and have a perfectly valid single-word Lehmer GCD.
318        q
319        // OPT: Fix the below method to get double-word Lehmer GCD.
320
321        // Recompute r0 and r1 and take the high bits.
322        // TODO: Is it safe to do this based on just the u128 prefix?
323        // let (r0, r1) = q.apply_u128(r0, r1);
324        // let s = r0.leading_zeros();
325        // let r0s = r0 << s;
326        // let r1s = r1 << s;
327        // let qn = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as
328        // u64);
329
330        // // Multiply matrices qn * q
331        // qn.compose(q)
332    }
333}
334
335#[cfg(test)]
336#[allow(clippy::cast_lossless)]
337#[allow(clippy::many_single_char_names)]
338mod tests {
339    use super::*;
340    use crate::{const_for, nlimbs};
341    use core::{
342        cmp::{max, min},
343        mem::swap,
344        str::FromStr,
345    };
346    use proptest::{proptest, test_runner::Config};
347
348    fn gcd(mut a: u128, mut b: u128) -> u128 {
349        while b != 0 {
350            a %= b;
351            swap(&mut a, &mut b);
352        }
353        a
354    }
355
356    fn gcd_uint<const BITS: usize, const LIMBS: usize>(
357        mut a: Uint<BITS, LIMBS>,
358        mut b: Uint<BITS, LIMBS>,
359    ) -> Uint<BITS, LIMBS> {
360        while b != Uint::ZERO {
361            a %= b;
362            swap(&mut a, &mut b);
363        }
364        a
365    }
366
367    #[test]
368    fn test_from_u64_example() {
369        let (a, b) = (252, 105);
370        let m = Matrix::from_u64(a, b);
371        assert_eq!(m, Matrix(2, 5, 5, 12, false));
372        let (a, b) = m.apply_u128(a as u128, b as u128);
373        assert_eq!(a, 21);
374        assert_eq!(b, 0);
375    }
376
377    #[test]
378    fn test_from_u64() {
379        proptest!(|(a: u64, b: u64)| {
380            let (a, b) = (max(a,b), min(a,b));
381            let m = Matrix::from_u64(a, b);
382            let (c, d) = m.apply_u128(a as u128, b as u128);
383            assert!(c >= d);
384            assert_eq!(c, gcd(a as u128, b as u128));
385            assert_eq!(d, 0);
386        });
387    }
388
389    #[test]
390    fn test_from_u64_prefix() {
391        proptest!(|(a: u128, b: u128)| {
392            // Prepare input
393            let (a, b) = (max(a,b), min(a,b));
394            let s = a.leading_zeros();
395            let (sa, sb) = (a << s, b << s);
396
397            let m = Matrix::from_u64_prefix((sa >> 64) as u64, (sb >> 64) as u64);
398            let (c, d) = m.apply_u128(a, b);
399            assert!(c >= d);
400            if m == Matrix::IDENTITY {
401                assert_eq!(c, a);
402                assert_eq!(d, b);
403            } else {
404                assert!(c <= a);
405                assert!(d < b);
406                assert_eq!(gcd(a, b), gcd(c, d));
407            }
408        });
409    }
410
411    fn test_form_uint_one<const BITS: usize, const LIMBS: usize>(
412        a: Uint<BITS, LIMBS>,
413        b: Uint<BITS, LIMBS>,
414    ) {
415        let (a, b) = (max(a, b), min(a, b));
416        let m = Matrix::from(a, b);
417        let (mut c, mut d) = (a, b);
418        m.apply(&mut c, &mut d);
419        assert!(c >= d);
420        if m == Matrix::IDENTITY {
421            assert_eq!(c, a);
422            assert_eq!(d, b);
423        } else {
424            assert!(c <= a);
425            assert!(d < b);
426            assert_eq!(gcd_uint(a, b), gcd_uint(c, d));
427        }
428    }
429
430    #[test]
431    fn test_from_uint_cases() {
432        // This case fails with the double-word version above.
433        type U129 = Uint<129, 3>;
434        test_form_uint_one(
435            U129::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap(),
436            U129::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap(),
437        );
438    }
439
440    #[test]
441    #[allow(clippy::absurd_extreme_comparisons)] // Generated code
442    fn test_from_uint_proptest() {
443        const_for!(BITS in SIZES {
444            const LIMBS: usize = nlimbs(BITS);
445            type U = Uint<BITS, LIMBS>;
446            let config = Config { cases: 10, ..Default::default() };
447            proptest!(config, |(a: U, b: U)| {
448                test_form_uint_one(a, b);
449            });
450        });
451    }
452}