1use alloy_primitives::{Bytes, B256};
4use alloy_rlp::{Decodable, Encodable, Header, EMPTY_STRING_CODE};
5use core::ops::Range;
6use nybbles::Nibbles;
7use smallvec::SmallVec;
8
9#[allow(unused_imports)]
10use alloc::vec::Vec;
11
12mod branch;
13pub use branch::{BranchNode, BranchNodeCompact, BranchNodeRef};
14
15mod extension;
16pub use extension::{ExtensionNode, ExtensionNodeRef};
17
18mod leaf;
19pub use leaf::{LeafNode, LeafNodeRef};
20
21mod rlp;
22pub use rlp::RlpNode;
23
24pub const CHILD_INDEX_RANGE: Range<u8> = 0..16;
26
27#[derive(PartialEq, Eq, Clone, Debug)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum TrieNode {
31 EmptyRoot,
33 Branch(BranchNode),
35 Extension(ExtensionNode),
37 Leaf(LeafNode),
39}
40
41impl Encodable for TrieNode {
42 #[inline]
43 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
44 match self {
45 Self::EmptyRoot => {
46 out.put_u8(EMPTY_STRING_CODE);
47 }
48 Self::Branch(branch) => branch.encode(out),
49 Self::Extension(extension) => extension.encode(out),
50 Self::Leaf(leaf) => leaf.encode(out),
51 }
52 }
53
54 #[inline]
55 fn length(&self) -> usize {
56 match self {
57 Self::EmptyRoot => 1,
58 Self::Branch(branch) => branch.length(),
59 Self::Extension(extension) => extension.length(),
60 Self::Leaf(leaf) => leaf.length(),
61 }
62 }
63}
64
65impl Decodable for TrieNode {
66 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
67 let mut items = match Header::decode_raw(buf)? {
68 alloy_rlp::PayloadView::List(list) => list,
69 alloy_rlp::PayloadView::String(val) => {
70 return if val.is_empty() {
71 Ok(Self::EmptyRoot)
72 } else {
73 Err(alloy_rlp::Error::UnexpectedString)
74 }
75 }
76 };
77
78 match items.len() {
81 17 => {
82 let mut branch = BranchNode::default();
83 for (idx, item) in items.into_iter().enumerate() {
84 if idx == 16 {
85 if item != [EMPTY_STRING_CODE] {
86 return Err(alloy_rlp::Error::Custom(
87 "branch node values are not supported",
88 ));
89 }
90 } else if item != [EMPTY_STRING_CODE] {
91 branch.stack.push(RlpNode::from_raw_rlp(item)?);
92 branch.state_mask.set_bit(idx as u8);
93 }
94 }
95 Ok(Self::Branch(branch))
96 }
97 2 => {
98 let mut key = items.remove(0);
99
100 let encoded_key = Header::decode_bytes(&mut key, false)?;
101 if encoded_key.is_empty() {
102 return Err(alloy_rlp::Error::Custom("trie node key empty"));
103 }
104
105 let key_flag = encoded_key[0] & 0xf0;
107 let first = match key_flag {
109 ExtensionNode::ODD_FLAG | LeafNode::ODD_FLAG => Some(encoded_key[0] & 0x0f),
110 ExtensionNode::EVEN_FLAG | LeafNode::EVEN_FLAG => None,
111 _ => return Err(alloy_rlp::Error::Custom("node is not extension or leaf")),
112 };
113
114 let key = unpack_path_to_nibbles(first, &encoded_key[1..]);
115 let node = if key_flag == LeafNode::EVEN_FLAG || key_flag == LeafNode::ODD_FLAG {
116 let value = Bytes::decode(&mut items.remove(0))?.into();
117 Self::Leaf(LeafNode::new(key, value))
118 } else {
119 Self::Extension(ExtensionNode::new(
121 key,
122 RlpNode::from_raw_rlp(items.remove(0))?,
123 ))
124 };
125 Ok(node)
126 }
127 _ => Err(alloy_rlp::Error::Custom("invalid number of items in the list")),
128 }
129 }
130}
131
132impl TrieNode {
133 #[inline]
135 pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
136 self.encode(rlp);
137 RlpNode::from_rlp(rlp)
138 }
139}
140
141#[inline]
143#[deprecated = "use `RlpNode::from_rlp` instead"]
144pub fn rlp_node(rlp: &[u8]) -> RlpNode {
145 RlpNode::from_rlp(rlp)
146}
147
148#[inline]
150#[deprecated = "use `RlpNode::word_rlp` instead"]
151pub fn word_rlp(word: &B256) -> RlpNode {
152 RlpNode::word_rlp(word)
153}
154
155#[inline]
166pub(crate) fn unpack_path_to_nibbles(first: Option<u8>, rest: &[u8]) -> Nibbles {
167 let Some(first) = first else { return Nibbles::unpack(rest) };
168 debug_assert!(first <= 0xf);
169 let len = rest.len() * 2 + 1;
170 unsafe {
172 Nibbles::from_repr_unchecked(nybbles::smallvec_with(len, |buf| {
173 let (f, r) = buf.split_first_mut().unwrap_unchecked();
174 f.write(first);
175 Nibbles::unpack_to_unchecked(rest, r);
176 }))
177 }
178}
179
180#[inline]
230pub fn encode_path_leaf(nibbles: &Nibbles, is_leaf: bool) -> SmallVec<[u8; 36]> {
231 let mut nibbles = nibbles.as_slice();
232 let encoded_len = nibbles.len() / 2 + 1;
233 let odd_nibbles = nibbles.len() % 2 != 0;
234 unsafe {
236 nybbles::smallvec_with(encoded_len, |buf| {
237 let (first, rest) = buf.split_first_mut().unwrap_unchecked();
238 first.write(match (is_leaf, odd_nibbles) {
239 (true, true) => LeafNode::ODD_FLAG | *nibbles.get_unchecked(0),
240 (true, false) => LeafNode::EVEN_FLAG,
241 (false, true) => ExtensionNode::ODD_FLAG | *nibbles.get_unchecked(0),
242 (false, false) => ExtensionNode::EVEN_FLAG,
243 });
244 if odd_nibbles {
245 nibbles = nibbles.get_unchecked(1..);
246 }
247 nybbles::pack_to_unchecked(nibbles, rest);
248 })
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::TrieMask;
256 use alloy_primitives::hex;
257
258 #[test]
259 fn rlp_empty_root_node() {
260 let empty_root = TrieNode::EmptyRoot;
261 let rlp = empty_root.rlp(&mut vec![]);
262 assert_eq!(rlp[..], hex!("80"));
263 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), empty_root);
264 }
265
266 #[test]
267 fn rlp_zero_value_leaf_roundtrip() {
268 let leaf = TrieNode::Leaf(LeafNode::new(
269 Nibbles::from_nibbles_unchecked(hex!("0604060f")),
270 alloy_rlp::encode(alloy_primitives::U256::ZERO),
271 ));
272 let rlp = leaf.rlp(&mut vec![]);
273 assert_eq!(rlp[..], hex!("c68320646f8180"));
274 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
275 }
276
277 #[test]
278 fn rlp_trie_node_roundtrip() {
279 let leaf = TrieNode::Leaf(LeafNode::new(
281 Nibbles::from_nibbles_unchecked(hex!("0604060f")),
282 hex!("76657262").to_vec(),
283 ));
284 let rlp = leaf.rlp(&mut vec![]);
285 assert_eq!(rlp[..], hex!("c98320646f8476657262"));
286 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
287
288 let mut child = vec![];
290 hex!("76657262").to_vec().as_slice().encode(&mut child);
291 let extension = TrieNode::Extension(ExtensionNode::new(
292 Nibbles::from_nibbles_unchecked(hex!("0604060f")),
293 RlpNode::from_raw(&child).unwrap(),
294 ));
295 let rlp = extension.rlp(&mut vec![]);
296 assert_eq!(rlp[..], hex!("c98300646f8476657262"));
297 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), extension);
298
299 let branch = TrieNode::Branch(BranchNode::new(
301 core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
302 TrieMask::new(u16::MAX),
303 ));
304 let mut rlp = vec![];
305 let rlp_node = branch.rlp(&mut rlp);
306 assert_eq!(
307 rlp_node[..],
308 hex!("a0bed74980bbe29d9c4439c10e9c451e29b306fe74bcf9795ecf0ebbd92a220513")
309 );
310 assert_eq!(rlp, hex!("f90211a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a0171717171717171717171717171717171717171717171717171717171717171780"));
311 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), branch);
312 }
313
314 #[test]
315 fn hashed_encode_path_regression() {
316 let nibbles = Nibbles::from_nibbles(hex!("05010406040a040203030f010805020b050c04070003070e0909070f010b0a0805020301070c0a0902040b0f000f0006040a04050f020b090701000a0a040b"));
317 let path = encode_path_leaf(&nibbles, true);
318 let expected = hex!("351464a4233f1852b5c47037e997f1ba852317ca924bf0f064a45f2b9710aa4b");
319 assert_eq!(path[..], expected);
320 }
321
322 #[test]
323 #[cfg(feature = "arbitrary")]
324 #[cfg_attr(miri, ignore = "no proptest")]
325 fn encode_path_first_byte() {
326 use proptest::{collection::vec, prelude::*};
327
328 proptest::proptest!(|(input in vec(any::<u8>(), 0..128))| {
329 let input = Nibbles::unpack(input);
330 prop_assert!(input.iter().all(|&nibble| nibble <= 0xf));
331 let input_is_odd = input.len() % 2 == 1;
332
333 let compact_leaf = encode_path_leaf(&input, true);
334 let leaf_flag = compact_leaf[0];
335 assert_ne!(leaf_flag & LeafNode::EVEN_FLAG, 0);
337 assert_eq!(input_is_odd, (leaf_flag & ExtensionNode::ODD_FLAG) != 0);
338 if input_is_odd {
339 assert_eq!(leaf_flag & 0x0f, input.first().unwrap());
340 }
341
342 let compact_extension = encode_path_leaf(&input, false);
343 let extension_flag = compact_extension[0];
344 assert_eq!(extension_flag & LeafNode::EVEN_FLAG, 0);
346 assert_eq!(input_is_odd, (extension_flag & ExtensionNode::ODD_FLAG) != 0);
347 if input_is_odd {
348 assert_eq!(extension_flag & 0x0f, input.first().unwrap());
349 }
350 });
351 }
352}