alloy_rlp/
decode.rs

1use crate::{Error, Header, Result};
2use bytes::{Bytes, BytesMut};
3use core::marker::{PhantomData, PhantomPinned};
4
5/// A type that can be decoded from an RLP blob.
6pub trait Decodable: Sized {
7    /// Decodes the blob into the appropriate type. `buf` must be advanced past
8    /// the decoded object.
9    fn decode(buf: &mut &[u8]) -> Result<Self>;
10}
11
12/// An active RLP decoder, with a specific slice of a payload.
13#[derive(Debug)]
14pub struct Rlp<'a> {
15    payload_view: &'a [u8],
16}
17
18impl<'a> Rlp<'a> {
19    /// Instantiate an RLP decoder with a payload slice.
20    pub fn new(mut payload: &'a [u8]) -> Result<Self> {
21        let payload_view = Header::decode_bytes(&mut payload, true)?;
22        Ok(Self { payload_view })
23    }
24
25    /// Decode the next item from the buffer.
26    #[inline]
27    pub fn get_next<T: Decodable>(&mut self) -> Result<Option<T>> {
28        if self.payload_view.is_empty() {
29            Ok(None)
30        } else {
31            T::decode(&mut self.payload_view).map(Some)
32        }
33    }
34}
35
36impl<T: ?Sized> Decodable for PhantomData<T> {
37    fn decode(_buf: &mut &[u8]) -> Result<Self> {
38        Ok(Self)
39    }
40}
41
42impl Decodable for PhantomPinned {
43    fn decode(_buf: &mut &[u8]) -> Result<Self> {
44        Ok(Self)
45    }
46}
47
48impl Decodable for bool {
49    #[inline]
50    fn decode(buf: &mut &[u8]) -> Result<Self> {
51        Ok(match u8::decode(buf)? {
52            0 => false,
53            1 => true,
54            _ => return Err(Error::Custom("invalid bool value, must be 0 or 1")),
55        })
56    }
57}
58
59impl<const N: usize> Decodable for [u8; N] {
60    #[inline]
61    fn decode(from: &mut &[u8]) -> Result<Self> {
62        let bytes = Header::decode_bytes(from, false)?;
63        Self::try_from(bytes).map_err(|_| Error::UnexpectedLength)
64    }
65}
66
67macro_rules! decode_integer {
68    ($($t:ty),+ $(,)?) => {$(
69        impl Decodable for $t {
70            #[inline]
71            fn decode(buf: &mut &[u8]) -> Result<Self> {
72                let bytes = Header::decode_bytes(buf, false)?;
73                static_left_pad(bytes).map(<$t>::from_be_bytes)
74            }
75        }
76    )+};
77}
78
79decode_integer!(u8, u16, u32, u64, usize, u128);
80
81impl Decodable for Bytes {
82    #[inline]
83    fn decode(buf: &mut &[u8]) -> Result<Self> {
84        Header::decode_bytes(buf, false).map(|x| Self::from(x.to_vec()))
85    }
86}
87
88impl Decodable for BytesMut {
89    #[inline]
90    fn decode(buf: &mut &[u8]) -> Result<Self> {
91        Header::decode_bytes(buf, false).map(Self::from)
92    }
93}
94
95impl Decodable for alloc::string::String {
96    #[inline]
97    fn decode(buf: &mut &[u8]) -> Result<Self> {
98        Header::decode_str(buf).map(Into::into)
99    }
100}
101
102impl<T: Decodable> Decodable for alloc::vec::Vec<T> {
103    #[inline]
104    fn decode(buf: &mut &[u8]) -> Result<Self> {
105        let mut bytes = Header::decode_bytes(buf, true)?;
106        let mut vec = Self::new();
107        let payload_view = &mut bytes;
108        while !payload_view.is_empty() {
109            vec.push(T::decode(payload_view)?);
110        }
111        Ok(vec)
112    }
113}
114
115macro_rules! wrap_impl {
116    ($($(#[$attr:meta])* [$($gen:tt)*] <$t:ty>::$new:ident($t2:ty)),+ $(,)?) => {$(
117        $(#[$attr])*
118        impl<$($gen)*> Decodable for $t {
119            #[inline]
120            fn decode(buf: &mut &[u8]) -> Result<Self> {
121                <$t2 as Decodable>::decode(buf).map(<$t>::$new)
122            }
123        }
124    )+};
125}
126
127wrap_impl! {
128    #[cfg(feature = "arrayvec")]
129    [const N: usize] <arrayvec::ArrayVec<u8, N>>::from([u8; N]),
130    [T: Decodable] <alloc::boxed::Box<T>>::new(T),
131    [T: Decodable] <alloc::rc::Rc<T>>::new(T),
132    #[cfg(target_has_atomic = "ptr")]
133    [T: Decodable] <alloc::sync::Arc<T>>::new(T),
134}
135
136impl<T: ?Sized + alloc::borrow::ToOwned> Decodable for alloc::borrow::Cow<'_, T>
137where
138    T::Owned: Decodable,
139{
140    #[inline]
141    fn decode(buf: &mut &[u8]) -> Result<Self> {
142        T::Owned::decode(buf).map(Self::Owned)
143    }
144}
145
146#[cfg(any(feature = "std", feature = "core-net"))]
147mod std_impl {
148    use super::*;
149    #[cfg(all(feature = "core-net", not(feature = "std")))]
150    use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
151    #[cfg(feature = "std")]
152    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
153
154    impl Decodable for IpAddr {
155        fn decode(buf: &mut &[u8]) -> Result<Self> {
156            let bytes = Header::decode_bytes(buf, false)?;
157            match bytes.len() {
158                4 => Ok(Self::V4(Ipv4Addr::from(slice_to_array::<4>(bytes).expect("infallible")))),
159                16 => {
160                    Ok(Self::V6(Ipv6Addr::from(slice_to_array::<16>(bytes).expect("infallible"))))
161                }
162                _ => Err(Error::UnexpectedLength),
163            }
164        }
165    }
166
167    impl Decodable for Ipv4Addr {
168        #[inline]
169        fn decode(buf: &mut &[u8]) -> Result<Self> {
170            let bytes = Header::decode_bytes(buf, false)?;
171            slice_to_array::<4>(bytes).map(Self::from)
172        }
173    }
174
175    impl Decodable for Ipv6Addr {
176        #[inline]
177        fn decode(buf: &mut &[u8]) -> Result<Self> {
178            let bytes = Header::decode_bytes(buf, false)?;
179            slice_to_array::<16>(bytes).map(Self::from)
180        }
181    }
182
183    #[inline]
184    fn slice_to_array<const N: usize>(slice: &[u8]) -> Result<[u8; N]> {
185        slice.try_into().map_err(|_| Error::UnexpectedLength)
186    }
187}
188
189/// Decodes the entire input, ensuring no trailing bytes remain.
190///
191/// # Errors
192///
193/// Returns an error if the encoding is invalid or if data remains after decoding the RLP item.
194#[inline]
195pub fn decode_exact<T: Decodable>(bytes: impl AsRef<[u8]>) -> Result<T> {
196    let mut buf = bytes.as_ref();
197    let out = T::decode(&mut buf)?;
198
199    // check if there are any remaining bytes after decoding
200    if !buf.is_empty() {
201        // TODO: introduce a new variant TrailingBytes to better distinguish this error
202        return Err(Error::UnexpectedLength);
203    }
204
205    Ok(out)
206}
207
208/// Left-pads a slice to a statically known size array.
209///
210/// # Errors
211///
212/// Returns an error if the slice is too long or if the first byte is 0.
213#[inline]
214pub(crate) fn static_left_pad<const N: usize>(data: &[u8]) -> Result<[u8; N]> {
215    if data.len() > N {
216        return Err(Error::Overflow);
217    }
218
219    let mut v = [0; N];
220
221    // yes, data may empty, e.g. we decode a bool false value
222    if data.is_empty() {
223        return Ok(v);
224    }
225
226    if data[0] == 0 {
227        return Err(Error::LeadingZero);
228    }
229
230    // SAFETY: length checked above
231    unsafe { v.get_unchecked_mut(N - data.len()..) }.copy_from_slice(data);
232    Ok(v)
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::{encode, Encodable};
239    use core::fmt::Debug;
240    use hex_literal::hex;
241
242    #[allow(unused_imports)]
243    use alloc::{string::String, vec::Vec};
244
245    fn check_decode<'a, T, IT>(fixtures: IT)
246    where
247        T: Encodable + Decodable + PartialEq + Debug,
248        IT: IntoIterator<Item = (Result<T>, &'a [u8])>,
249    {
250        for (expected, mut input) in fixtures {
251            if let Ok(expected) = &expected {
252                assert_eq!(crate::encode(expected), input, "{expected:?}");
253            }
254
255            let orig = input;
256            assert_eq!(
257                T::decode(&mut input),
258                expected,
259                "input: {}{}",
260                hex::encode(orig),
261                expected.as_ref().map_or_else(
262                    |_| String::new(),
263                    |expected| format!("; expected: {}", hex::encode(crate::encode(expected)))
264                )
265            );
266
267            if expected.is_ok() {
268                assert_eq!(input, &[]);
269            }
270        }
271    }
272
273    #[test]
274    fn rlp_bool() {
275        let out = [0x80];
276        let val = bool::decode(&mut &out[..]);
277        assert_eq!(Ok(false), val);
278
279        let out = [0x01];
280        let val = bool::decode(&mut &out[..]);
281        assert_eq!(Ok(true), val);
282    }
283
284    #[test]
285    fn rlp_strings() {
286        check_decode::<Bytes, _>([
287            (Ok(hex!("00")[..].to_vec().into()), &hex!("00")[..]),
288            (
289                Ok(hex!("6f62636465666768696a6b6c6d")[..].to_vec().into()),
290                &hex!("8D6F62636465666768696A6B6C6D")[..],
291            ),
292            (Err(Error::UnexpectedList), &hex!("C0")[..]),
293        ])
294    }
295
296    #[test]
297    fn rlp_fixed_length() {
298        check_decode([
299            (Ok(hex!("6f62636465666768696a6b6c6d")), &hex!("8D6F62636465666768696A6B6C6D")[..]),
300            (Err(Error::UnexpectedLength), &hex!("8C6F62636465666768696A6B6C")[..]),
301            (Err(Error::UnexpectedLength), &hex!("8E6F62636465666768696A6B6C6D6E")[..]),
302        ])
303    }
304
305    #[test]
306    fn rlp_u64() {
307        check_decode([
308            (Ok(9_u64), &hex!("09")[..]),
309            (Ok(0_u64), &hex!("80")[..]),
310            (Ok(0x0505_u64), &hex!("820505")[..]),
311            (Ok(0xCE05050505_u64), &hex!("85CE05050505")[..]),
312            (Err(Error::Overflow), &hex!("8AFFFFFFFFFFFFFFFFFF7C")[..]),
313            (Err(Error::InputTooShort), &hex!("8BFFFFFFFFFFFFFFFFFF7C")[..]),
314            (Err(Error::UnexpectedList), &hex!("C0")[..]),
315            (Err(Error::LeadingZero), &hex!("00")[..]),
316            (Err(Error::NonCanonicalSingleByte), &hex!("8105")[..]),
317            (Err(Error::LeadingZero), &hex!("8200F4")[..]),
318            (Err(Error::NonCanonicalSize), &hex!("B8020004")[..]),
319            (
320                Err(Error::Overflow),
321                &hex!("A101000000000000000000000000000000000000008B000000000000000000000000")[..],
322            ),
323        ])
324    }
325
326    #[test]
327    fn rlp_vectors() {
328        check_decode::<Vec<u64>, _>([
329            (Ok(vec![]), &hex!("C0")[..]),
330            (Ok(vec![0xBBCCB5_u64, 0xFFC0B5_u64]), &hex!("C883BBCCB583FFC0B5")[..]),
331        ])
332    }
333
334    #[cfg(feature = "std")]
335    #[test]
336    fn rlp_ip() {
337        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
338
339        let localhost4 = Ipv4Addr::new(127, 0, 0, 1);
340        let localhost6 = localhost4.to_ipv6_mapped();
341        let expected4 = &hex!("847F000001")[..];
342        let expected6 = &hex!("9000000000000000000000ffff7f000001")[..];
343        check_decode::<Ipv4Addr, _>([(Ok(localhost4), expected4)]);
344        check_decode::<Ipv6Addr, _>([(Ok(localhost6), expected6)]);
345        check_decode::<IpAddr, _>([
346            (Ok(IpAddr::V4(localhost4)), expected4),
347            (Ok(IpAddr::V6(localhost6)), expected6),
348        ]);
349    }
350
351    #[test]
352    fn malformed_rlp() {
353        check_decode::<Bytes, _>([
354            (Err(Error::InputTooShort), &hex!("C1")[..]),
355            (Err(Error::InputTooShort), &hex!("D7")[..]),
356        ]);
357        check_decode::<[u8; 5], _>([
358            (Err(Error::InputTooShort), &hex!("C1")[..]),
359            (Err(Error::InputTooShort), &hex!("D7")[..]),
360        ]);
361        #[cfg(feature = "std")]
362        check_decode::<std::net::IpAddr, _>([
363            (Err(Error::InputTooShort), &hex!("C1")[..]),
364            (Err(Error::InputTooShort), &hex!("D7")[..]),
365        ]);
366        check_decode::<Vec<u8>, _>([
367            (Err(Error::InputTooShort), &hex!("C1")[..]),
368            (Err(Error::InputTooShort), &hex!("D7")[..]),
369        ]);
370        check_decode::<String, _>([
371            (Err(Error::InputTooShort), &hex!("C1")[..]),
372            (Err(Error::InputTooShort), &hex!("D7")[..]),
373        ]);
374        check_decode::<String, _>([
375            (Err(Error::InputTooShort), &hex!("C1")[..]),
376            (Err(Error::InputTooShort), &hex!("D7")[..]),
377        ]);
378        check_decode::<u8, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
379        check_decode::<u64, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
380    }
381
382    #[test]
383    fn rlp_full() {
384        fn check_decode_exact<T: Decodable + Encodable + PartialEq + Debug>(input: T) {
385            let encoded = encode(&input);
386            assert_eq!(decode_exact::<T>(&encoded), Ok(input));
387            assert_eq!(
388                decode_exact::<T>([encoded, vec![0x00]].concat()),
389                Err(Error::UnexpectedLength)
390            );
391        }
392
393        check_decode_exact::<String>("".into());
394        check_decode_exact::<String>("test1234".into());
395        check_decode_exact::<Vec<u64>>(vec![]);
396        check_decode_exact::<Vec<u64>>(vec![0; 4]);
397    }
398}