ruint/algorithms/gcd/
mod.rs1#![allow(clippy::module_name_repetitions)]
2
3mod matrix;
7
8pub use self::matrix::Matrix as LehmerMatrix;
9use crate::Uint;
10use core::mem::swap;
11
12#[inline]
18#[must_use]
19pub fn gcd<const BITS: usize, const LIMBS: usize>(
20 mut a: Uint<BITS, LIMBS>,
21 mut b: Uint<BITS, LIMBS>,
22) -> Uint<BITS, LIMBS> {
23 if b > a {
24 swap(&mut a, &mut b);
25 }
26 while b != Uint::ZERO {
27 debug_assert!(a >= b);
28 let m = LehmerMatrix::from(a, b);
29 if m == LehmerMatrix::IDENTITY {
30 a %= b;
34 swap(&mut a, &mut b);
35 } else {
36 m.apply(&mut a, &mut b);
37 }
38 }
39 a
40}
41
42#[inline]
65#[must_use]
66pub fn gcd_extended<const BITS: usize, const LIMBS: usize>(
67 mut a: Uint<BITS, LIMBS>,
68 mut b: Uint<BITS, LIMBS>,
69) -> (
70 Uint<BITS, LIMBS>,
71 Uint<BITS, LIMBS>,
72 Uint<BITS, LIMBS>,
73 bool,
74) {
75 if BITS == 0 {
76 return (Uint::ZERO, Uint::ZERO, Uint::ZERO, false);
77 }
78 let swapped = a < b;
79 if swapped {
80 swap(&mut a, &mut b);
81 }
82
83 let mut s0 = Uint::ONE;
85 let mut s1 = Uint::ZERO;
86 let mut t0 = Uint::ZERO;
87 let mut t1 = Uint::ONE;
88 let mut even = true;
89 while b != Uint::ZERO {
90 debug_assert!(a >= b);
91 let m = LehmerMatrix::from(a, b);
92 if m == LehmerMatrix::IDENTITY {
93 let q = a / b;
97 a -= q * b;
98 swap(&mut a, &mut b);
99 s0 -= q * s1;
100 swap(&mut s0, &mut s1);
101 t0 -= q * t1;
102 swap(&mut t0, &mut t1);
103 even = !even;
104 } else {
105 m.apply(&mut a, &mut b);
106 m.apply(&mut s0, &mut s1);
107 m.apply(&mut t0, &mut t1);
108 even ^= !m.4;
109 }
110 }
111 if even {
113 t0 = Uint::ZERO - t0;
115 } else {
116 s0 = Uint::ZERO - s0;
118 }
119 if swapped {
120 swap(&mut s0, &mut t0);
121 even = !even;
122 }
123 (a, s0, t0, even)
124}
125
126#[inline]
144#[must_use]
145pub fn inv_mod<const BITS: usize, const LIMBS: usize>(
146 num: Uint<BITS, LIMBS>,
147 modulus: Uint<BITS, LIMBS>,
148) -> Option<Uint<BITS, LIMBS>> {
149 if BITS == 0 || modulus.is_zero() {
150 return None;
151 }
152 let mut a = modulus;
153 let mut b = num;
154 if b >= a {
155 b %= a;
156 }
157 if b.is_zero() {
158 return None;
159 }
160
161 let mut t0 = Uint::ZERO;
162 let mut t1 = Uint::ONE;
163 let mut even = true;
164 while b != Uint::ZERO {
165 debug_assert!(a >= b);
166 let m = LehmerMatrix::from(a, b);
167 if m == LehmerMatrix::IDENTITY {
168 let q = a / b;
172 a -= q * b;
173 swap(&mut a, &mut b);
174 t0 -= q * t1;
175 swap(&mut t0, &mut t1);
176 even = !even;
177 } else {
178 m.apply(&mut a, &mut b);
179 m.apply(&mut t0, &mut t1);
180 even ^= !m.4;
181 }
182 }
183 if a == Uint::ONE {
184 Some(if even { modulus + t0 } else { t0 })
186 } else {
187 None
188 }
189}
190
191#[cfg(test)]
192#[allow(clippy::cast_lossless)]
193mod tests {
194 use super::*;
195 use crate::{const_for, nlimbs};
196 use proptest::{proptest, test_runner::Config};
197
198 #[test]
199 fn test_gcd_one() {
200 use core::str::FromStr;
201 const BITS: usize = 129;
202 const LIMBS: usize = nlimbs(BITS);
203 type U = Uint<BITS, LIMBS>;
204 let a = U::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap();
205 let b = U::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap();
206 assert_eq!(gcd(a, b), gcd_ref(a, b));
207 }
208
209 fn gcd_ref<const BITS: usize, const LIMBS: usize>(
211 mut a: Uint<BITS, LIMBS>,
212 mut b: Uint<BITS, LIMBS>,
213 ) -> Uint<BITS, LIMBS> {
214 while b != Uint::ZERO {
215 a %= b;
216 swap(&mut a, &mut b);
217 }
218 a
219 }
220
221 #[test]
222 #[allow(clippy::absurd_extreme_comparisons)] fn test_gcd() {
224 const_for!(BITS in SIZES {
225 const LIMBS: usize = nlimbs(BITS);
226 type U = Uint<BITS, LIMBS>;
227 let config = Config { cases: 10, ..Default::default()};
228 proptest!(config, |(a: U, b: U)| {
229 assert_eq!(gcd(a, b), gcd_ref(a, b));
230 });
231 });
232 }
233
234 #[test]
235 fn test_gcd_extended() {
236 const_for!(BITS in SIZES {
237 const LIMBS: usize = nlimbs(BITS);
238 type U = Uint<BITS, LIMBS>;
239 let config = Config { cases: 5, ..Default::default() };
240 proptest!(config, |(a: U, b: U)| {
241 let (g, x, y, sign) = gcd_extended(a, b);
242 assert_eq!(g, gcd_ref(a, b));
243 if sign {
244 assert_eq!(a * x - b * y, g);
245 } else {
246 assert_eq!(b * y - a * x, g);
247 }
248 });
249 });
250 }
251}