1use crate::{Error, Header, Result};
2use bytes::{Bytes, BytesMut};
3use core::marker::{PhantomData, PhantomPinned};
4
5pub trait Decodable: Sized {
7 fn decode(buf: &mut &[u8]) -> Result<Self>;
10}
11
12#[derive(Debug)]
14pub struct Rlp<'a> {
15 payload_view: &'a [u8],
16}
17
18impl<'a> Rlp<'a> {
19 pub fn new(mut payload: &'a [u8]) -> Result<Self> {
21 let payload_view = Header::decode_bytes(&mut payload, true)?;
22 Ok(Self { payload_view })
23 }
24
25 #[inline]
27 pub fn get_next<T: Decodable>(&mut self) -> Result<Option<T>> {
28 if self.payload_view.is_empty() {
29 Ok(None)
30 } else {
31 T::decode(&mut self.payload_view).map(Some)
32 }
33 }
34}
35
36impl<T: ?Sized> Decodable for PhantomData<T> {
37 fn decode(_buf: &mut &[u8]) -> Result<Self> {
38 Ok(Self)
39 }
40}
41
42impl Decodable for PhantomPinned {
43 fn decode(_buf: &mut &[u8]) -> Result<Self> {
44 Ok(Self)
45 }
46}
47
48impl Decodable for bool {
49 #[inline]
50 fn decode(buf: &mut &[u8]) -> Result<Self> {
51 Ok(match u8::decode(buf)? {
52 0 => false,
53 1 => true,
54 _ => return Err(Error::Custom("invalid bool value, must be 0 or 1")),
55 })
56 }
57}
58
59impl<const N: usize> Decodable for [u8; N] {
60 #[inline]
61 fn decode(from: &mut &[u8]) -> Result<Self> {
62 let bytes = Header::decode_bytes(from, false)?;
63 Self::try_from(bytes).map_err(|_| Error::UnexpectedLength)
64 }
65}
66
67macro_rules! decode_integer {
68 ($($t:ty),+ $(,)?) => {$(
69 impl Decodable for $t {
70 #[inline]
71 fn decode(buf: &mut &[u8]) -> Result<Self> {
72 let bytes = Header::decode_bytes(buf, false)?;
73 static_left_pad(bytes).map(<$t>::from_be_bytes)
74 }
75 }
76 )+};
77}
78
79decode_integer!(u8, u16, u32, u64, usize, u128);
80
81impl Decodable for Bytes {
82 #[inline]
83 fn decode(buf: &mut &[u8]) -> Result<Self> {
84 Header::decode_bytes(buf, false).map(|x| Self::from(x.to_vec()))
85 }
86}
87
88impl Decodable for BytesMut {
89 #[inline]
90 fn decode(buf: &mut &[u8]) -> Result<Self> {
91 Header::decode_bytes(buf, false).map(Self::from)
92 }
93}
94
95impl Decodable for alloc::string::String {
96 #[inline]
97 fn decode(buf: &mut &[u8]) -> Result<Self> {
98 Header::decode_str(buf).map(Into::into)
99 }
100}
101
102impl<T: Decodable> Decodable for alloc::vec::Vec<T> {
103 #[inline]
104 fn decode(buf: &mut &[u8]) -> Result<Self> {
105 let mut bytes = Header::decode_bytes(buf, true)?;
106 let mut vec = Self::new();
107 let payload_view = &mut bytes;
108 while !payload_view.is_empty() {
109 vec.push(T::decode(payload_view)?);
110 }
111 Ok(vec)
112 }
113}
114
115macro_rules! wrap_impl {
116 ($($(#[$attr:meta])* [$($gen:tt)*] <$t:ty>::$new:ident($t2:ty)),+ $(,)?) => {$(
117 $(#[$attr])*
118 impl<$($gen)*> Decodable for $t {
119 #[inline]
120 fn decode(buf: &mut &[u8]) -> Result<Self> {
121 <$t2 as Decodable>::decode(buf).map(<$t>::$new)
122 }
123 }
124 )+};
125}
126
127wrap_impl! {
128 #[cfg(feature = "arrayvec")]
129 [const N: usize] <arrayvec::ArrayVec<u8, N>>::from([u8; N]),
130 [T: Decodable] <alloc::boxed::Box<T>>::new(T),
131 [T: Decodable] <alloc::rc::Rc<T>>::new(T),
132 #[cfg(target_has_atomic = "ptr")]
133 [T: Decodable] <alloc::sync::Arc<T>>::new(T),
134}
135
136impl<T: ?Sized + alloc::borrow::ToOwned> Decodable for alloc::borrow::Cow<'_, T>
137where
138 T::Owned: Decodable,
139{
140 #[inline]
141 fn decode(buf: &mut &[u8]) -> Result<Self> {
142 T::Owned::decode(buf).map(Self::Owned)
143 }
144}
145
146#[cfg(any(feature = "std", feature = "core-net"))]
147mod std_impl {
148 use super::*;
149 #[cfg(all(feature = "core-net", not(feature = "std")))]
150 use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
151 #[cfg(feature = "std")]
152 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
153
154 impl Decodable for IpAddr {
155 fn decode(buf: &mut &[u8]) -> Result<Self> {
156 let bytes = Header::decode_bytes(buf, false)?;
157 match bytes.len() {
158 4 => Ok(Self::V4(Ipv4Addr::from(slice_to_array::<4>(bytes).expect("infallible")))),
159 16 => {
160 Ok(Self::V6(Ipv6Addr::from(slice_to_array::<16>(bytes).expect("infallible"))))
161 }
162 _ => Err(Error::UnexpectedLength),
163 }
164 }
165 }
166
167 impl Decodable for Ipv4Addr {
168 #[inline]
169 fn decode(buf: &mut &[u8]) -> Result<Self> {
170 let bytes = Header::decode_bytes(buf, false)?;
171 slice_to_array::<4>(bytes).map(Self::from)
172 }
173 }
174
175 impl Decodable for Ipv6Addr {
176 #[inline]
177 fn decode(buf: &mut &[u8]) -> Result<Self> {
178 let bytes = Header::decode_bytes(buf, false)?;
179 slice_to_array::<16>(bytes).map(Self::from)
180 }
181 }
182
183 #[inline]
184 fn slice_to_array<const N: usize>(slice: &[u8]) -> Result<[u8; N]> {
185 slice.try_into().map_err(|_| Error::UnexpectedLength)
186 }
187}
188
189#[inline]
195pub fn decode_exact<T: Decodable>(bytes: impl AsRef<[u8]>) -> Result<T> {
196 let mut buf = bytes.as_ref();
197 let out = T::decode(&mut buf)?;
198
199 if !buf.is_empty() {
201 return Err(Error::UnexpectedLength);
203 }
204
205 Ok(out)
206}
207
208#[inline]
214pub(crate) fn static_left_pad<const N: usize>(data: &[u8]) -> Result<[u8; N]> {
215 if data.len() > N {
216 return Err(Error::Overflow);
217 }
218
219 let mut v = [0; N];
220
221 if data.is_empty() {
223 return Ok(v);
224 }
225
226 if data[0] == 0 {
227 return Err(Error::LeadingZero);
228 }
229
230 unsafe { v.get_unchecked_mut(N - data.len()..) }.copy_from_slice(data);
232 Ok(v)
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::{encode, Encodable};
239 use core::fmt::Debug;
240 use hex_literal::hex;
241
242 #[allow(unused_imports)]
243 use alloc::{string::String, vec::Vec};
244
245 fn check_decode<'a, T, IT>(fixtures: IT)
246 where
247 T: Encodable + Decodable + PartialEq + Debug,
248 IT: IntoIterator<Item = (Result<T>, &'a [u8])>,
249 {
250 for (expected, mut input) in fixtures {
251 if let Ok(expected) = &expected {
252 assert_eq!(crate::encode(expected), input, "{expected:?}");
253 }
254
255 let orig = input;
256 assert_eq!(
257 T::decode(&mut input),
258 expected,
259 "input: {}{}",
260 hex::encode(orig),
261 expected.as_ref().map_or_else(
262 |_| String::new(),
263 |expected| format!("; expected: {}", hex::encode(crate::encode(expected)))
264 )
265 );
266
267 if expected.is_ok() {
268 assert_eq!(input, &[]);
269 }
270 }
271 }
272
273 #[test]
274 fn rlp_bool() {
275 let out = [0x80];
276 let val = bool::decode(&mut &out[..]);
277 assert_eq!(Ok(false), val);
278
279 let out = [0x01];
280 let val = bool::decode(&mut &out[..]);
281 assert_eq!(Ok(true), val);
282 }
283
284 #[test]
285 fn rlp_strings() {
286 check_decode::<Bytes, _>([
287 (Ok(hex!("00")[..].to_vec().into()), &hex!("00")[..]),
288 (
289 Ok(hex!("6f62636465666768696a6b6c6d")[..].to_vec().into()),
290 &hex!("8D6F62636465666768696A6B6C6D")[..],
291 ),
292 (Err(Error::UnexpectedList), &hex!("C0")[..]),
293 ])
294 }
295
296 #[test]
297 fn rlp_fixed_length() {
298 check_decode([
299 (Ok(hex!("6f62636465666768696a6b6c6d")), &hex!("8D6F62636465666768696A6B6C6D")[..]),
300 (Err(Error::UnexpectedLength), &hex!("8C6F62636465666768696A6B6C")[..]),
301 (Err(Error::UnexpectedLength), &hex!("8E6F62636465666768696A6B6C6D6E")[..]),
302 ])
303 }
304
305 #[test]
306 fn rlp_u64() {
307 check_decode([
308 (Ok(9_u64), &hex!("09")[..]),
309 (Ok(0_u64), &hex!("80")[..]),
310 (Ok(0x0505_u64), &hex!("820505")[..]),
311 (Ok(0xCE05050505_u64), &hex!("85CE05050505")[..]),
312 (Err(Error::Overflow), &hex!("8AFFFFFFFFFFFFFFFFFF7C")[..]),
313 (Err(Error::InputTooShort), &hex!("8BFFFFFFFFFFFFFFFFFF7C")[..]),
314 (Err(Error::UnexpectedList), &hex!("C0")[..]),
315 (Err(Error::LeadingZero), &hex!("00")[..]),
316 (Err(Error::NonCanonicalSingleByte), &hex!("8105")[..]),
317 (Err(Error::LeadingZero), &hex!("8200F4")[..]),
318 (Err(Error::NonCanonicalSize), &hex!("B8020004")[..]),
319 (
320 Err(Error::Overflow),
321 &hex!("A101000000000000000000000000000000000000008B000000000000000000000000")[..],
322 ),
323 ])
324 }
325
326 #[test]
327 fn rlp_vectors() {
328 check_decode::<Vec<u64>, _>([
329 (Ok(vec![]), &hex!("C0")[..]),
330 (Ok(vec![0xBBCCB5_u64, 0xFFC0B5_u64]), &hex!("C883BBCCB583FFC0B5")[..]),
331 ])
332 }
333
334 #[cfg(feature = "std")]
335 #[test]
336 fn rlp_ip() {
337 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
338
339 let localhost4 = Ipv4Addr::new(127, 0, 0, 1);
340 let localhost6 = localhost4.to_ipv6_mapped();
341 let expected4 = &hex!("847F000001")[..];
342 let expected6 = &hex!("9000000000000000000000ffff7f000001")[..];
343 check_decode::<Ipv4Addr, _>([(Ok(localhost4), expected4)]);
344 check_decode::<Ipv6Addr, _>([(Ok(localhost6), expected6)]);
345 check_decode::<IpAddr, _>([
346 (Ok(IpAddr::V4(localhost4)), expected4),
347 (Ok(IpAddr::V6(localhost6)), expected6),
348 ]);
349 }
350
351 #[test]
352 fn malformed_rlp() {
353 check_decode::<Bytes, _>([
354 (Err(Error::InputTooShort), &hex!("C1")[..]),
355 (Err(Error::InputTooShort), &hex!("D7")[..]),
356 ]);
357 check_decode::<[u8; 5], _>([
358 (Err(Error::InputTooShort), &hex!("C1")[..]),
359 (Err(Error::InputTooShort), &hex!("D7")[..]),
360 ]);
361 #[cfg(feature = "std")]
362 check_decode::<std::net::IpAddr, _>([
363 (Err(Error::InputTooShort), &hex!("C1")[..]),
364 (Err(Error::InputTooShort), &hex!("D7")[..]),
365 ]);
366 check_decode::<Vec<u8>, _>([
367 (Err(Error::InputTooShort), &hex!("C1")[..]),
368 (Err(Error::InputTooShort), &hex!("D7")[..]),
369 ]);
370 check_decode::<String, _>([
371 (Err(Error::InputTooShort), &hex!("C1")[..]),
372 (Err(Error::InputTooShort), &hex!("D7")[..]),
373 ]);
374 check_decode::<String, _>([
375 (Err(Error::InputTooShort), &hex!("C1")[..]),
376 (Err(Error::InputTooShort), &hex!("D7")[..]),
377 ]);
378 check_decode::<u8, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
379 check_decode::<u64, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
380 }
381
382 #[test]
383 fn rlp_full() {
384 fn check_decode_exact<T: Decodable + Encodable + PartialEq + Debug>(input: T) {
385 let encoded = encode(&input);
386 assert_eq!(decode_exact::<T>(&encoded), Ok(input));
387 assert_eq!(
388 decode_exact::<T>([encoded, vec![0x00]].concat()),
389 Err(Error::UnexpectedLength)
390 );
391 }
392
393 check_decode_exact::<String>("".into());
394 check_decode_exact::<String>("test1234".into());
395 check_decode_exact::<Vec<u64>>(vec![]);
396 check_decode_exact::<Vec<u64>>(vec![0; 4]);
397 }
398}