ruint/
base_convert.rs

1use crate::{
2    algorithms::{addmul_nx1, mul_nx1},
3    Uint,
4};
5use core::{fmt, iter::FusedIterator, mem::MaybeUninit};
6
7/// Error for [`from_base_le`][Uint::from_base_le] and
8/// [`from_base_be`][Uint::from_base_be].
9#[allow(clippy::module_name_repetitions)]
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum BaseConvertError {
12    /// The value is too large to fit the target type.
13    Overflow,
14
15    /// The requested number base `.0` is less than two.
16    InvalidBase(u64),
17
18    /// The provided digit `.0` is out of range for requested base `.1`.
19    InvalidDigit(u64, u64),
20}
21
22#[cfg(feature = "std")]
23impl std::error::Error for BaseConvertError {}
24
25impl fmt::Display for BaseConvertError {
26    #[inline]
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            Self::Overflow => f.write_str("the value is too large to fit the target type"),
30            Self::InvalidBase(base) => {
31                write!(f, "the requested number base {base} is less than two")
32            }
33            Self::InvalidDigit(digit, base) => {
34                write!(f, "digit {digit} is out of range for base {base}")
35            }
36        }
37    }
38}
39
40impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
41    /// Returns an iterator over the base `base` digits of the number in
42    /// little-endian order.
43    ///
44    /// Pro tip: instead of setting `base = 10`, set it to the highest
45    /// power of `10` that still fits `u64`. This way much fewer iterations
46    /// are required to extract all the digits.
47    // OPT: Internalize this trick so the user won't have to worry about it.
48    /// # Panics
49    ///
50    /// Panics if the base is less than 2.
51    #[inline]
52    #[track_caller]
53    pub fn to_base_le(&self, base: u64) -> impl Iterator<Item = u64> {
54        SpigotLittle::new(self.limbs, base)
55    }
56
57    /// Returns an iterator over the base `base` digits of the number in
58    /// big-endian order.
59    ///
60    /// Pro tip: instead of setting `base = 10`, set it to the highest
61    /// power of `10` that still fits `u64`. This way much fewer iterations
62    /// are required to extract all the digits.
63    ///
64    /// Use [`to_base_be_2`](Self::to_base_be_2) to extract the maximum number
65    /// of digits at once more efficiently.
66    ///
67    /// # Panics
68    ///
69    /// Panics if the base is less than 2.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// let n = ruint::aliases::U64::from(1234);
75    /// assert_eq!(n.to_base_be(10).collect::<Vec<_>>(), [1, 2, 3, 4]);
76    /// assert_eq!(n.to_base_be(1000000).collect::<Vec<_>>(), [1234]);
77    ///
78    /// // `to_base_be_2` always returns digits maximally packed into `u64`s.
79    /// assert_eq!(n.to_base_be_2(10).collect::<Vec<_>>(), [1234]);
80    /// assert_eq!(n.to_base_be_2(1000000).collect::<Vec<_>>(), [1234]);
81    /// ```
82    #[inline]
83    #[track_caller]
84    pub fn to_base_be(&self, base: u64) -> impl Iterator<Item = u64> {
85        // Use `to_base_le` if we can heap-allocate it to reverse the order,
86        // as it only performs one division per iteration instead of two.
87        #[cfg(feature = "alloc")]
88        {
89            self.to_base_le(base)
90                .collect::<alloc::vec::Vec<_>>()
91                .into_iter()
92                .rev()
93        }
94        #[cfg(not(feature = "alloc"))]
95        {
96            SpigotBig::new(*self, base)
97        }
98    }
99
100    /// Returns an iterator over the base `base` digits of the number in
101    /// big-endian order.
102    ///
103    /// Always returns digits maximally packed into `u64`s.
104    /// Unlike [`to_base_be`], this method:
105    /// - never heap-allocates memory, so it's always faster
106    /// - always returns digits maximally packed into `u64`s, so passing the
107    ///   constant base like `2`, `8`, instead of the highest power that fits in
108    ///   u64 is not needed
109    ///
110    /// # Panics
111    ///
112    /// Panics if the base is less than 2.
113    ///
114    /// # Examples
115    ///
116    /// See [`to_base_be`].
117    ///
118    /// [`to_base_be`]: Self::to_base_be
119    #[inline]
120    #[track_caller]
121    pub fn to_base_be_2(&self, base: u64) -> impl Iterator<Item = u64> {
122        SpigotBig2::new(self.limbs, base)
123    }
124
125    /// Constructs the [`Uint`] from digits in the base `base` in little-endian.
126    ///
127    /// # Errors
128    ///
129    /// * [`BaseConvertError::InvalidBase`] if the base is less than 2.
130    /// * [`BaseConvertError::InvalidDigit`] if a digit is out of range.
131    /// * [`BaseConvertError::Overflow`] if the number is too large to fit.
132    #[inline]
133    pub fn from_base_le<I>(base: u64, digits: I) -> Result<Self, BaseConvertError>
134    where
135        I: IntoIterator<Item = u64>,
136    {
137        if base < 2 {
138            return Err(BaseConvertError::InvalidBase(base));
139        }
140        if BITS == 0 {
141            for digit in digits {
142                if digit >= base {
143                    return Err(BaseConvertError::InvalidDigit(digit, base));
144                }
145                if digit != 0 {
146                    return Err(BaseConvertError::Overflow);
147                }
148            }
149            return Ok(Self::ZERO);
150        }
151
152        let mut iter = digits.into_iter();
153        let mut result = Self::ZERO;
154        let mut power = Self::ONE;
155        for digit in iter.by_ref() {
156            if digit >= base {
157                return Err(BaseConvertError::InvalidDigit(digit, base));
158            }
159
160            // Add digit to result
161            let overflow = addmul_nx1(&mut result.limbs, power.as_limbs(), digit);
162            if overflow != 0 || result.limbs[LIMBS - 1] > Self::MASK {
163                return Err(BaseConvertError::Overflow);
164            }
165
166            // Update power
167            let overflow = mul_nx1(&mut power.limbs, base);
168            if overflow != 0 || power.limbs[LIMBS - 1] > Self::MASK {
169                // Following digits must be zero
170                break;
171            }
172        }
173        for digit in iter {
174            if digit >= base {
175                return Err(BaseConvertError::InvalidDigit(digit, base));
176            }
177            if digit != 0 {
178                return Err(BaseConvertError::Overflow);
179            }
180        }
181        Ok(result)
182    }
183
184    /// Constructs the [`Uint`] from digits in the base `base` in big-endian.
185    ///
186    /// # Errors
187    ///
188    /// * [`BaseConvertError::InvalidBase`] if the base is less than 2.
189    /// * [`BaseConvertError::InvalidDigit`] if a digit is out of range.
190    /// * [`BaseConvertError::Overflow`] if the number is too large to fit.
191    #[inline]
192    pub fn from_base_be<I: IntoIterator<Item = u64>>(
193        base: u64,
194        digits: I,
195    ) -> Result<Self, BaseConvertError> {
196        // OPT: Special handling of bases that divide 2^64, and bases that are
197        // powers of 2.
198        // OPT: Same trick as with `to_base_le`, find the largest power of base
199        // that fits `u64` and accumulate there first.
200        if base < 2 {
201            return Err(BaseConvertError::InvalidBase(base));
202        }
203
204        let mut result = Self::ZERO;
205        for digit in digits {
206            if digit >= base {
207                return Err(BaseConvertError::InvalidDigit(digit, base));
208            }
209            // Multiply by base.
210            // OPT: keep track of non-zero limbs and mul the minimum.
211            let mut carry = u128::from(digit);
212            #[allow(clippy::cast_possible_truncation)]
213            for limb in &mut result.limbs {
214                carry += u128::from(*limb) * u128::from(base);
215                *limb = carry as u64;
216                carry >>= 64;
217            }
218            if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) {
219                return Err(BaseConvertError::Overflow);
220            }
221        }
222
223        Ok(result)
224    }
225}
226
227struct SpigotLittle<const LIMBS: usize> {
228    base:  u64,
229    limbs: [u64; LIMBS],
230}
231
232impl<const LIMBS: usize> SpigotLittle<LIMBS> {
233    #[inline]
234    #[track_caller]
235    fn new(limbs: [u64; LIMBS], base: u64) -> Self {
236        assert!(base > 1);
237        Self { base, limbs }
238    }
239}
240
241impl<const LIMBS: usize> Iterator for SpigotLittle<LIMBS> {
242    type Item = u64;
243
244    #[inline]
245    #[allow(clippy::cast_possible_truncation)] // Doesn't truncate.
246    fn next(&mut self) -> Option<Self::Item> {
247        let base = self.base;
248        assume!(base > 1); // Checked in `new`.
249
250        let mut zero = 0_u64;
251        let mut remainder = 0_u128;
252        for limb in self.limbs.iter_mut().rev() {
253            zero |= *limb;
254            remainder = (remainder << 64) | u128::from(*limb);
255            *limb = (remainder / u128::from(base)) as u64;
256            remainder %= u128::from(base);
257        }
258        if zero == 0 {
259            None
260        } else {
261            Some(remainder as u64)
262        }
263    }
264}
265
266impl<const LIMBS: usize> FusedIterator for SpigotLittle<LIMBS> {}
267
268/// Implementation of `to_base_be` when `alloc` feature is disabled.
269///
270/// This is generally slower than simply reversing the result of `to_base_le`
271/// as it performs two divisions per iteration instead of one.
272#[cfg(not(feature = "alloc"))]
273struct SpigotBig<const LIMBS: usize, const BITS: usize> {
274    base:  u64,
275    n:     Uint<BITS, LIMBS>,
276    power: Uint<BITS, LIMBS>,
277    done:  bool,
278}
279
280#[cfg(not(feature = "alloc"))]
281impl<const LIMBS: usize, const BITS: usize> SpigotBig<LIMBS, BITS> {
282    #[inline]
283    #[track_caller]
284    fn new(n: Uint<BITS, LIMBS>, base: u64) -> Self {
285        assert!(base > 1);
286
287        Self {
288            n,
289            base,
290            power: Self::highest_power(n, base),
291            done: false,
292        }
293    }
294
295    /// Returns the largest power of `base` that fits in `n`.
296    #[inline]
297    fn highest_power(n: Uint<BITS, LIMBS>, base: u64) -> Uint<BITS, LIMBS> {
298        let mut power = Uint::ONE;
299        if base.is_power_of_two() {
300            loop {
301                match power.checked_shl(base.trailing_zeros() as _) {
302                    Some(p) if p < n => power = p,
303                    _ => break,
304                }
305            }
306        } else if let Ok(base) = Uint::try_from(base) {
307            loop {
308                match power.checked_mul(base) {
309                    Some(p) if p < n => power = p,
310                    _ => break,
311                }
312            }
313        }
314        power
315    }
316}
317
318#[cfg(not(feature = "alloc"))]
319impl<const LIMBS: usize, const BITS: usize> Iterator for SpigotBig<LIMBS, BITS> {
320    type Item = u64;
321
322    #[inline]
323    fn next(&mut self) -> Option<Self::Item> {
324        if self.done {
325            return None;
326        }
327
328        let digit;
329        if self.power == 1 {
330            digit = self.n;
331            self.done = true;
332        } else if self.base.is_power_of_two() {
333            digit = self.n >> self.power.trailing_zeros();
334            self.n &= self.power - Uint::ONE;
335
336            self.power >>= self.base.trailing_zeros();
337        } else {
338            (digit, self.n) = self.n.div_rem(self.power);
339            self.power /= Uint::from(self.base);
340        }
341
342        match u64::try_from(digit) {
343            Ok(digit) => Some(digit),
344            Err(e) => debug_unreachable!("digit {digit}: {e}"),
345        }
346    }
347}
348
349#[cfg(not(feature = "alloc"))]
350impl<const LIMBS: usize, const BITS: usize> core::iter::FusedIterator for SpigotBig<LIMBS, BITS> {}
351
352/// An iterator over the base `base` digits of the number in big-endian order.
353///
354/// See [`Uint::to_base_be_2`] for more details.
355struct SpigotBig2<const LIMBS: usize> {
356    buf: SpigotBuf<LIMBS>,
357}
358
359impl<const LIMBS: usize> SpigotBig2<LIMBS> {
360    #[inline]
361    #[track_caller]
362    fn new(limbs: [u64; LIMBS], base: u64) -> Self {
363        Self {
364            buf: SpigotBuf::new(limbs, base),
365        }
366    }
367}
368
369impl<const LIMBS: usize> Iterator for SpigotBig2<LIMBS> {
370    type Item = u64;
371
372    #[inline]
373    fn next(&mut self) -> Option<Self::Item> {
374        self.buf.next_back()
375    }
376}
377
378impl<const LIMBS: usize> FusedIterator for SpigotBig2<LIMBS> {}
379
380/// Collects [`SpigotLittle`] into a stack-allocated buffer.
381///
382/// Base for [`SpigotBig2`].
383struct SpigotBuf<const LIMBS: usize> {
384    end: usize,
385    buf: [[MaybeUninit<u64>; 2]; LIMBS],
386}
387
388impl<const LIMBS: usize> SpigotBuf<LIMBS> {
389    #[inline]
390    #[track_caller]
391    fn new(limbs: [u64; LIMBS], mut base: u64) -> Self {
392        // We need to do this so we can guarantee that `buf` is big enough.
393        base = crate::utils::max_pow_u64(base);
394
395        let mut buf = [[MaybeUninit::uninit(); 2]; LIMBS];
396        // TODO(MSRV-1.80): let as_slice = buf.as_flattened_mut();
397        let as_slice = unsafe {
398            core::slice::from_raw_parts_mut(buf.as_mut_ptr().cast::<MaybeUninit<u64>>(), LIMBS * 2)
399        };
400        let mut i = 0;
401        for limb in SpigotLittle::new(limbs, base) {
402            debug_assert!(
403                i < as_slice.len(),
404                "base {base} too small for u64 digits of {LIMBS} limbs; this shouldn't happen \
405                 because of the `max_pow_u64` call above"
406            );
407            unsafe { as_slice.get_unchecked_mut(i).write(limb) };
408            i += 1;
409        }
410        Self { end: i, buf }
411    }
412
413    #[inline]
414    fn next_back(&mut self) -> Option<u64> {
415        if self.end == 0 {
416            None
417        } else {
418            self.end -= 1;
419            Some(unsafe { *self.buf.as_ptr().cast::<u64>().add(self.end) })
420        }
421    }
422}
423
424#[cfg(test)]
425#[allow(clippy::unreadable_literal)]
426#[allow(clippy::zero_prefixed_literal)]
427mod tests {
428    use super::*;
429    use crate::utils::max_pow_u64;
430
431    // 90630363884335538722706632492458228784305343302099024356772372330524102404852
432    const N: Uint<256, 4> = Uint::from_limbs([
433        0xa8ec92344438aaf4_u64,
434        0x9819ebdbd1faaab1_u64,
435        0x573b1a7064c19c1a_u64,
436        0xc85ef7d79691fe79_u64,
437    ]);
438
439    #[test]
440    fn test_to_base_le() {
441        assert_eq!(
442            Uint::<64, 1>::from(123456789)
443                .to_base_le(10)
444                .collect::<Vec<_>>(),
445            vec![9, 8, 7, 6, 5, 4, 3, 2, 1]
446        );
447        assert_eq!(
448            N.to_base_le(10000000000000000000_u64).collect::<Vec<_>>(),
449            vec![
450                2372330524102404852,
451                0534330209902435677,
452                7066324924582287843,
453                0630363884335538722,
454                9
455            ]
456        );
457    }
458
459    #[test]
460    fn test_from_base_le() {
461        assert_eq!(
462            Uint::<64, 1>::from_base_le(10, [9, 8, 7, 6, 5, 4, 3, 2, 1]),
463            Ok(Uint::<64, 1>::from(123456789))
464        );
465        assert_eq!(
466            Uint::<256, 4>::from_base_le(10000000000000000000_u64, [
467                2372330524102404852,
468                0534330209902435677,
469                7066324924582287843,
470                0630363884335538722,
471                9
472            ])
473            .unwrap(),
474            N
475        );
476    }
477
478    #[test]
479    fn test_to_base_be() {
480        assert_eq!(
481            Uint::<64, 1>::from(123456789)
482                .to_base_be(10)
483                .collect::<Vec<_>>(),
484            vec![1, 2, 3, 4, 5, 6, 7, 8, 9]
485        );
486        assert_eq!(
487            N.to_base_be(10000000000000000000_u64).collect::<Vec<_>>(),
488            vec![
489                9,
490                0630363884335538722,
491                7066324924582287843,
492                0534330209902435677,
493                2372330524102404852
494            ]
495        );
496    }
497
498    #[test]
499    fn test_to_base_be_2() {
500        assert_eq!(
501            Uint::<64, 1>::from(123456789)
502                .to_base_be_2(10)
503                .collect::<Vec<_>>(),
504            vec![123456789]
505        );
506        assert_eq!(
507            N.to_base_be_2(10000000000000000000_u64).collect::<Vec<_>>(),
508            vec![
509                9,
510                0630363884335538722,
511                7066324924582287843,
512                0534330209902435677,
513                2372330524102404852
514            ]
515        );
516    }
517
518    #[test]
519    fn test_from_base_be() {
520        assert_eq!(
521            Uint::<64, 1>::from_base_be(10, [1, 2, 3, 4, 5, 6, 7, 8, 9]),
522            Ok(Uint::<64, 1>::from(123456789))
523        );
524        assert_eq!(
525            Uint::<256, 4>::from_base_be(10000000000000000000_u64, [
526                9,
527                0630363884335538722,
528                7066324924582287843,
529                0534330209902435677,
530                2372330524102404852
531            ])
532            .unwrap(),
533            N
534        );
535    }
536
537    #[test]
538    fn test_from_base_be_overflow() {
539        assert_eq!(
540            Uint::<0, 0>::from_base_be(10, core::iter::empty()),
541            Ok(Uint::<0, 0>::ZERO)
542        );
543        assert_eq!(
544            Uint::<0, 0>::from_base_be(10, core::iter::once(0)),
545            Ok(Uint::<0, 0>::ZERO)
546        );
547        assert_eq!(
548            Uint::<0, 0>::from_base_be(10, core::iter::once(1)),
549            Err(BaseConvertError::Overflow)
550        );
551        assert_eq!(
552            Uint::<1, 1>::from_base_be(10, [1, 0, 0].into_iter()),
553            Err(BaseConvertError::Overflow)
554        );
555    }
556
557    #[test]
558    fn test_roundtrip() {
559        fn test<const BITS: usize, const LIMBS: usize>(n: Uint<BITS, LIMBS>, base: u64) {
560            assert_eq!(
561                n.to_base_be(base).collect::<Vec<_>>(),
562                n.to_base_le(base)
563                    .collect::<Vec<_>>()
564                    .into_iter()
565                    .rev()
566                    .collect::<Vec<_>>(),
567            );
568
569            let digits = n.to_base_le(base);
570            let n2 = Uint::<BITS, LIMBS>::from_base_le(base, digits).unwrap();
571            assert_eq!(n, n2);
572
573            let digits = n.to_base_be(base);
574            let n2 = Uint::<BITS, LIMBS>::from_base_be(base, digits).unwrap();
575            assert_eq!(n, n2);
576
577            let digits = n.to_base_be_2(base).collect::<Vec<_>>();
578            let n2 = Uint::<BITS, LIMBS>::from_base_be(max_pow_u64(base), digits).unwrap();
579            assert_eq!(n, n2);
580        }
581
582        let single = |x: u64| x..=x;
583        for base in [2..=129, single(1 << 31), single(1 << 32), single(1 << 33)]
584            .into_iter()
585            .flatten()
586        {
587            test(Uint::<64, 1>::from(123456789), base);
588            test(Uint::<128, 2>::from(123456789), base);
589            test(N, base);
590        }
591    }
592}