use alloc::collections::BTreeMap;
use alloy_primitives::{ruint::ParseError, Bytes, B256, U256};
use core::{fmt, str::FromStr};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum JsonStorageKey {
Hash(B256),
Number(U256),
}
impl JsonStorageKey {
pub fn as_b256(&self) -> B256 {
match self {
Self::Hash(hash) => *hash,
Self::Number(num) => B256::from(*num),
}
}
}
impl Default for JsonStorageKey {
fn default() -> Self {
Self::Hash(Default::default())
}
}
impl From<B256> for JsonStorageKey {
fn from(value: B256) -> Self {
Self::Hash(value)
}
}
impl From<[u8; 32]> for JsonStorageKey {
fn from(value: [u8; 32]) -> Self {
B256::from(value).into()
}
}
impl From<U256> for JsonStorageKey {
fn from(value: U256) -> Self {
Self::Number(value)
}
}
impl FromStr for JsonStorageKey {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(hash) = B256::from_str(s) {
return Ok(Self::Hash(hash));
}
s.parse().map(Self::Number)
}
}
impl fmt::Display for JsonStorageKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Hash(hash) => hash.fmt(f),
Self::Number(num) => alloc::format!("{num:#x}").fmt(f),
}
}
}
pub fn from_bytes_to_b256<'de, D>(bytes: Bytes) -> Result<B256, D::Error>
where
D: Deserializer<'de>,
{
if bytes.0.len() > 32 {
return Err(serde::de::Error::custom("input too long to be a B256"));
}
let mut padded = [0u8; 32];
padded[32 - bytes.0.len()..].copy_from_slice(&bytes.0);
Ok(B256::from_slice(&padded))
}
pub fn deserialize_storage_map<'de, D>(
deserializer: D,
) -> Result<Option<BTreeMap<B256, B256>>, D::Error>
where
D: Deserializer<'de>,
{
let map = Option::<BTreeMap<Bytes, Bytes>>::deserialize(deserializer)?;
match map {
Some(map) => {
let mut res_map = BTreeMap::new();
for (k, v) in map {
let k_deserialized = from_bytes_to_b256::<'de, D>(k)?;
let v_deserialized = from_bytes_to_b256::<'de, D>(v)?;
res_map.insert(k_deserialized, v_deserialized);
}
Ok(Some(res_map))
}
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::{String, ToString};
use serde_json::json;
#[test]
fn default_number_storage_key() {
let key = JsonStorageKey::Number(Default::default());
assert_eq!(key.to_string(), String::from("0x0"));
}
#[test]
fn default_hash_storage_key() {
let key = JsonStorageKey::default();
assert_eq!(
key.to_string(),
String::from("0x0000000000000000000000000000000000000000000000000000000000000000")
);
}
#[test]
fn test_storage_key() {
let cases = [
"0x0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001", ];
let key: JsonStorageKey = serde_json::from_str(&json!(cases[0]).to_string()).unwrap();
let key2: JsonStorageKey = serde_json::from_str(&json!(cases[1]).to_string()).unwrap();
assert_eq!(key.as_b256(), key2.as_b256());
}
#[test]
fn test_storage_key_serde_roundtrips() {
let test_cases = [
"0x0000000000000000000000000000000000000000000000000000000000000001", "0x0000000000000000000000000000000000000000000000000000000000000abc", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xabc", "0xabcd", ];
for input in test_cases {
let key: JsonStorageKey = serde_json::from_str(&json!(input).to_string()).unwrap();
let output = key.to_string();
assert_eq!(
input, output,
"Storage key roundtrip failed to preserve the exact hex representation for {}",
input
);
}
}
#[test]
fn test_as_b256() {
let cases = [
"0x0abc", "0x0000000000000000000000000000000000000000000000000000000000000abc", ];
let num_key: JsonStorageKey = serde_json::from_str(&json!(cases[0]).to_string()).unwrap();
let hash_key: JsonStorageKey = serde_json::from_str(&json!(cases[1]).to_string()).unwrap();
assert_eq!(num_key, JsonStorageKey::Number(U256::from_str(cases[0]).unwrap()));
assert_eq!(hash_key, JsonStorageKey::Hash(B256::from_str(cases[1]).unwrap()));
assert_eq!(num_key.as_b256(), hash_key.as_b256());
}
#[test]
fn test_json_storage_key_from_b256() {
let b256_value = B256::from([1u8; 32]);
let key = JsonStorageKey::from(b256_value);
assert_eq!(key, JsonStorageKey::Hash(b256_value));
assert_eq!(
key.to_string(),
"0x0101010101010101010101010101010101010101010101010101010101010101"
);
}
#[test]
fn test_json_storage_key_from_u256() {
let u256_value = U256::from(42);
let key = JsonStorageKey::from(u256_value);
assert_eq!(key, JsonStorageKey::Number(u256_value));
assert_eq!(key.to_string(), "0x2a");
}
#[test]
fn test_json_storage_key_from_u8_array() {
let bytes = [0u8; 32];
let key = JsonStorageKey::from(bytes);
assert_eq!(key, JsonStorageKey::Hash(B256::from(bytes)));
}
#[test]
fn test_from_str_parsing() {
let hex_str = "0x0101010101010101010101010101010101010101010101010101010101010101";
let key = JsonStorageKey::from_str(hex_str).unwrap();
assert_eq!(key, JsonStorageKey::Hash(B256::from_str(hex_str).unwrap()));
let num_str = "42";
let key = JsonStorageKey::from_str(num_str).unwrap();
assert_eq!(key, JsonStorageKey::Number(U256::from(42)));
}
#[test]
fn test_deserialize_storage_map_with_valid_data() {
let json_data = json!({
"0x0000000000000000000000000000000000000000000000000000000000000001": "0x22",
"0x0000000000000000000000000000000000000000000000000000000000000002": "0x33"
});
let deserialized: Option<BTreeMap<B256, B256>> = deserialize_storage_map(
&serde_json::from_value::<serde_json::Value>(json_data).unwrap(),
)
.unwrap();
assert_eq!(
deserialized.unwrap(),
BTreeMap::from([
(B256::from(U256::from(1u128)), B256::from(U256::from(0x22u128))),
(B256::from(U256::from(2u128)), B256::from(U256::from(0x33u128)))
])
);
}
#[test]
fn test_deserialize_storage_map_with_empty_data() {
let json_data = json!({});
let deserialized: Option<BTreeMap<B256, B256>> = deserialize_storage_map(
&serde_json::from_value::<serde_json::Value>(json_data).unwrap(),
)
.unwrap();
assert!(deserialized.unwrap().is_empty());
}
#[test]
fn test_deserialize_storage_map_with_none() {
let json_data = json!(null);
let deserialized: Option<BTreeMap<B256, B256>> = deserialize_storage_map(
&serde_json::from_value::<serde_json::Value>(json_data).unwrap(),
)
.unwrap();
assert!(deserialized.is_none());
}
#[test]
fn test_from_bytes_to_b256_with_valid_input() {
let bytes = Bytes::from(vec![0x1, 0x2, 0x3, 0x4]);
let result = from_bytes_to_b256::<serde_json::Value>(bytes).unwrap();
let expected = B256::from_slice(&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
2, 3, 4,
]);
assert_eq!(result, expected);
}
#[test]
fn test_from_bytes_to_b256_with_exact_32_bytes() {
let bytes = Bytes::from(vec![
0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x10, 0x11,
0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
0x20,
]);
let result = from_bytes_to_b256::<serde_json::Value>(bytes).unwrap();
let expected = B256::from_slice(&[
0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x10, 0x11,
0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
0x20,
]);
assert_eq!(result, expected);
}
#[test]
fn test_from_bytes_to_b256_with_input_too_long() {
let bytes = Bytes::from(vec![0x1; 33]); let result = from_bytes_to_b256::<serde_json::Value>(bytes);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), "input too long to be a B256");
}
#[test]
fn test_from_bytes_to_b256_with_empty_input() {
let bytes = Bytes::from(vec![]);
let result = from_bytes_to_b256::<serde_json::Value>(bytes).unwrap();
let expected = B256::from_slice(&[0; 32]); assert_eq!(result, expected);
}
}