alloy_trie/nodes/
mod.rs

1//! Various branch nodes produced by the hash builder.
2
3use alloy_primitives::{B256, Bytes};
4use alloy_rlp::{Decodable, EMPTY_STRING_CODE, Encodable, Header};
5use core::ops::Range;
6use nybbles::Nibbles;
7use smallvec::SmallVec;
8
9#[allow(unused_imports)]
10use alloc::vec::Vec;
11
12mod branch;
13pub use branch::{BranchNode, BranchNodeCompact, BranchNodeRef};
14
15mod extension;
16pub use extension::{ExtensionNode, ExtensionNodeRef};
17
18mod leaf;
19pub use leaf::{LeafNode, LeafNodeRef};
20
21mod rlp;
22pub use rlp::RlpNode;
23
24/// The range of valid child indexes.
25pub const CHILD_INDEX_RANGE: Range<u8> = 0..16;
26
27/// Enum representing an MPT trie node.
28#[derive(PartialEq, Eq, Clone, Debug)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum TrieNode {
31    /// Variant representing empty root node.
32    EmptyRoot,
33    /// Variant representing a [BranchNode].
34    Branch(BranchNode),
35    /// Variant representing a [ExtensionNode].
36    Extension(ExtensionNode),
37    /// Variant representing a [LeafNode].
38    Leaf(LeafNode),
39}
40
41impl Encodable for TrieNode {
42    #[inline]
43    fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
44        match self {
45            Self::EmptyRoot => {
46                out.put_u8(EMPTY_STRING_CODE);
47            }
48            Self::Branch(branch) => branch.encode(out),
49            Self::Extension(extension) => extension.encode(out),
50            Self::Leaf(leaf) => leaf.encode(out),
51        }
52    }
53
54    #[inline]
55    fn length(&self) -> usize {
56        match self {
57            Self::EmptyRoot => 1,
58            Self::Branch(branch) => branch.length(),
59            Self::Extension(extension) => extension.length(),
60            Self::Leaf(leaf) => leaf.length(),
61        }
62    }
63}
64
65impl Decodable for TrieNode {
66    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
67        let mut items = match Header::decode_raw(buf)? {
68            alloy_rlp::PayloadView::List(list) => list,
69            alloy_rlp::PayloadView::String(val) => {
70                return if val.is_empty() {
71                    Ok(Self::EmptyRoot)
72                } else {
73                    Err(alloy_rlp::Error::UnexpectedString)
74                };
75            }
76        };
77
78        // A valid number of trie node items is either 17 (branch node)
79        // or 2 (extension or leaf node).
80        match items.len() {
81            17 => {
82                let mut branch = BranchNode::default();
83                for (idx, item) in items.into_iter().enumerate() {
84                    if idx == 16 {
85                        if item != [EMPTY_STRING_CODE] {
86                            return Err(alloy_rlp::Error::Custom(
87                                "branch node values are not supported",
88                            ));
89                        }
90                    } else if item != [EMPTY_STRING_CODE] {
91                        branch.stack.push(RlpNode::from_raw_rlp(item)?);
92                        branch.state_mask.set_bit(idx as u8);
93                    }
94                }
95                Ok(Self::Branch(branch))
96            }
97            2 => {
98                let mut key = items.remove(0);
99
100                let encoded_key = Header::decode_bytes(&mut key, false)?;
101                if encoded_key.is_empty() {
102                    return Err(alloy_rlp::Error::Custom("trie node key empty"));
103                }
104
105                // extract the high order part of the nibble to then pick the odd nibble out
106                let key_flag = encoded_key[0] & 0xf0;
107                // Retrieve first byte. If it's [Some], then the nibbles are odd.
108                let first = match key_flag {
109                    ExtensionNode::ODD_FLAG | LeafNode::ODD_FLAG => Some(encoded_key[0] & 0x0f),
110                    ExtensionNode::EVEN_FLAG | LeafNode::EVEN_FLAG => None,
111                    _ => return Err(alloy_rlp::Error::Custom("node is not extension or leaf")),
112                };
113
114                let key = unpack_path_to_nibbles(first, &encoded_key[1..]);
115                let node = if key_flag == LeafNode::EVEN_FLAG || key_flag == LeafNode::ODD_FLAG {
116                    let value = Bytes::decode(&mut items.remove(0))?.into();
117                    Self::Leaf(LeafNode::new(key, value))
118                } else {
119                    // We don't decode value because it is expected to be RLP encoded.
120                    Self::Extension(ExtensionNode::new(
121                        key,
122                        RlpNode::from_raw_rlp(items.remove(0))?,
123                    ))
124                };
125                Ok(node)
126            }
127            _ => Err(alloy_rlp::Error::Custom("invalid number of items in the list")),
128        }
129    }
130}
131
132impl TrieNode {
133    /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`.
134    #[inline]
135    pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
136        self.encode(rlp);
137        RlpNode::from_rlp(rlp)
138    }
139}
140
141/// Given an RLP-encoded node, returns it either as `rlp(node)` or `rlp(keccak(rlp(node)))`.
142#[inline]
143#[deprecated = "use `RlpNode::from_rlp` instead"]
144pub fn rlp_node(rlp: &[u8]) -> RlpNode {
145    RlpNode::from_rlp(rlp)
146}
147
148/// Optimization for quick RLP-encoding of a 32-byte word.
149#[inline]
150#[deprecated = "use `RlpNode::word_rlp` instead"]
151pub fn word_rlp(word: &B256) -> RlpNode {
152    RlpNode::word_rlp(word)
153}
154
155/// Unpack node path to nibbles.
156///
157/// NOTE: The first nibble should be less than or equal to `0xf` if provided.
158/// If first nibble is greater than `0xf`, the method will not panic, but initialize invalid nibbles
159/// instead.
160///
161/// ## Arguments
162///
163/// `first` - first nibble of the path if it is odd
164/// `rest` - rest of the nibbles packed
165#[inline]
166pub(crate) fn unpack_path_to_nibbles(first: Option<u8>, rest: &[u8]) -> Nibbles {
167    let rest = Nibbles::unpack(rest);
168    let Some(first) = first else { return rest };
169    debug_assert!(first <= 0xf);
170    // TODO: optimize
171    Nibbles::from_nibbles_unchecked([first]).join(&rest)
172}
173
174/// Encodes a given path leaf as a compact array of bytes.
175///
176/// In resulted array, each byte represents two "nibbles" (half-bytes or 4 bits) of the original hex
177/// data, along with additional information about the leaf itself.
178///
179/// The method takes the following input:
180/// `is_leaf`: A boolean value indicating whether the current node is a leaf node or not.
181///
182/// The first byte of the encoded vector is set based on the `is_leaf` flag and the parity of
183/// the hex data length (even or odd number of nibbles).
184///  - If the node is an extension with even length, the header byte is `0x00`.
185///  - If the node is an extension with odd length, the header byte is `0x10 + <first nibble>`.
186///  - If the node is a leaf with even length, the header byte is `0x20`.
187///  - If the node is a leaf with odd length, the header byte is `0x30 + <first nibble>`.
188///
189/// If there is an odd number of nibbles, store the first nibble in the lower 4 bits of the
190/// first byte of encoded.
191///
192/// # Returns
193///
194/// A vector containing the compact byte representation of the nibble sequence, including the
195/// header byte.
196///
197/// This vector's length is `self.len() / 2 + 1`. For stack-allocated nibbles, this is at most
198/// 33 bytes, so 36 was chosen as the stack capacity to round up to the next usize-aligned
199/// size.
200///
201/// # Examples
202///
203/// ```
204/// use alloy_trie::nodes::encode_path_leaf;
205/// use nybbles::Nibbles;
206///
207/// // Extension node with an even path length:
208/// let nibbles = Nibbles::from_nibbles(&[0x0A, 0x0B, 0x0C, 0x0D]);
209/// assert_eq!(encode_path_leaf(&nibbles, false)[..], [0x00, 0xAB, 0xCD]);
210///
211/// // Extension node with an odd path length:
212/// let nibbles = Nibbles::from_nibbles(&[0x0A, 0x0B, 0x0C]);
213/// assert_eq!(encode_path_leaf(&nibbles, false)[..], [0x1A, 0xBC]);
214///
215/// // Leaf node with an even path length:
216/// let nibbles = Nibbles::from_nibbles(&[0x0A, 0x0B, 0x0C, 0x0D]);
217/// assert_eq!(encode_path_leaf(&nibbles, true)[..], [0x20, 0xAB, 0xCD]);
218///
219/// // Leaf node with an odd path length:
220/// let nibbles = Nibbles::from_nibbles(&[0x0A, 0x0B, 0x0C]);
221/// assert_eq!(encode_path_leaf(&nibbles, true)[..], [0x3A, 0xBC]);
222/// ```
223#[inline]
224pub fn encode_path_leaf(nibbles: &Nibbles, is_leaf: bool) -> SmallVec<[u8; 36]> {
225    let mut nibbles = *nibbles;
226    let encoded_len = nibbles.len() / 2 + 1;
227    let odd_nibbles = nibbles.len() % 2 != 0;
228    // SAFETY: `len` is calculated correctly.
229    unsafe {
230        nybbles::smallvec_with(encoded_len, |buf| {
231            let (first, rest) = buf.split_first_mut().unwrap_unchecked();
232            first.write(match (is_leaf, odd_nibbles) {
233                (true, true) => LeafNode::ODD_FLAG | nibbles.get_unchecked(0),
234                (true, false) => LeafNode::EVEN_FLAG,
235                (false, true) => ExtensionNode::ODD_FLAG | nibbles.get_unchecked(0),
236                (false, false) => ExtensionNode::EVEN_FLAG,
237            });
238            if odd_nibbles {
239                nibbles = nibbles.slice(1..);
240            }
241            nibbles.pack_to_slice_unchecked(rest);
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::TrieMask;
250    use alloy_primitives::hex;
251
252    #[test]
253    fn rlp_empty_root_node() {
254        let empty_root = TrieNode::EmptyRoot;
255        let rlp = empty_root.rlp(&mut vec![]);
256        assert_eq!(rlp[..], hex!("80"));
257        assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), empty_root);
258    }
259
260    #[test]
261    fn rlp_zero_value_leaf_roundtrip() {
262        let leaf = TrieNode::Leaf(LeafNode::new(
263            Nibbles::from_nibbles_unchecked(hex!("0604060f")),
264            alloy_rlp::encode(alloy_primitives::U256::ZERO),
265        ));
266        let rlp = leaf.rlp(&mut vec![]);
267        assert_eq!(rlp[..], hex!("c68320646f8180"));
268        assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
269    }
270
271    #[test]
272    fn rlp_trie_node_roundtrip() {
273        // leaf
274        let leaf = TrieNode::Leaf(LeafNode::new(
275            Nibbles::from_nibbles_unchecked(hex!("0604060f")),
276            hex!("76657262").to_vec(),
277        ));
278        let rlp = leaf.rlp(&mut vec![]);
279        assert_eq!(rlp[..], hex!("c98320646f8476657262"));
280        assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
281
282        // extension
283        let mut child = vec![];
284        hex!("76657262").to_vec().as_slice().encode(&mut child);
285        let extension = TrieNode::Extension(ExtensionNode::new(
286            Nibbles::from_nibbles_unchecked(hex!("0604060f")),
287            RlpNode::from_raw(&child).unwrap(),
288        ));
289        let rlp = extension.rlp(&mut vec![]);
290        assert_eq!(rlp[..], hex!("c98300646f8476657262"));
291        assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), extension);
292
293        // branch
294        let branch = TrieNode::Branch(BranchNode::new(
295            core::iter::repeat_n(RlpNode::word_rlp(&B256::repeat_byte(23)), 16).collect(),
296            TrieMask::new(u16::MAX),
297        ));
298        let mut rlp = vec![];
299        let rlp_node = branch.rlp(&mut rlp);
300        assert_eq!(
301            rlp_node[..],
302            hex!("a0bed74980bbe29d9c4439c10e9c451e29b306fe74bcf9795ecf0ebbd92a220513")
303        );
304        assert_eq!(
305            rlp,
306            hex!(
307                "f90211a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a0171717171717171717171717171717171717171717171717171717171717171780"
308            )
309        );
310        assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), branch);
311    }
312
313    #[test]
314    fn hashed_encode_path_regression() {
315        let nibbles = Nibbles::from_nibbles(hex!(
316            "05010406040a040203030f010805020b050c04070003070e0909070f010b0a0805020301070c0a0902040b0f000f0006040a04050f020b090701000a0a040b"
317        ));
318        let path = encode_path_leaf(&nibbles, true);
319        let expected = hex!("351464a4233f1852b5c47037e997f1ba852317ca924bf0f064a45f2b9710aa4b");
320        assert_eq!(path[..], expected);
321    }
322
323    #[test]
324    #[cfg(feature = "arbitrary")]
325    #[cfg_attr(miri, ignore = "no proptest")]
326    fn encode_path_first_byte() {
327        use proptest::{collection::vec, prelude::*};
328
329        proptest::proptest!(|(input in vec(any::<u8>(), 0..32))| {
330            let input = Nibbles::unpack(&input);
331            prop_assert!(input.to_vec().iter().all(|&nibble| nibble <= 0xf));
332            let input_is_odd = input.len() % 2 == 1;
333
334            let compact_leaf = encode_path_leaf(&input, true);
335            let leaf_flag = compact_leaf[0];
336            // Check flag
337            assert_ne!(leaf_flag & LeafNode::EVEN_FLAG, 0);
338            assert_eq!(input_is_odd, (leaf_flag & ExtensionNode::ODD_FLAG) != 0);
339            if input_is_odd {
340                assert_eq!(leaf_flag & 0x0f, input.first().unwrap());
341            }
342
343            let compact_extension = encode_path_leaf(&input, false);
344            let extension_flag = compact_extension[0];
345            // Check first byte
346            assert_eq!(extension_flag & LeafNode::EVEN_FLAG, 0);
347            assert_eq!(input_is_odd, (extension_flag & ExtensionNode::ODD_FLAG) != 0);
348            if input_is_odd {
349                assert_eq!(extension_flag & 0x0f, input.first().unwrap());
350            }
351        });
352    }
353}