alloy_trie/nodes/
branch.rs

1use super::{super::TrieMask, RlpNode, CHILD_INDEX_RANGE};
2use alloy_primitives::{hex, B256};
3use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE};
4use core::{fmt, ops::Range, slice::Iter};
5
6use alloc::sync::Arc;
7#[allow(unused_imports)]
8use alloc::vec::Vec;
9
10/// A branch node in an Ethereum Merkle Patricia Trie.
11///
12/// Branch node is a 17-element array consisting of 16 slots that correspond to each hexadecimal
13/// character and an additional slot for a value. We do exclude the node value since all paths have
14/// a fixed size.
15#[derive(PartialEq, Eq, Clone, Default)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct BranchNode {
18    /// The collection of RLP encoded children.
19    pub stack: Vec<RlpNode>,
20    /// The bitmask indicating the presence of children at the respective nibble positions
21    pub state_mask: TrieMask,
22}
23
24impl fmt::Debug for BranchNode {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.debug_struct("BranchNode")
27            .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
28            .field("state_mask", &self.state_mask)
29            .field("first_child_index", &self.as_ref().first_child_index())
30            .finish()
31    }
32}
33
34impl Encodable for BranchNode {
35    #[inline]
36    fn encode(&self, out: &mut dyn BufMut) {
37        self.as_ref().encode(out)
38    }
39
40    #[inline]
41    fn length(&self) -> usize {
42        self.as_ref().length()
43    }
44}
45
46impl Decodable for BranchNode {
47    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
48        let mut bytes = Header::decode_bytes(buf, true)?;
49
50        let mut stack = Vec::new();
51        let mut state_mask = TrieMask::default();
52        for index in CHILD_INDEX_RANGE {
53            // The buffer must contain empty string code for value.
54            if bytes.len() <= 1 {
55                return Err(alloy_rlp::Error::InputTooShort);
56            }
57
58            if bytes[0] == EMPTY_STRING_CODE {
59                bytes.advance(1);
60                continue;
61            }
62
63            // Decode without advancing
64            let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?;
65            let len = payload_length + length_of_length(payload_length);
66            stack.push(RlpNode::from_raw_rlp(&bytes[..len])?);
67            bytes.advance(len);
68            state_mask.set_bit(index);
69        }
70
71        // Consume empty string code for branch node value.
72        let bytes = Header::decode_bytes(&mut bytes, false)?;
73        if !bytes.is_empty() {
74            return Err(alloy_rlp::Error::Custom("branch values not supported"));
75        }
76        debug_assert!(bytes.is_empty(), "bytes {}", alloy_primitives::hex::encode(bytes));
77
78        Ok(Self { stack, state_mask })
79    }
80}
81
82impl BranchNode {
83    /// Creates a new branch node with the given stack and state mask.
84    pub const fn new(stack: Vec<RlpNode>, state_mask: TrieMask) -> Self {
85        Self { stack, state_mask }
86    }
87
88    /// Return branch node as [BranchNodeRef].
89    pub fn as_ref(&self) -> BranchNodeRef<'_> {
90        BranchNodeRef::new(&self.stack, self.state_mask)
91    }
92}
93
94/// A reference to [BranchNode] and its state mask.
95/// NOTE: The stack may contain more items that specified in the state mask.
96#[derive(Clone)]
97pub struct BranchNodeRef<'a> {
98    /// Reference to the collection of RLP encoded nodes.
99    /// NOTE: The referenced stack might have more items than the number of children
100    /// for this node. We should only ever access items starting from
101    /// [BranchNodeRef::first_child_index].
102    pub stack: &'a [RlpNode],
103    /// Reference to bitmask indicating the presence of children at
104    /// the respective nibble positions.
105    pub state_mask: TrieMask,
106}
107
108impl fmt::Debug for BranchNodeRef<'_> {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        f.debug_struct("BranchNodeRef")
111            .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
112            .field("state_mask", &self.state_mask)
113            .field("first_child_index", &self.first_child_index())
114            .finish()
115    }
116}
117
118/// Implementation of RLP encoding for branch node in Ethereum Merkle Patricia Trie.
119/// Encode it as a 17-element list consisting of 16 slots that correspond to
120/// each child of the node (0-f) and an additional slot for a value.
121impl Encodable for BranchNodeRef<'_> {
122    #[inline]
123    fn encode(&self, out: &mut dyn BufMut) {
124        Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);
125
126        // Extend the RLP buffer with the present children
127        for (_, child) in self.children() {
128            if let Some(child) = child {
129                out.put_slice(child);
130            } else {
131                out.put_u8(EMPTY_STRING_CODE);
132            }
133        }
134
135        out.put_u8(EMPTY_STRING_CODE);
136    }
137
138    #[inline]
139    fn length(&self) -> usize {
140        let payload_length = self.rlp_payload_length();
141        payload_length + length_of_length(payload_length)
142    }
143}
144
145impl<'a> BranchNodeRef<'a> {
146    /// Create a new branch node from the stack of nodes.
147    #[inline]
148    pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self {
149        Self { stack, state_mask }
150    }
151
152    /// Returns the stack index of the first child for this node.
153    ///
154    /// # Panics
155    ///
156    /// If the stack length is less than number of children specified in state mask.
157    /// Means that the node is in inconsistent state.
158    #[inline]
159    pub fn first_child_index(&self) -> usize {
160        self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
161    }
162
163    /// Returns an iterator over children of the branch node.
164    #[inline]
165    pub fn children(&self) -> impl Iterator<Item = (u8, Option<&RlpNode>)> + '_ {
166        BranchChildrenIter::new(self)
167    }
168
169    /// Given the hash mask of children, return an iterator over stack items
170    /// that match the mask.
171    #[inline]
172    pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator<Item = B256> + '_ {
173        self.children()
174            .filter_map(|(i, c)| c.map(|c| (i, c)))
175            .filter(move |(index, _)| hash_mask.is_bit_set(*index))
176            .map(|(_, child)| B256::from_slice(&child[1..]))
177    }
178
179    /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`.
180    #[inline]
181    pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
182        self.encode(rlp);
183        RlpNode::from_rlp(rlp)
184    }
185
186    /// Returns the length of RLP encoded fields of branch node.
187    #[inline]
188    fn rlp_payload_length(&self) -> usize {
189        let mut payload_length = 1;
190        for (_, child) in self.children() {
191            if let Some(child) = child {
192                payload_length += child.len();
193            } else {
194                payload_length += 1;
195            }
196        }
197        payload_length
198    }
199}
200
201/// Iterator over branch node children.
202#[derive(Debug)]
203struct BranchChildrenIter<'a> {
204    range: Range<u8>,
205    state_mask: TrieMask,
206    stack_iter: Iter<'a, RlpNode>,
207}
208
209impl<'a> BranchChildrenIter<'a> {
210    /// Create new iterator over branch node children.
211    fn new(node: &BranchNodeRef<'a>) -> Self {
212        Self {
213            range: CHILD_INDEX_RANGE,
214            state_mask: node.state_mask,
215            stack_iter: node.stack[node.first_child_index()..].iter(),
216        }
217    }
218}
219
220impl<'a> Iterator for BranchChildrenIter<'a> {
221    type Item = (u8, Option<&'a RlpNode>);
222
223    #[inline]
224    fn next(&mut self) -> Option<Self::Item> {
225        let i = self.range.next()?;
226        let value = if self.state_mask.is_bit_set(i) {
227            // SAFETY: `first_child_index` guarantees that `stack` is exactly
228            // `state_mask.count_ones()` long.
229            Some(unsafe { self.stack_iter.next().unwrap_unchecked() })
230        } else {
231            None
232        };
233        Some((i, value))
234    }
235
236    #[inline]
237    fn size_hint(&self) -> (usize, Option<usize>) {
238        let len = self.len();
239        (len, Some(len))
240    }
241}
242
243impl core::iter::FusedIterator for BranchChildrenIter<'_> {}
244
245impl ExactSizeIterator for BranchChildrenIter<'_> {
246    #[inline]
247    fn len(&self) -> usize {
248        self.range.len()
249    }
250}
251
252/// A struct representing a branch node in an Ethereum trie.
253///
254/// A branch node can have up to 16 children, each corresponding to one of the possible nibble
255/// values (`0` to `f`) in the trie's path.
256///
257/// The masks in a BranchNode are used to efficiently represent and manage information about the
258/// presence and types of its children. They are bitmasks, where each bit corresponds to a nibble
259/// (half-byte, or 4 bits) value from `0` to `f`.
260#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
261#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
262pub struct BranchNodeCompact {
263    /// The bitmask indicating the presence of children at the respective nibble positions in the
264    /// trie. If the bit at position i (counting from the right) is set (1), it indicates that a
265    /// child exists for the nibble value i. If the bit is unset (0), it means there is no child
266    /// for that nibble value.
267    pub state_mask: TrieMask,
268    /// The bitmask representing the children at the respective nibble positions in the trie that
269    /// are also stored in the database. If the bit at position `i` (counting from the right)
270    /// is set (1) and also present in the state_mask, it indicates that the corresponding
271    /// child at the nibble value `i` is stored in the database. If the bit is unset (0), it means
272    /// the child is not stored in the database.
273    pub tree_mask: TrieMask,
274    /// The bitmask representing the hashed branch children nodes at the respective nibble
275    /// positions in the trie. If the bit at position `i` (counting from the right) is set (1)
276    /// and also present in the state_mask, it indicates that the corresponding child at the
277    /// nibble value `i` is a hashed branch child node. If the bit is unset (0), it means the child
278    /// is not a hashed branch child node.
279    pub hash_mask: TrieMask,
280    /// Collection of hashes associated with the children of the branch node.
281    /// Each child hash is calculated by hashing two consecutive sub-branch roots.
282    pub hashes: Arc<Vec<B256>>,
283    /// An optional root hash of the subtree rooted at this branch node.
284    pub root_hash: Option<B256>,
285}
286
287impl BranchNodeCompact {
288    /// Creates a new [BranchNodeCompact] from the given parameters.
289    pub fn new(
290        state_mask: impl Into<TrieMask>,
291        tree_mask: impl Into<TrieMask>,
292        hash_mask: impl Into<TrieMask>,
293        hashes: Vec<B256>,
294        root_hash: Option<B256>,
295    ) -> Self {
296        let (state_mask, tree_mask, hash_mask) =
297            (state_mask.into(), tree_mask.into(), hash_mask.into());
298        assert!(
299            tree_mask.is_subset_of(state_mask),
300            "state mask: {state_mask:?} tree mask: {tree_mask:?}"
301        );
302        assert!(
303            hash_mask.is_subset_of(state_mask),
304            "state_mask {state_mask:?} hash_mask: {hash_mask:?}"
305        );
306        assert_eq!(hash_mask.count_ones() as usize, hashes.len());
307        Self { state_mask, tree_mask, hash_mask, hashes: hashes.into(), root_hash }
308    }
309
310    /// Returns the hash associated with the given nibble.
311    pub fn hash_for_nibble(&self, nibble: u8) -> B256 {
312        let mask = *TrieMask::from_nibble(nibble) - 1;
313        let index = (*self.hash_mask & mask).count_ones();
314        self.hashes[index as usize]
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::nodes::{ExtensionNode, LeafNode};
322    use nybbles::Nibbles;
323
324    #[test]
325    fn rlp_branch_node_roundtrip() {
326        let empty = BranchNode::default();
327        let encoded = alloy_rlp::encode(&empty);
328        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty);
329
330        let sparse_node = BranchNode::new(
331            vec![
332                RlpNode::word_rlp(&B256::repeat_byte(1)),
333                RlpNode::word_rlp(&B256::repeat_byte(2)),
334            ],
335            TrieMask::new(0b1000100),
336        );
337        let encoded = alloy_rlp::encode(&sparse_node);
338        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node);
339
340        let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec());
341        let mut buf = vec![];
342        let leaf_rlp = leaf_child.as_ref().rlp(&mut buf);
343        let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010));
344        let encoded = alloy_rlp::encode(&branch_with_leaf);
345        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_leaf);
346
347        let extension_child = ExtensionNode::new(Nibbles::from_nibbles(hex!("0203")), leaf_rlp);
348        let mut buf = vec![];
349        let extension_rlp = extension_child.as_ref().rlp(&mut buf);
350        let branch_with_ext = BranchNode::new(vec![extension_rlp], TrieMask::new(0b00000100000));
351        let encoded = alloy_rlp::encode(&branch_with_ext);
352        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext);
353
354        let full = BranchNode::new(
355            core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
356            TrieMask::new(u16::MAX),
357        );
358        let encoded = alloy_rlp::encode(&full);
359        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), full);
360    }
361}