use crate::{alloc::string::ToString, Error, Panic, Result, Revert, SolError};
use alloc::{string::String, vec::Vec};
use core::{convert::Infallible, fmt, iter::FusedIterator, marker::PhantomData};
mod event;
pub use event::SolEventInterface;
pub trait SolInterface: Sized {
const NAME: &'static str;
const MIN_DATA_LENGTH: usize;
const COUNT: usize;
fn selector(&self) -> [u8; 4];
fn selector_at(i: usize) -> Option<[u8; 4]>;
fn valid_selector(selector: [u8; 4]) -> bool;
fn type_check(selector: [u8; 4]) -> Result<()> {
if Self::valid_selector(selector) {
Ok(())
} else {
Err(Error::UnknownSelector { name: Self::NAME, selector: selector.into() })
}
}
fn abi_decode_raw(selector: [u8; 4], data: &[u8], validate: bool) -> Result<Self>;
fn abi_encoded_size(&self) -> usize;
fn abi_encode_raw(&self, out: &mut Vec<u8>);
#[inline]
fn selectors() -> Selectors<Self> {
Selectors::new()
}
#[inline]
fn abi_encode(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + self.abi_encoded_size());
out.extend(self.selector());
self.abi_encode_raw(&mut out);
out
}
#[inline]
fn abi_decode(data: &[u8], validate: bool) -> Result<Self> {
if data.len() < Self::MIN_DATA_LENGTH.saturating_add(4) {
Err(crate::Error::type_check_fail(data, Self::NAME))
} else {
let (selector, data) = data.split_first_chunk().unwrap();
Self::abi_decode_raw(*selector, data, validate)
}
}
}
impl SolInterface for Infallible {
const NAME: &'static str = "GenericContractError";
const MIN_DATA_LENGTH: usize = usize::MAX;
const COUNT: usize = 0;
#[inline]
fn selector(&self) -> [u8; 4] {
unreachable!()
}
#[inline]
fn selector_at(_i: usize) -> Option<[u8; 4]> {
None
}
#[inline]
fn valid_selector(_selector: [u8; 4]) -> bool {
false
}
#[inline]
fn abi_decode_raw(selector: [u8; 4], _data: &[u8], _validate: bool) -> Result<Self> {
Self::type_check(selector).map(|()| unreachable!())
}
#[inline]
fn abi_encoded_size(&self) -> usize {
unreachable!()
}
#[inline]
fn abi_encode_raw(&self, _out: &mut Vec<u8>) {
unreachable!()
}
}
pub type GenericContractError = ContractError<Infallible>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ContractError<T> {
CustomError(T),
Revert(Revert),
Panic(Panic),
}
impl<T: SolInterface> From<T> for ContractError<T> {
#[inline]
fn from(value: T) -> Self {
Self::CustomError(value)
}
}
impl<T> From<Revert> for ContractError<T> {
#[inline]
fn from(value: Revert) -> Self {
Self::Revert(value)
}
}
impl<T> TryFrom<ContractError<T>> for Revert {
type Error = ContractError<T>;
#[inline]
fn try_from(value: ContractError<T>) -> Result<Self, Self::Error> {
match value {
ContractError::Revert(inner) => Ok(inner),
_ => Err(value),
}
}
}
impl<T> From<Panic> for ContractError<T> {
#[inline]
fn from(value: Panic) -> Self {
Self::Panic(value)
}
}
impl<T> TryFrom<ContractError<T>> for Panic {
type Error = ContractError<T>;
#[inline]
fn try_from(value: ContractError<T>) -> Result<Self, Self::Error> {
match value {
ContractError::Panic(inner) => Ok(inner),
_ => Err(value),
}
}
}
impl<T: fmt::Display> fmt::Display for ContractError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CustomError(error) => error.fmt(f),
Self::Panic(panic) => panic.fmt(f),
Self::Revert(revert) => revert.fmt(f),
}
}
}
impl<T: core::error::Error + 'static> core::error::Error for ContractError<T> {
#[inline]
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match self {
Self::CustomError(error) => Some(error),
Self::Panic(panic) => Some(panic),
Self::Revert(revert) => Some(revert),
}
}
}
impl<T: SolInterface> SolInterface for ContractError<T> {
const NAME: &'static str = "ContractError";
const MIN_DATA_LENGTH: usize = if T::MIN_DATA_LENGTH < 32 { T::MIN_DATA_LENGTH } else { 32 };
const COUNT: usize = T::COUNT + 2;
#[inline]
fn selector(&self) -> [u8; 4] {
match self {
Self::CustomError(error) => error.selector(),
Self::Panic(_) => Panic::SELECTOR,
Self::Revert(_) => Revert::SELECTOR,
}
}
#[inline]
fn selector_at(i: usize) -> Option<[u8; 4]> {
if i < T::COUNT {
T::selector_at(i)
} else {
match i - T::COUNT {
0 => Some(Revert::SELECTOR),
1 => Some(Panic::SELECTOR),
_ => None,
}
}
}
#[inline]
fn valid_selector(selector: [u8; 4]) -> bool {
match selector {
Revert::SELECTOR | Panic::SELECTOR => true,
s => T::valid_selector(s),
}
}
#[inline]
fn abi_decode_raw(selector: [u8; 4], data: &[u8], validate: bool) -> Result<Self> {
match selector {
Revert::SELECTOR => Revert::abi_decode_raw(data, validate).map(Self::Revert),
Panic::SELECTOR => Panic::abi_decode_raw(data, validate).map(Self::Panic),
s => T::abi_decode_raw(s, data, validate).map(Self::CustomError),
}
}
#[inline]
fn abi_encoded_size(&self) -> usize {
match self {
Self::CustomError(error) => error.abi_encoded_size(),
Self::Panic(panic) => panic.abi_encoded_size(),
Self::Revert(revert) => revert.abi_encoded_size(),
}
}
#[inline]
fn abi_encode_raw(&self, out: &mut Vec<u8>) {
match self {
Self::CustomError(error) => error.abi_encode_raw(out),
Self::Panic(panic) => panic.abi_encode_raw(out),
Self::Revert(revert) => revert.abi_encode_raw(out),
}
}
}
impl<T> ContractError<T> {
#[inline]
pub const fn is_custom_error(&self) -> bool {
matches!(self, Self::CustomError(_))
}
#[inline]
pub const fn as_custom_error(&self) -> Option<&T> {
match self {
Self::CustomError(inner) => Some(inner),
_ => None,
}
}
#[inline]
pub fn as_custom_error_mut(&mut self) -> Option<&mut T> {
match self {
Self::CustomError(inner) => Some(inner),
_ => None,
}
}
#[inline]
pub const fn is_revert(&self) -> bool {
matches!(self, Self::Revert(_))
}
#[inline]
pub const fn as_revert(&self) -> Option<&Revert> {
match self {
Self::Revert(inner) => Some(inner),
_ => None,
}
}
#[inline]
pub fn as_revert_mut(&mut self) -> Option<&mut Revert> {
match self {
Self::Revert(inner) => Some(inner),
_ => None,
}
}
#[inline]
pub const fn is_panic(&self) -> bool {
matches!(self, Self::Panic(_))
}
#[inline]
pub const fn as_panic(&self) -> Option<&Panic> {
match self {
Self::Panic(inner) => Some(inner),
_ => None,
}
}
#[inline]
pub fn as_panic_mut(&mut self) -> Option<&mut Panic> {
match self {
Self::Panic(inner) => Some(inner),
_ => None,
}
}
}
pub type GenericRevertReason = RevertReason<Infallible>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum RevertReason<T> {
ContractError(ContractError<T>),
RawString(String),
}
impl<T: fmt::Display> fmt::Display for RevertReason<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ContractError(error) => error.fmt(f),
Self::RawString(raw_string) => f.write_str(raw_string),
}
}
}
impl<T> From<ContractError<T>> for RevertReason<T> {
fn from(error: ContractError<T>) -> Self {
Self::ContractError(error)
}
}
impl<T> From<Revert> for RevertReason<T> {
fn from(revert: Revert) -> Self {
Self::ContractError(ContractError::Revert(revert))
}
}
impl<T> From<String> for RevertReason<T> {
fn from(raw_string: String) -> Self {
Self::RawString(raw_string)
}
}
impl<T: SolInterface> RevertReason<T>
where
Self: From<ContractError<Infallible>>,
{
pub fn decode(out: &[u8]) -> Option<Self> {
if let Ok(error) = ContractError::<T>::abi_decode(out, false) {
return Some(error.into());
}
if let Ok(decoded_string) = core::str::from_utf8(out) {
return Some(decoded_string.to_string().into());
}
None
}
}
impl<T: SolInterface + fmt::Display> RevertReason<T> {
#[allow(clippy::inherent_to_string_shadow_display)]
pub fn to_string(&self) -> String {
match self {
Self::ContractError(error) => error.to_string(),
Self::RawString(raw_string) => raw_string.clone(),
}
}
}
pub struct Selectors<T> {
index: usize,
_marker: PhantomData<T>,
}
impl<T> Clone for Selectors<T> {
fn clone(&self) -> Self {
Self { index: self.index, _marker: PhantomData }
}
}
impl<T> fmt::Debug for Selectors<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Selectors").field("index", &self.index).finish()
}
}
impl<T> Selectors<T> {
#[inline]
const fn new() -> Self {
Self { index: 0, _marker: PhantomData }
}
}
impl<T: SolInterface> Iterator for Selectors<T> {
type Item = [u8; 4];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let selector = T::selector_at(self.index)?;
self.index += 1;
Some(selector)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let exact = self.len();
(exact, Some(exact))
}
#[inline]
fn count(self) -> usize {
self.len()
}
}
impl<T: SolInterface> ExactSizeIterator for Selectors<T> {
#[inline]
fn len(&self) -> usize {
T::COUNT - self.index
}
}
impl<T: SolInterface> FusedIterator for Selectors<T> {}
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::{keccak256, U256};
fn sel(s: &str) -> [u8; 4] {
keccak256(s)[..4].try_into().unwrap()
}
#[test]
fn generic_contract_error_enum() {
assert_eq!(
GenericContractError::selectors().collect::<Vec<_>>(),
[sel("Error(string)"), sel("Panic(uint256)")]
);
}
#[test]
fn contract_error_enum_1() {
crate::sol! {
contract C {
error Err1();
}
}
assert_eq!(C::CErrors::COUNT, 1);
assert_eq!(C::CErrors::MIN_DATA_LENGTH, 0);
assert_eq!(ContractError::<C::CErrors>::COUNT, 1 + 2);
assert_eq!(ContractError::<C::CErrors>::MIN_DATA_LENGTH, 0);
assert_eq!(C::CErrors::SELECTORS, [sel("Err1()")]);
assert_eq!(
ContractError::<C::CErrors>::selectors().collect::<Vec<_>>(),
vec![sel("Err1()"), sel("Error(string)"), sel("Panic(uint256)")],
);
for selector in C::CErrors::selectors() {
assert!(C::CErrors::valid_selector(selector));
}
for selector in ContractError::<C::CErrors>::selectors() {
assert!(ContractError::<C::CErrors>::valid_selector(selector));
}
}
#[test]
fn contract_error_enum_2() {
crate::sol! {
#[derive(Debug, PartialEq, Eq)]
contract C {
error Err1();
error Err2(uint256);
error Err3(string);
}
}
assert_eq!(C::CErrors::COUNT, 3);
assert_eq!(C::CErrors::MIN_DATA_LENGTH, 0);
assert_eq!(ContractError::<C::CErrors>::COUNT, 2 + 3);
assert_eq!(ContractError::<C::CErrors>::MIN_DATA_LENGTH, 0);
assert_eq!(
C::CErrors::SELECTORS,
[sel("Err3(string)"), sel("Err2(uint256)"), sel("Err1()")]
);
assert_eq!(
ContractError::<C::CErrors>::selectors().collect::<Vec<_>>(),
[
sel("Err3(string)"),
sel("Err2(uint256)"),
sel("Err1()"),
sel("Error(string)"),
sel("Panic(uint256)"),
],
);
let err1 = || C::Err1 {};
let errors_err1 = || C::CErrors::Err1(err1());
let contract_error_err1 = || ContractError::<C::CErrors>::CustomError(errors_err1());
let data = err1().abi_encode();
assert_eq!(data[..4], C::Err1::SELECTOR);
assert_eq!(errors_err1().abi_encode(), data);
assert_eq!(contract_error_err1().abi_encode(), data);
assert_eq!(C::Err1::abi_decode(&data, true), Ok(err1()));
assert_eq!(C::CErrors::abi_decode(&data, true), Ok(errors_err1()));
assert_eq!(ContractError::<C::CErrors>::abi_decode(&data, true), Ok(contract_error_err1()));
let err2 = || C::Err2 { _0: U256::from(42) };
let errors_err2 = || C::CErrors::Err2(err2());
let contract_error_err2 = || ContractError::<C::CErrors>::CustomError(errors_err2());
let data = err2().abi_encode();
assert_eq!(data[..4], C::Err2::SELECTOR);
assert_eq!(errors_err2().abi_encode(), data);
assert_eq!(contract_error_err2().abi_encode(), data);
assert_eq!(C::Err2::abi_decode(&data, true), Ok(err2()));
assert_eq!(C::CErrors::abi_decode(&data, true), Ok(errors_err2()));
assert_eq!(ContractError::<C::CErrors>::abi_decode(&data, true), Ok(contract_error_err2()));
let err3 = || C::Err3 { _0: "hello".into() };
let errors_err3 = || C::CErrors::Err3(err3());
let contract_error_err3 = || ContractError::<C::CErrors>::CustomError(errors_err3());
let data = err3().abi_encode();
assert_eq!(data[..4], C::Err3::SELECTOR);
assert_eq!(errors_err3().abi_encode(), data);
assert_eq!(contract_error_err3().abi_encode(), data);
assert_eq!(C::Err3::abi_decode(&data, true), Ok(err3()));
assert_eq!(C::CErrors::abi_decode(&data, true), Ok(errors_err3()));
assert_eq!(ContractError::<C::CErrors>::abi_decode(&data, true), Ok(contract_error_err3()));
for selector in C::CErrors::selectors() {
assert!(C::CErrors::valid_selector(selector));
}
for selector in ContractError::<C::CErrors>::selectors() {
assert!(ContractError::<C::CErrors>::valid_selector(selector));
}
}
}