1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! Abstraction over how different runtimes manipulate the guest WebAssembly module's memory.

use std::borrow::Cow;

use frunk::{hlist, hlist_pat, HList};

use super::{
    traits::{CabiFreeAlias, CabiReallocAlias},
    InstanceWithFunction, Runtime, RuntimeError,
};
use crate::{Layout, WitType};

#[cfg(test)]
#[path = "unit_tests/memory.rs"]
mod tests;

/// An address for a location in a guest WebAssembly module's memory.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GuestPointer(pub(crate) u32);

impl GuestPointer {
    /// Returns a new address that's the current address advanced to add padding to ensure it's
    /// aligned to the `alignment` byte boundary.
    pub const fn aligned_at(&self, alignment: u32) -> Self {
        // The following computation is equivalent to:
        // `(alignment - (self.0 % alignment)) % alignment`.
        // Source: https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
        let padding = (-(self.0 as i32) & (alignment as i32 - 1)) as u32;

        GuestPointer(self.0 + padding)
    }

    /// Returns a new address that's the current address advanced to after the size of `T`.
    pub const fn after<T: WitType>(&self) -> Self {
        GuestPointer(self.0 + T::SIZE)
    }

    /// Returns a new address that's the current address advanced to add padding to ensure it's
    /// aligned properly for `T`.
    pub const fn after_padding_for<T: WitType>(&self) -> Self {
        self.aligned_at(<T::Layout as Layout>::ALIGNMENT)
    }

    /// Returns the address of an element in a contiguous list of properly aligned `T` types.
    pub const fn index<T: WitType>(&self, index: u32) -> Self {
        let element_size = GuestPointer(T::SIZE).after_padding_for::<T>();

        GuestPointer(self.0 + index * element_size.0)
    }
}

/// Interface for accessing a runtime specific memory.
pub trait RuntimeMemory<Instance> {
    /// Reads `length` bytes from memory from the provided `location`.
    fn read<'instance>(
        &self,
        instance: &'instance Instance,
        location: GuestPointer,
        length: u32,
    ) -> Result<Cow<'instance, [u8]>, RuntimeError>;

    /// Writes the `bytes` to memory at the provided `location`.
    fn write(
        &mut self,
        instance: &mut Instance,
        location: GuestPointer,
        bytes: &[u8],
    ) -> Result<(), RuntimeError>;
}

/// A handle to interface with a guest Wasm module instance's memory.
#[expect(clippy::type_complexity)]
pub struct Memory<'runtime, Instance>
where
    Instance: CabiReallocAlias + CabiFreeAlias,
{
    instance: &'runtime mut Instance,
    memory: <Instance::Runtime as Runtime>::Memory,
    cabi_realloc: Option<
        <Instance as InstanceWithFunction<HList![i32, i32, i32, i32], HList![i32]>>::Function,
    >,
    cabi_free: Option<<Instance as InstanceWithFunction<HList![i32], HList![]>>::Function>,
}

impl<'runtime, Instance> Memory<'runtime, Instance>
where
    Instance: CabiReallocAlias + CabiFreeAlias,
{
    /// Creates a new [`Memory`] instance using a Wasm module `instance` and its `memory` export.
    pub(super) fn new(
        instance: &'runtime mut Instance,
        memory: <Instance::Runtime as Runtime>::Memory,
    ) -> Self {
        Memory {
            instance,
            memory,
            cabi_realloc: None,
            cabi_free: None,
        }
    }
}

impl<Instance> Memory<'_, Instance>
where
    Instance: CabiReallocAlias + CabiFreeAlias,
    <Instance::Runtime as Runtime>::Memory: RuntimeMemory<Instance>,
{
    /// Reads `length` bytes from `location`.
    ///
    /// The underlying runtime may return either a memory slice or an owned buffer.
    pub fn read(&self, location: GuestPointer, length: u32) -> Result<Cow<'_, [u8]>, RuntimeError> {
        self.memory.read(&*self.instance, location, length)
    }

    /// Writes `bytes` to `location`.
    pub fn write(&mut self, location: GuestPointer, bytes: &[u8]) -> Result<(), RuntimeError> {
        self.memory.write(&mut *self.instance, location, bytes)
    }

    /// Returns a newly allocated buffer of `size` bytes in the guest module's memory
    /// aligned to the requested `alignment`.
    ///
    /// Calls the guest module to allocate the memory, so the resulting allocation is managed by
    /// the guest.
    pub fn allocate(&mut self, size: u32, alignment: u32) -> Result<GuestPointer, RuntimeError> {
        if self.cabi_realloc.is_none() {
            self.cabi_realloc = Some(<Instance as InstanceWithFunction<
                HList![i32, i32, i32, i32],
                HList![i32],
            >>::load_function(self.instance, "cabi_realloc")?);
        }

        let size = i32::try_from(size).map_err(|_| RuntimeError::AllocationTooLarge)?;
        let alignment = i32::try_from(alignment).map_err(|_| RuntimeError::InvalidAlignment)?;

        let cabi_realloc = self
            .cabi_realloc
            .as_ref()
            .expect("`cabi_realloc` function was not loaded before it was called");

        let hlist_pat![allocation_address] = self
            .instance
            .call(cabi_realloc, hlist![0, 0, alignment, size])?;

        Ok(GuestPointer(
            allocation_address
                .try_into()
                .map_err(|_| RuntimeError::AllocationFailed)?,
        ))
    }

    /// Deallocates the `allocation` managed by the guest.
    pub fn deallocate(&mut self, allocation: GuestPointer) -> Result<(), RuntimeError> {
        if self.cabi_free.is_none() {
            self.cabi_free = Some(
                <Instance as InstanceWithFunction<HList![i32], HList![]>>::load_function(
                    self.instance,
                    "cabi_free",
                )?,
            );
        }

        let address = allocation
            .0
            .try_into()
            .map_err(|_| RuntimeError::DeallocateInvalidAddress)?;

        let cabi_free = self
            .cabi_free
            .as_ref()
            .expect("`cabi_free` function was not loaded before it was called");

        self.instance.call(cabi_free, hlist![address])?;

        Ok(())
    }
}