1use std::{
10 any::Any,
11 borrow::Cow,
12 collections::HashMap,
13 sync::{
14 atomic::{AtomicUsize, Ordering},
15 Arc, Mutex, MutexGuard,
16 },
17};
18
19use frunk::{hlist, hlist_pat, HList};
20
21use super::{
22 GuestPointer, Instance, InstanceWithFunction, InstanceWithMemory, Runtime, RuntimeError,
23 RuntimeMemory,
24};
25use crate::{memory_layout::FlatLayout, ExportFunction, WitLoad, WitStore};
26
27pub struct MockRuntime;
29
30impl Runtime for MockRuntime {
31 type Export = String;
32 type Memory = Arc<Mutex<Vec<u8>>>;
33}
34
35pub type FunctionHandler<UserData> =
37 Arc<dyn Fn(MockInstance<UserData>, Box<dyn Any>) -> Result<Box<dyn Any>, RuntimeError>>;
38
39pub struct MockInstance<UserData> {
43 memory: Arc<Mutex<Vec<u8>>>,
44 exported_functions: HashMap<String, FunctionHandler<UserData>>,
45 imported_functions: HashMap<String, FunctionHandler<UserData>>,
46 user_data: Arc<Mutex<UserData>>,
47}
48
49impl<UserData> Default for MockInstance<UserData>
50where
51 UserData: Default,
52{
53 fn default() -> Self {
54 MockInstance::new(UserData::default())
55 }
56}
57
58impl<UserData> Clone for MockInstance<UserData> {
59 fn clone(&self) -> Self {
60 MockInstance {
61 memory: self.memory.clone(),
62 exported_functions: self.exported_functions.clone(),
63 imported_functions: self.imported_functions.clone(),
64 user_data: self.user_data.clone(),
65 }
66 }
67}
68
69impl<UserData> MockInstance<UserData> {
70 pub fn new(user_data: UserData) -> Self {
72 let memory = Arc::new(Mutex::new(Vec::new()));
73
74 MockInstance {
75 memory: memory.clone(),
76 exported_functions: HashMap::new(),
77 imported_functions: HashMap::new(),
78 user_data: Arc::new(Mutex::new(user_data)),
79 }
80 .with_exported_function("cabi_free", |_, _: HList![i32]| Ok(hlist![]))
81 .with_exported_function(
82 "cabi_realloc",
83 move |_,
84 hlist_pat![_old_address, _old_size, alignment, new_size]: HList![
85 i32, i32, i32, i32
86 ]| {
87 let allocation_size = usize::try_from(new_size)
88 .expect("Failed to allocate a negative amount of memory");
89
90 let mut memory = memory
91 .lock()
92 .expect("Panic while holding a lock to a `MockInstance`'s memory");
93
94 let address = GuestPointer(memory.len().try_into()?).aligned_at(alignment as u32);
95
96 memory.resize(address.0 as usize + allocation_size, 0);
97
98 assert!(
99 memory.len() <= i32::MAX as usize,
100 "No more memory for allocations"
101 );
102
103 Ok(hlist![address.0 as i32])
104 },
105 )
106 }
107 pub fn with_exported_function<Parameters, Results, Handler>(
111 mut self,
112 name: impl Into<String>,
113 handler: Handler,
114 ) -> Self
115 where
116 Parameters: 'static,
117 Results: 'static,
118 Handler: Fn(MockInstance<UserData>, Parameters) -> Result<Results, RuntimeError> + 'static,
119 {
120 self.add_exported_function(name, handler);
121 self
122 }
123
124 pub fn add_exported_function<Parameters, Results, Handler>(
128 &mut self,
129 name: impl Into<String>,
130 handler: Handler,
131 ) -> &mut Self
132 where
133 Parameters: 'static,
134 Results: 'static,
135 Handler: Fn(MockInstance<UserData>, Parameters) -> Result<Results, RuntimeError> + 'static,
136 {
137 self.exported_functions.insert(
138 name.into(),
139 Arc::new(move |caller, boxed_parameters| {
140 let parameters = boxed_parameters
141 .downcast()
142 .expect("Incorrect parameters used to call handler for exported function");
143
144 handler(caller, *parameters).map(|results| Box::new(results) as Box<dyn Any>)
145 }),
146 );
147 self
148 }
149
150 pub fn call_imported_function<Parameters, Results>(
152 &self,
153 function: &str,
154 parameters: Parameters,
155 ) -> Result<Results, RuntimeError>
156 where
157 Parameters: WitStore + 'static,
158 Results: WitLoad + 'static,
159 {
160 let handler = self
161 .imported_functions
162 .get(function)
163 .unwrap_or_else(|| panic!("Missing function imported from host: {function:?}"));
164
165 let flat_parameters = parameters.lower(&mut self.clone().memory()?)?;
166 let boxed_flat_results = handler(self.clone(), Box::new(flat_parameters))?;
167 let flat_results = *boxed_flat_results
168 .downcast()
169 .expect("Expected an incorrect results type from imported host function");
170
171 Results::lift_from(flat_results, &self.clone().memory()?)
172 }
173
174 pub fn memory_contents(&self) -> Vec<u8> {
176 self.memory.lock().unwrap().clone()
177 }
178}
179
180impl<UserData> Instance for MockInstance<UserData> {
181 type Runtime = MockRuntime;
182 type UserData = UserData;
183 type UserDataReference<'a>
184 = MutexGuard<'a, UserData>
185 where
186 Self::UserData: 'a,
187 Self: 'a;
188 type UserDataMutReference<'a>
189 = MutexGuard<'a, UserData>
190 where
191 Self::UserData: 'a,
192 Self: 'a;
193
194 fn load_export(&mut self, name: &str) -> Option<String> {
195 if name == "memory" || self.exported_functions.contains_key(name) {
196 Some(name.to_owned())
197 } else {
198 None
199 }
200 }
201
202 fn user_data(&self) -> Self::UserDataReference<'_> {
203 self.user_data
204 .try_lock()
205 .expect("Unexpected reentrant access to user data in `MockInstance`")
206 }
207
208 fn user_data_mut(&mut self) -> Self::UserDataMutReference<'_> {
209 self.user_data
210 .try_lock()
211 .expect("Unexpected reentrant access to user data in `MockInstance`")
212 }
213}
214
215impl<Parameters, Results, UserData> InstanceWithFunction<Parameters, Results>
216 for MockInstance<UserData>
217where
218 Parameters: FlatLayout + 'static,
219 Results: FlatLayout + 'static,
220{
221 type Function = String;
222
223 fn function_from_export(
224 &mut self,
225 name: <Self::Runtime as Runtime>::Export,
226 ) -> Result<Option<Self::Function>, RuntimeError> {
227 Ok(Some(name))
228 }
229
230 fn call(
231 &mut self,
232 function: &Self::Function,
233 parameters: Parameters,
234 ) -> Result<Results, RuntimeError> {
235 let handler = self
236 .exported_functions
237 .get(function)
238 .ok_or_else(|| RuntimeError::FunctionNotFound(function.clone()))?;
239
240 let results = handler(self.clone(), Box::new(parameters))?;
241
242 Ok(*results.downcast().unwrap_or_else(|_| {
243 panic!("Incorrect results type expected from handler of expected function: {function}")
244 }))
245 }
246}
247
248impl<UserData> RuntimeMemory<MockInstance<UserData>> for Arc<Mutex<Vec<u8>>> {
249 fn read<'instance>(
250 &self,
251 instance: &'instance MockInstance<UserData>,
252 location: GuestPointer,
253 length: u32,
254 ) -> Result<Cow<'instance, [u8]>, RuntimeError> {
255 let memory = instance
256 .memory
257 .lock()
258 .expect("Panic while holding a lock to a `MockInstance`'s memory");
259
260 let start = location.0 as usize;
261 let end = start + length as usize;
262
263 Ok(Cow::Owned(memory[start..end].to_owned()))
264 }
265
266 fn write(
267 &mut self,
268 instance: &mut MockInstance<UserData>,
269 location: GuestPointer,
270 bytes: &[u8],
271 ) -> Result<(), RuntimeError> {
272 let mut memory = instance
273 .memory
274 .lock()
275 .expect("Panic while holding a lock to a `MockInstance`'s memory");
276
277 let start = location.0 as usize;
278 let end = start + bytes.len();
279
280 memory[start..end].copy_from_slice(bytes);
281
282 Ok(())
283 }
284}
285
286impl<UserData> InstanceWithMemory for MockInstance<UserData> {
287 fn memory_from_export(
288 &self,
289 export: String,
290 ) -> Result<Option<Arc<Mutex<Vec<u8>>>>, RuntimeError> {
291 if export == "memory" {
292 Ok(Some(self.memory.clone()))
293 } else {
294 Err(RuntimeError::NotMemory)
295 }
296 }
297}
298
299impl<Handler, Parameters, Results, UserData> ExportFunction<Handler, Parameters, Results>
300 for MockInstance<UserData>
301where
302 Handler: Fn(MockInstance<UserData>, Parameters) -> Result<Results, RuntimeError> + 'static,
303 Parameters: 'static,
304 Results: 'static,
305{
306 fn export(
307 &mut self,
308 module_name: &str,
309 function_name: &str,
310 handler: Handler,
311 ) -> Result<(), RuntimeError> {
312 let name = format!("{module_name}#{function_name}");
313
314 self.imported_functions.insert(
315 name.clone(),
316 Arc::new(move |instance, boxed_parameters| {
317 let parameters = boxed_parameters.downcast().unwrap_or_else(|_| {
318 panic!(
319 "Incorrect parameters used to call handler for exported function {name:?}"
320 )
321 });
322
323 let results = handler(instance, *parameters)?;
324
325 Ok(Box::new(results))
326 }),
327 );
328
329 Ok(())
330 }
331}
332
333pub trait MockResults {
338 type Results;
340}
341
342impl<T> MockResults for T {
343 type Results = T;
344}
345
346pub struct MockExportedFunction<Parameters, Results, UserData> {
348 name: String,
349 call_counter: Arc<AtomicUsize>,
350 expected_calls: usize,
351 handler: Arc<dyn Fn(MockInstance<UserData>, Parameters) -> Result<Results, RuntimeError>>,
352}
353
354impl<Parameters, Results, UserData> MockExportedFunction<Parameters, Results, UserData>
355where
356 Parameters: 'static,
357 Results: 'static,
358 UserData: 'static,
359{
360 pub fn new(
367 name: impl Into<String>,
368 handler: impl Fn(MockInstance<UserData>, Parameters) -> Result<Results, RuntimeError> + 'static,
369 expected_calls: usize,
370 ) -> Self {
371 MockExportedFunction {
372 name: name.into(),
373 call_counter: Arc::default(),
374 expected_calls,
375 handler: Arc::new(handler),
376 }
377 }
378
379 pub fn register(&self, instance: &mut MockInstance<UserData>) {
381 let call_counter = self.call_counter.clone();
382 let handler = self.handler.clone();
383
384 instance.add_exported_function(self.name.clone(), move |caller, parameters: Parameters| {
385 call_counter.fetch_add(1, Ordering::AcqRel);
386 handler(caller, parameters)
387 });
388 }
389}
390
391impl<Parameters, Results, UserData> Drop for MockExportedFunction<Parameters, Results, UserData> {
392 fn drop(&mut self) {
393 assert_eq!(
394 self.call_counter.load(Ordering::Acquire),
395 self.expected_calls,
396 "Unexpected number of calls to `{}`",
397 self.name
398 );
399 }
400}