1use super::{super::TrieMask, RlpNode, CHILD_INDEX_RANGE};
2use alloy_primitives::{hex, B256};
3use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE};
4use core::{fmt, ops::Range, slice::Iter};
5
6use alloc::sync::Arc;
7#[allow(unused_imports)]
8use alloc::vec::Vec;
9
10#[derive(PartialEq, Eq, Clone, Default)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct BranchNode {
18 pub stack: Vec<RlpNode>,
20 pub state_mask: TrieMask,
22}
23
24impl fmt::Debug for BranchNode {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 f.debug_struct("BranchNode")
27 .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
28 .field("state_mask", &self.state_mask)
29 .field("first_child_index", &self.as_ref().first_child_index())
30 .finish()
31 }
32}
33
34impl Encodable for BranchNode {
35 #[inline]
36 fn encode(&self, out: &mut dyn BufMut) {
37 self.as_ref().encode(out)
38 }
39
40 #[inline]
41 fn length(&self) -> usize {
42 self.as_ref().length()
43 }
44}
45
46impl Decodable for BranchNode {
47 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
48 let mut bytes = Header::decode_bytes(buf, true)?;
49
50 let mut stack = Vec::new();
51 let mut state_mask = TrieMask::default();
52 for index in CHILD_INDEX_RANGE {
53 if bytes.len() <= 1 {
55 return Err(alloy_rlp::Error::InputTooShort);
56 }
57
58 if bytes[0] == EMPTY_STRING_CODE {
59 bytes.advance(1);
60 continue;
61 }
62
63 let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?;
65 let len = payload_length + length_of_length(payload_length);
66 stack.push(RlpNode::from_raw_rlp(&bytes[..len])?);
67 bytes.advance(len);
68 state_mask.set_bit(index);
69 }
70
71 let bytes = Header::decode_bytes(&mut bytes, false)?;
73 if !bytes.is_empty() {
74 return Err(alloy_rlp::Error::Custom("branch values not supported"));
75 }
76 debug_assert!(bytes.is_empty(), "bytes {}", alloy_primitives::hex::encode(bytes));
77
78 Ok(Self { stack, state_mask })
79 }
80}
81
82impl BranchNode {
83 pub const fn new(stack: Vec<RlpNode>, state_mask: TrieMask) -> Self {
85 Self { stack, state_mask }
86 }
87
88 pub fn as_ref(&self) -> BranchNodeRef<'_> {
90 BranchNodeRef::new(&self.stack, self.state_mask)
91 }
92}
93
94#[derive(Clone)]
97pub struct BranchNodeRef<'a> {
98 pub stack: &'a [RlpNode],
103 pub state_mask: TrieMask,
106}
107
108impl fmt::Debug for BranchNodeRef<'_> {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("BranchNodeRef")
111 .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
112 .field("state_mask", &self.state_mask)
113 .field("first_child_index", &self.first_child_index())
114 .finish()
115 }
116}
117
118impl Encodable for BranchNodeRef<'_> {
122 #[inline]
123 fn encode(&self, out: &mut dyn BufMut) {
124 Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);
125
126 for (_, child) in self.children() {
128 if let Some(child) = child {
129 out.put_slice(child);
130 } else {
131 out.put_u8(EMPTY_STRING_CODE);
132 }
133 }
134
135 out.put_u8(EMPTY_STRING_CODE);
136 }
137
138 #[inline]
139 fn length(&self) -> usize {
140 let payload_length = self.rlp_payload_length();
141 payload_length + length_of_length(payload_length)
142 }
143}
144
145impl<'a> BranchNodeRef<'a> {
146 #[inline]
148 pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self {
149 Self { stack, state_mask }
150 }
151
152 #[inline]
159 pub fn first_child_index(&self) -> usize {
160 self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
161 }
162
163 #[inline]
165 pub fn children(&self) -> impl Iterator<Item = (u8, Option<&RlpNode>)> + '_ {
166 BranchChildrenIter::new(self)
167 }
168
169 #[inline]
172 pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator<Item = B256> + '_ {
173 self.children()
174 .filter_map(|(i, c)| c.map(|c| (i, c)))
175 .filter(move |(index, _)| hash_mask.is_bit_set(*index))
176 .map(|(_, child)| B256::from_slice(&child[1..]))
177 }
178
179 #[inline]
181 pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
182 self.encode(rlp);
183 RlpNode::from_rlp(rlp)
184 }
185
186 #[inline]
188 fn rlp_payload_length(&self) -> usize {
189 let mut payload_length = 1;
190 for (_, child) in self.children() {
191 if let Some(child) = child {
192 payload_length += child.len();
193 } else {
194 payload_length += 1;
195 }
196 }
197 payload_length
198 }
199}
200
201#[derive(Debug)]
203struct BranchChildrenIter<'a> {
204 range: Range<u8>,
205 state_mask: TrieMask,
206 stack_iter: Iter<'a, RlpNode>,
207}
208
209impl<'a> BranchChildrenIter<'a> {
210 fn new(node: &BranchNodeRef<'a>) -> Self {
212 Self {
213 range: CHILD_INDEX_RANGE,
214 state_mask: node.state_mask,
215 stack_iter: node.stack[node.first_child_index()..].iter(),
216 }
217 }
218}
219
220impl<'a> Iterator for BranchChildrenIter<'a> {
221 type Item = (u8, Option<&'a RlpNode>);
222
223 #[inline]
224 fn next(&mut self) -> Option<Self::Item> {
225 let i = self.range.next()?;
226 let value = if self.state_mask.is_bit_set(i) {
227 Some(unsafe { self.stack_iter.next().unwrap_unchecked() })
230 } else {
231 None
232 };
233 Some((i, value))
234 }
235
236 #[inline]
237 fn size_hint(&self) -> (usize, Option<usize>) {
238 let len = self.len();
239 (len, Some(len))
240 }
241}
242
243impl core::iter::FusedIterator for BranchChildrenIter<'_> {}
244
245impl ExactSizeIterator for BranchChildrenIter<'_> {
246 #[inline]
247 fn len(&self) -> usize {
248 self.range.len()
249 }
250}
251
252#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
261#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
262pub struct BranchNodeCompact {
263 pub state_mask: TrieMask,
268 pub tree_mask: TrieMask,
274 pub hash_mask: TrieMask,
280 pub hashes: Arc<Vec<B256>>,
283 pub root_hash: Option<B256>,
285}
286
287impl BranchNodeCompact {
288 pub fn new(
290 state_mask: impl Into<TrieMask>,
291 tree_mask: impl Into<TrieMask>,
292 hash_mask: impl Into<TrieMask>,
293 hashes: Vec<B256>,
294 root_hash: Option<B256>,
295 ) -> Self {
296 let (state_mask, tree_mask, hash_mask) =
297 (state_mask.into(), tree_mask.into(), hash_mask.into());
298 assert!(
299 tree_mask.is_subset_of(state_mask),
300 "state mask: {state_mask:?} tree mask: {tree_mask:?}"
301 );
302 assert!(
303 hash_mask.is_subset_of(state_mask),
304 "state_mask {state_mask:?} hash_mask: {hash_mask:?}"
305 );
306 assert_eq!(hash_mask.count_ones() as usize, hashes.len());
307 Self { state_mask, tree_mask, hash_mask, hashes: hashes.into(), root_hash }
308 }
309
310 pub fn hash_for_nibble(&self, nibble: u8) -> B256 {
312 let mask = *TrieMask::from_nibble(nibble) - 1;
313 let index = (*self.hash_mask & mask).count_ones();
314 self.hashes[index as usize]
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::nodes::{ExtensionNode, LeafNode};
322 use nybbles::Nibbles;
323
324 #[test]
325 fn rlp_branch_node_roundtrip() {
326 let empty = BranchNode::default();
327 let encoded = alloy_rlp::encode(&empty);
328 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty);
329
330 let sparse_node = BranchNode::new(
331 vec![
332 RlpNode::word_rlp(&B256::repeat_byte(1)),
333 RlpNode::word_rlp(&B256::repeat_byte(2)),
334 ],
335 TrieMask::new(0b1000100),
336 );
337 let encoded = alloy_rlp::encode(&sparse_node);
338 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node);
339
340 let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec());
341 let mut buf = vec![];
342 let leaf_rlp = leaf_child.as_ref().rlp(&mut buf);
343 let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010));
344 let encoded = alloy_rlp::encode(&branch_with_leaf);
345 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_leaf);
346
347 let extension_child = ExtensionNode::new(Nibbles::from_nibbles(hex!("0203")), leaf_rlp);
348 let mut buf = vec![];
349 let extension_rlp = extension_child.as_ref().rlp(&mut buf);
350 let branch_with_ext = BranchNode::new(vec![extension_rlp], TrieMask::new(0b00000100000));
351 let encoded = alloy_rlp::encode(&branch_with_ext);
352 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext);
353
354 let full = BranchNode::new(
355 core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
356 TrieMask::new(u16::MAX),
357 );
358 let encoded = alloy_rlp::encode(&full);
359 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), full);
360 }
361}