1use alloy_primitives::{B256, Bytes};
4use alloy_rlp::{Decodable, EMPTY_STRING_CODE, Encodable, Header};
5use core::ops::Range;
6use nybbles::Nibbles;
7use smallvec::SmallVec;
8
9#[allow(unused_imports)]
10use alloc::vec::Vec;
11
12mod branch;
13pub use branch::{BranchNode, BranchNodeCompact, BranchNodeRef};
14
15mod extension;
16pub use extension::{ExtensionNode, ExtensionNodeRef};
17
18mod leaf;
19pub use leaf::{LeafNode, LeafNodeRef};
20
21mod rlp;
22pub use rlp::RlpNode;
23
24pub const CHILD_INDEX_RANGE: Range<u8> = 0..16;
26
27#[derive(PartialEq, Eq, Clone, Debug)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum TrieNode {
31 EmptyRoot,
33 Branch(BranchNode),
35 Extension(ExtensionNode),
37 Leaf(LeafNode),
39}
40
41impl Encodable for TrieNode {
42 #[inline]
43 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
44 match self {
45 Self::EmptyRoot => {
46 out.put_u8(EMPTY_STRING_CODE);
47 }
48 Self::Branch(branch) => branch.encode(out),
49 Self::Extension(extension) => extension.encode(out),
50 Self::Leaf(leaf) => leaf.encode(out),
51 }
52 }
53
54 #[inline]
55 fn length(&self) -> usize {
56 match self {
57 Self::EmptyRoot => 1,
58 Self::Branch(branch) => branch.length(),
59 Self::Extension(extension) => extension.length(),
60 Self::Leaf(leaf) => leaf.length(),
61 }
62 }
63}
64
65impl Decodable for TrieNode {
66 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
67 let mut items = match Header::decode_raw(buf)? {
68 alloy_rlp::PayloadView::List(list) => list,
69 alloy_rlp::PayloadView::String(val) => {
70 return if val.is_empty() {
71 Ok(Self::EmptyRoot)
72 } else {
73 Err(alloy_rlp::Error::UnexpectedString)
74 };
75 }
76 };
77
78 match items.len() {
81 17 => {
82 let mut branch = BranchNode::default();
83 for (idx, item) in items.into_iter().enumerate() {
84 if idx == 16 {
85 if item != [EMPTY_STRING_CODE] {
86 return Err(alloy_rlp::Error::Custom(
87 "branch node values are not supported",
88 ));
89 }
90 } else if item != [EMPTY_STRING_CODE] {
91 branch.stack.push(RlpNode::from_raw_rlp(item)?);
92 branch.state_mask.set_bit(idx as u8);
93 }
94 }
95 Ok(Self::Branch(branch))
96 }
97 2 => {
98 let mut key = items.remove(0);
99
100 let encoded_key = Header::decode_bytes(&mut key, false)?;
101 if encoded_key.is_empty() {
102 return Err(alloy_rlp::Error::Custom("trie node key empty"));
103 }
104
105 let key_flag = encoded_key[0] & 0xf0;
107 let first = match key_flag {
109 ExtensionNode::ODD_FLAG | LeafNode::ODD_FLAG => Some(encoded_key[0] & 0x0f),
110 ExtensionNode::EVEN_FLAG | LeafNode::EVEN_FLAG => None,
111 _ => return Err(alloy_rlp::Error::Custom("node is not extension or leaf")),
112 };
113
114 let key = unpack_path_to_nibbles(first, &encoded_key[1..]);
115 let node = if key_flag == LeafNode::EVEN_FLAG || key_flag == LeafNode::ODD_FLAG {
116 let value = Bytes::decode(&mut items.remove(0))?.into();
117 Self::Leaf(LeafNode::new(key, value))
118 } else {
119 Self::Extension(ExtensionNode::new(
121 key,
122 RlpNode::from_raw_rlp(items.remove(0))?,
123 ))
124 };
125 Ok(node)
126 }
127 _ => Err(alloy_rlp::Error::Custom("invalid number of items in the list")),
128 }
129 }
130}
131
132impl TrieNode {
133 #[inline]
135 pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
136 self.encode(rlp);
137 RlpNode::from_rlp(rlp)
138 }
139}
140
141#[inline]
143#[deprecated = "use `RlpNode::from_rlp` instead"]
144pub fn rlp_node(rlp: &[u8]) -> RlpNode {
145 RlpNode::from_rlp(rlp)
146}
147
148#[inline]
150#[deprecated = "use `RlpNode::word_rlp` instead"]
151pub fn word_rlp(word: &B256) -> RlpNode {
152 RlpNode::word_rlp(word)
153}
154
155#[inline]
166pub(crate) fn unpack_path_to_nibbles(first: Option<u8>, rest: &[u8]) -> Nibbles {
167 let rest = Nibbles::unpack(rest);
168 let Some(first) = first else { return rest };
169 debug_assert!(first <= 0xf);
170 Nibbles::from_nibbles_unchecked([first]).join(&rest)
172}
173
174#[inline]
224pub fn encode_path_leaf(nibbles: &Nibbles, is_leaf: bool) -> SmallVec<[u8; 36]> {
225 let mut nibbles = *nibbles;
226 let encoded_len = nibbles.len() / 2 + 1;
227 let odd_nibbles = nibbles.len() % 2 != 0;
228 unsafe {
230 nybbles::smallvec_with(encoded_len, |buf| {
231 let (first, rest) = buf.split_first_mut().unwrap_unchecked();
232 first.write(match (is_leaf, odd_nibbles) {
233 (true, true) => LeafNode::ODD_FLAG | nibbles.get_unchecked(0),
234 (true, false) => LeafNode::EVEN_FLAG,
235 (false, true) => ExtensionNode::ODD_FLAG | nibbles.get_unchecked(0),
236 (false, false) => ExtensionNode::EVEN_FLAG,
237 });
238 if odd_nibbles {
239 nibbles = nibbles.slice(1..);
240 }
241 nibbles.pack_to_slice_unchecked(rest);
242 })
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::TrieMask;
250 use alloy_primitives::hex;
251
252 #[test]
253 fn rlp_empty_root_node() {
254 let empty_root = TrieNode::EmptyRoot;
255 let rlp = empty_root.rlp(&mut vec![]);
256 assert_eq!(rlp[..], hex!("80"));
257 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), empty_root);
258 }
259
260 #[test]
261 fn rlp_zero_value_leaf_roundtrip() {
262 let leaf = TrieNode::Leaf(LeafNode::new(
263 Nibbles::from_nibbles_unchecked(hex!("0604060f")),
264 alloy_rlp::encode(alloy_primitives::U256::ZERO),
265 ));
266 let rlp = leaf.rlp(&mut vec![]);
267 assert_eq!(rlp[..], hex!("c68320646f8180"));
268 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
269 }
270
271 #[test]
272 fn rlp_trie_node_roundtrip() {
273 let leaf = TrieNode::Leaf(LeafNode::new(
275 Nibbles::from_nibbles_unchecked(hex!("0604060f")),
276 hex!("76657262").to_vec(),
277 ));
278 let rlp = leaf.rlp(&mut vec![]);
279 assert_eq!(rlp[..], hex!("c98320646f8476657262"));
280 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
281
282 let mut child = vec![];
284 hex!("76657262").to_vec().as_slice().encode(&mut child);
285 let extension = TrieNode::Extension(ExtensionNode::new(
286 Nibbles::from_nibbles_unchecked(hex!("0604060f")),
287 RlpNode::from_raw(&child).unwrap(),
288 ));
289 let rlp = extension.rlp(&mut vec![]);
290 assert_eq!(rlp[..], hex!("c98300646f8476657262"));
291 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), extension);
292
293 let branch = TrieNode::Branch(BranchNode::new(
295 core::iter::repeat_n(RlpNode::word_rlp(&B256::repeat_byte(23)), 16).collect(),
296 TrieMask::new(u16::MAX),
297 ));
298 let mut rlp = vec![];
299 let rlp_node = branch.rlp(&mut rlp);
300 assert_eq!(
301 rlp_node[..],
302 hex!("a0bed74980bbe29d9c4439c10e9c451e29b306fe74bcf9795ecf0ebbd92a220513")
303 );
304 assert_eq!(
305 rlp,
306 hex!(
307 "f90211a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a0171717171717171717171717171717171717171717171717171717171717171780"
308 )
309 );
310 assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), branch);
311 }
312
313 #[test]
314 fn hashed_encode_path_regression() {
315 let nibbles = Nibbles::from_nibbles(hex!(
316 "05010406040a040203030f010805020b050c04070003070e0909070f010b0a0805020301070c0a0902040b0f000f0006040a04050f020b090701000a0a040b"
317 ));
318 let path = encode_path_leaf(&nibbles, true);
319 let expected = hex!("351464a4233f1852b5c47037e997f1ba852317ca924bf0f064a45f2b9710aa4b");
320 assert_eq!(path[..], expected);
321 }
322
323 #[test]
324 #[cfg(feature = "arbitrary")]
325 #[cfg_attr(miri, ignore = "no proptest")]
326 fn encode_path_first_byte() {
327 use proptest::{collection::vec, prelude::*};
328
329 proptest::proptest!(|(input in vec(any::<u8>(), 0..32))| {
330 let input = Nibbles::unpack(&input);
331 prop_assert!(input.to_vec().iter().all(|&nibble| nibble <= 0xf));
332 let input_is_odd = input.len() % 2 == 1;
333
334 let compact_leaf = encode_path_leaf(&input, true);
335 let leaf_flag = compact_leaf[0];
336 assert_ne!(leaf_flag & LeafNode::EVEN_FLAG, 0);
338 assert_eq!(input_is_odd, (leaf_flag & ExtensionNode::ODD_FLAG) != 0);
339 if input_is_odd {
340 assert_eq!(leaf_flag & 0x0f, input.first().unwrap());
341 }
342
343 let compact_extension = encode_path_leaf(&input, false);
344 let extension_flag = compact_extension[0];
345 assert_eq!(extension_flag & LeafNode::EVEN_FLAG, 0);
347 assert_eq!(input_is_odd, (extension_flag & ExtensionNode::ODD_FLAG) != 0);
348 if input_is_odd {
349 assert_eq!(extension_flag & 0x0f, input.first().unwrap());
350 }
351 });
352 }
353}