1use crate::{TransportErrorKind, TransportResult};
23use alloy_json_rpc as j;
24use serde::Serialize;
25use std::{
26 borrow::Cow,
27 collections::VecDeque,
28 sync::{Arc, PoisonError, RwLock},
29};
30
31pub type MockResponse = j::ResponsePayload;
33
34#[derive(Debug, Clone, Default)]
40pub struct Asserter {
41 responses: Arc<RwLock<VecDeque<MockResponse>>>,
42}
43
44impl Asserter {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn push(&self, response: MockResponse) {
52 self.write_q().push_back(response);
53 }
54
55 #[track_caller]
61 pub fn push_success<R: Serialize>(&self, response: &R) {
62 let s = serde_json::to_string(response).unwrap();
63 self.push(MockResponse::Success(serde_json::value::RawValue::from_string(s).unwrap()));
64 }
65
66 pub fn push_failure(&self, error: j::ErrorPayload) {
68 self.push(MockResponse::Failure(error));
69 }
70
71 pub fn push_failure_msg(&self, msg: impl Into<Cow<'static, str>>) {
73 self.push_failure(j::ErrorPayload::internal_error_message(msg.into()));
74 }
75
76 pub fn pop_response(&self) -> Option<MockResponse> {
78 self.write_q().pop_front()
79 }
80
81 pub fn read_q(&self) -> impl std::ops::Deref<Target = VecDeque<MockResponse>> + '_ {
83 self.responses.read().unwrap_or_else(PoisonError::into_inner)
84 }
85
86 pub fn write_q(&self) -> impl std::ops::DerefMut<Target = VecDeque<MockResponse>> + '_ {
88 self.responses.write().unwrap_or_else(PoisonError::into_inner)
89 }
90}
91
92#[derive(Clone, Debug)]
96pub struct MockTransport {
97 asserter: Asserter,
98}
99
100impl MockTransport {
101 pub const fn new(asserter: Asserter) -> Self {
103 Self { asserter }
104 }
105
106 pub const fn asserter(&self) -> &Asserter {
108 &self.asserter
109 }
110
111 async fn handle(self, req: j::RequestPacket) -> TransportResult<j::ResponsePacket> {
112 Ok(match req {
113 j::RequestPacket::Single(req) => j::ResponsePacket::Single(self.map_request(req)?),
114 j::RequestPacket::Batch(reqs) => j::ResponsePacket::Batch(
115 reqs.into_iter()
116 .map(|req| self.map_request(req))
117 .collect::<TransportResult<_>>()?,
118 ),
119 })
120 }
121
122 fn map_request(&self, req: j::SerializedRequest) -> TransportResult<j::Response> {
123 Ok(j::Response {
124 id: req.id().clone(),
125 payload: self
126 .asserter
127 .pop_response()
128 .ok_or_else(|| TransportErrorKind::custom_str("empty asserter response queue"))?,
129 })
130 }
131}
132
133impl std::ops::Deref for MockTransport {
134 type Target = Asserter;
135
136 fn deref(&self) -> &Self::Target {
137 &self.asserter
138 }
139}
140
141impl tower::Service<j::RequestPacket> for MockTransport {
142 type Response = j::ResponsePacket;
143 type Error = crate::TransportError;
144 type Future = crate::TransportFut<'static>;
145
146 fn poll_ready(
147 &mut self,
148 _cx: &mut std::task::Context<'_>,
149 ) -> std::task::Poll<Result<(), Self::Error>> {
150 std::task::Poll::Ready(Ok(()))
151 }
152
153 fn call(&mut self, req: j::RequestPacket) -> Self::Future {
154 Box::pin(self.clone().handle(req))
155 }
156}
157
158