use alloy_primitives::{Bytes, B256};
use alloy_rlp::{Decodable, Encodable, Header, EMPTY_STRING_CODE};
use core::ops::Range;
use nybbles::Nibbles;
use smallvec::SmallVec;
#[allow(unused_imports)]
use alloc::vec::Vec;
mod branch;
pub use branch::{BranchNode, BranchNodeCompact, BranchNodeRef};
mod extension;
pub use extension::{ExtensionNode, ExtensionNodeRef};
mod leaf;
pub use leaf::{LeafNode, LeafNodeRef};
mod rlp;
pub use rlp::RlpNode;
pub const CHILD_INDEX_RANGE: Range<u8> = 0..16;
#[derive(PartialEq, Eq, Clone, Debug)]
pub enum TrieNode {
EmptyRoot,
Branch(BranchNode),
Extension(ExtensionNode),
Leaf(LeafNode),
}
impl Encodable for TrieNode {
#[inline]
fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
match self {
Self::EmptyRoot => {
out.put_u8(EMPTY_STRING_CODE);
}
Self::Branch(branch) => branch.encode(out),
Self::Extension(extension) => extension.encode(out),
Self::Leaf(leaf) => leaf.encode(out),
}
}
#[inline]
fn length(&self) -> usize {
match self {
Self::EmptyRoot => 1,
Self::Branch(branch) => branch.length(),
Self::Extension(extension) => extension.length(),
Self::Leaf(leaf) => leaf.length(),
}
}
}
impl Decodable for TrieNode {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let mut items = match Header::decode_raw(buf)? {
alloy_rlp::PayloadView::List(list) => list,
alloy_rlp::PayloadView::String(val) => {
return if val.is_empty() {
Ok(Self::EmptyRoot)
} else {
Err(alloy_rlp::Error::UnexpectedString)
}
}
};
match items.len() {
17 => {
let mut branch = BranchNode::default();
for (idx, item) in items.into_iter().enumerate() {
if idx == 16 {
if item != [EMPTY_STRING_CODE] {
return Err(alloy_rlp::Error::Custom(
"branch node values are not supported",
));
}
} else if item != [EMPTY_STRING_CODE] {
branch.stack.push(RlpNode::from_raw_rlp(item)?);
branch.state_mask.set_bit(idx as u8);
}
}
Ok(Self::Branch(branch))
}
2 => {
let mut key = items.remove(0);
let encoded_key = Header::decode_bytes(&mut key, false)?;
if encoded_key.is_empty() {
return Err(alloy_rlp::Error::Custom("trie node key empty"));
}
let key_flag = encoded_key[0] & 0xf0;
let first = match key_flag {
ExtensionNode::ODD_FLAG | LeafNode::ODD_FLAG => Some(encoded_key[0] & 0x0f),
ExtensionNode::EVEN_FLAG | LeafNode::EVEN_FLAG => None,
_ => return Err(alloy_rlp::Error::Custom("node is not extension or leaf")),
};
let key = unpack_path_to_nibbles(first, &encoded_key[1..]);
let node = if key_flag == LeafNode::EVEN_FLAG || key_flag == LeafNode::ODD_FLAG {
let value = Bytes::decode(&mut items.remove(0))?.into();
Self::Leaf(LeafNode::new(key, value))
} else {
Self::Extension(ExtensionNode::new(
key,
RlpNode::from_raw_rlp(items.remove(0))?,
))
};
Ok(node)
}
_ => Err(alloy_rlp::Error::Custom("invalid number of items in the list")),
}
}
}
impl TrieNode {
#[inline]
pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
self.encode(rlp);
RlpNode::from_rlp(rlp)
}
}
#[inline]
#[deprecated = "use `RlpNode::from_rlp` instead"]
pub fn rlp_node(rlp: &[u8]) -> RlpNode {
RlpNode::from_rlp(rlp)
}
#[inline]
#[deprecated = "use `RlpNode::word_rlp` instead"]
pub fn word_rlp(word: &B256) -> RlpNode {
RlpNode::word_rlp(word)
}
#[inline]
pub(crate) fn unpack_path_to_nibbles(first: Option<u8>, rest: &[u8]) -> Nibbles {
let Some(first) = first else { return Nibbles::unpack(rest) };
debug_assert!(first <= 0xf);
let len = rest.len() * 2 + 1;
unsafe {
Nibbles::from_repr_unchecked(nybbles::smallvec_with(len, |buf| {
let (f, r) = buf.split_first_mut().unwrap_unchecked();
f.write(first);
Nibbles::unpack_to_unchecked(rest, r);
}))
}
}
#[inline]
pub fn encode_path_leaf(nibbles: &Nibbles, is_leaf: bool) -> SmallVec<[u8; 36]> {
let mut nibbles = nibbles.as_slice();
let encoded_len = nibbles.len() / 2 + 1;
let odd_nibbles = nibbles.len() % 2 != 0;
unsafe {
nybbles::smallvec_with(encoded_len, |buf| {
let (first, rest) = buf.split_first_mut().unwrap_unchecked();
first.write(match (is_leaf, odd_nibbles) {
(true, true) => LeafNode::ODD_FLAG | *nibbles.get_unchecked(0),
(true, false) => LeafNode::EVEN_FLAG,
(false, true) => ExtensionNode::ODD_FLAG | *nibbles.get_unchecked(0),
(false, false) => ExtensionNode::EVEN_FLAG,
});
if odd_nibbles {
nibbles = nibbles.get_unchecked(1..);
}
nybbles::pack_to_unchecked(nibbles, rest);
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TrieMask;
use alloy_primitives::hex;
#[test]
fn rlp_empty_root_node() {
let empty_root = TrieNode::EmptyRoot;
let rlp = empty_root.rlp(&mut vec![]);
assert_eq!(rlp[..], hex!("80"));
assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), empty_root);
}
#[test]
fn rlp_zero_value_leaf_roundtrip() {
let leaf = TrieNode::Leaf(LeafNode::new(
Nibbles::from_nibbles_unchecked(hex!("0604060f")),
alloy_rlp::encode(alloy_primitives::U256::ZERO),
));
let rlp = leaf.rlp(&mut vec![]);
assert_eq!(rlp[..], hex!("c68320646f8180"));
assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
}
#[test]
fn rlp_trie_node_roundtrip() {
let leaf = TrieNode::Leaf(LeafNode::new(
Nibbles::from_nibbles_unchecked(hex!("0604060f")),
hex!("76657262").to_vec(),
));
let rlp = leaf.rlp(&mut vec![]);
assert_eq!(rlp[..], hex!("c98320646f8476657262"));
assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf);
let mut child = vec![];
hex!("76657262").to_vec().as_slice().encode(&mut child);
let extension = TrieNode::Extension(ExtensionNode::new(
Nibbles::from_nibbles_unchecked(hex!("0604060f")),
RlpNode::from_raw(&child).unwrap(),
));
let rlp = extension.rlp(&mut vec![]);
assert_eq!(rlp[..], hex!("c98300646f8476657262"));
assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), extension);
let branch = TrieNode::Branch(BranchNode::new(
core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
TrieMask::new(u16::MAX),
));
let mut rlp = vec![];
let rlp_node = branch.rlp(&mut rlp);
assert_eq!(
rlp_node[..],
hex!("a0bed74980bbe29d9c4439c10e9c451e29b306fe74bcf9795ecf0ebbd92a220513")
);
assert_eq!(rlp, hex!("f90211a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a0171717171717171717171717171717171717171717171717171717171717171780"));
assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), branch);
}
#[test]
fn hashed_encode_path_regression() {
let nibbles = Nibbles::from_nibbles(hex!("05010406040a040203030f010805020b050c04070003070e0909070f010b0a0805020301070c0a0902040b0f000f0006040a04050f020b090701000a0a040b"));
let path = encode_path_leaf(&nibbles, true);
let expected = hex!("351464a4233f1852b5c47037e997f1ba852317ca924bf0f064a45f2b9710aa4b");
assert_eq!(path[..], expected);
}
#[test]
#[cfg(feature = "arbitrary")]
#[cfg_attr(miri, ignore = "no proptest")]
fn encode_path_first_byte() {
use proptest::{collection::vec, prelude::*};
proptest::proptest!(|(input in vec(any::<u8>(), 0..128))| {
let input = Nibbles::unpack(input);
prop_assert!(input.iter().all(|&nibble| nibble <= 0xf));
let input_is_odd = input.len() % 2 == 1;
let compact_leaf = encode_path_leaf(&input, true);
let leaf_flag = compact_leaf[0];
assert_ne!(leaf_flag & LeafNode::EVEN_FLAG, 0);
assert_eq!(input_is_odd, (leaf_flag & ExtensionNode::ODD_FLAG) != 0);
if input_is_odd {
assert_eq!(leaf_flag & 0x0f, input.first().unwrap());
}
let compact_extension = encode_path_leaf(&input, false);
let extension_flag = compact_extension[0];
assert_eq!(extension_flag & LeafNode::EVEN_FLAG, 0);
assert_eq!(input_is_odd, (extension_flag & ExtensionNode::ODD_FLAG) != 0);
if input_is_odd {
assert_eq!(extension_flag & 0x0f, input.first().unwrap());
}
});
}
}