shared_buffer/
mmap.rs

1use std::{
2    fmt::Display,
3    fs::File,
4    ops::{Bound, RangeBounds},
5    path::{Path, PathBuf},
6    sync::Arc,
7};
8
9use bytes::Buf;
10use memmap2::Mmap;
11
12#[derive(Debug, Clone)]
13pub(crate) struct MmappedSlice {
14    mmap: Arc<Mmap>,
15    start: usize,
16    end: usize,
17}
18
19impl MmappedSlice {
20    pub fn from_path(path: &Path) -> Result<Self, MmapError> {
21        let f = File::open(path).map_err(|error| MmapError::FileOpen {
22            error,
23            path: path.to_path_buf(),
24        })?;
25        MmappedSlice::from_file(&f)
26    }
27
28    pub fn from_file(file: &File) -> Result<Self, MmapError> {
29        unsafe {
30            let mmap = Mmap::map(file).map_err(MmapError::Map)?;
31            let end = mmap.len();
32            Ok(MmappedSlice {
33                mmap: Arc::new(mmap),
34                start: 0,
35                end,
36            })
37        }
38    }
39
40    #[inline]
41    pub fn as_slice(&self) -> &[u8] {
42        &self.mmap[self.start..self.end]
43    }
44
45    #[inline]
46    #[track_caller]
47    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
48        let (start, end) = bounds(self.start, self.end, range);
49        MmappedSlice {
50            mmap: Arc::clone(&self.mmap),
51            start,
52            end,
53        }
54    }
55}
56
57#[track_caller]
58fn bounds(
59    original_start: usize,
60    original_end: usize,
61    range: impl RangeBounds<usize>,
62) -> (usize, usize) {
63    let start_offset = match range.start_bound() {
64        Bound::Included(&index) => index,
65        Bound::Excluded(index) => index.saturating_sub(1),
66        Bound::Unbounded => 0,
67    };
68    let start = original_start + start_offset;
69
70    let end = match range.end_bound() {
71        Bound::Included(index) => original_start + index.saturating_add(1),
72        Bound::Excluded(&index) => original_start + index,
73        Bound::Unbounded => original_end,
74    };
75
76    assert!(start <= end, "{start} <= {end}");
77    assert!(
78        start >= original_start,
79        "Start offset out of bounds: {start} >= {original_start}"
80    );
81    assert!(
82        end <= original_end,
83        "End offset out of bounds: {end} <= {original_end}"
84    );
85
86    (start, end)
87}
88
89impl Buf for MmappedSlice {
90    fn remaining(&self) -> usize {
91        self.as_slice().len()
92    }
93
94    fn chunk(&self) -> &[u8] {
95        self.as_slice()
96    }
97
98    fn advance(&mut self, cnt: usize) {
99        debug_assert!(cnt <= self.remaining());
100        self.start += cnt;
101    }
102}
103
104/// Errors that may occur when using one of the mmap-based implementations of
105/// [`TryFrom`].
106#[derive(Debug)]
107pub enum MmapError {
108    /// Unable to open the file.
109    FileOpen {
110        error: std::io::Error,
111        path: PathBuf,
112    },
113    /// Mapping the file into memory failed.
114    Map(std::io::Error),
115}
116
117impl Display for MmapError {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        match self {
120            MmapError::FileOpen { path, .. } => write!(f, "Unable to open \"{}\"", path.display()),
121            MmapError::Map(_) => write!(f, "Unable to map the file into memory"),
122        }
123    }
124}
125
126impl std::error::Error for MmapError {
127    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
128        match self {
129            MmapError::FileOpen { error, .. } | MmapError::Map(error) => Some(error),
130        }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::io::Write;
137
138    use super::*;
139
140    #[test]
141    fn full_range() {
142        let (start, end) = bounds(1, 10, ..);
143
144        assert_eq!(start, 1);
145        assert_eq!(end, 10);
146    }
147
148    #[test]
149    fn range_to() {
150        let (start, end) = bounds(1, 10, ..5);
151
152        assert_eq!(start, 1);
153        assert_eq!(end, 1 + 5);
154    }
155
156    #[test]
157    fn range_to_inclusive() {
158        let (start, end) = bounds(1, 10, ..=5);
159
160        assert_eq!(start, 1);
161        assert_eq!(end, 1 + 6);
162    }
163
164    #[test]
165    fn range_from() {
166        let (start, end) = bounds(1, 10, 5..);
167
168        assert_eq!(start, 1 + 5);
169        assert_eq!(end, 10);
170    }
171
172    #[test]
173    fn range() {
174        let (start, end) = bounds(1, 10, 5..8);
175
176        assert_eq!(start, 1 + 5);
177        assert_eq!(end, 1 + 8);
178    }
179
180    #[test]
181    fn range_at_end() {
182        let (start, end) = bounds(1, 10, 5..9);
183
184        assert_eq!(start, 1 + 5);
185        assert_eq!(end, 1 + 9);
186    }
187
188    #[test]
189    fn range_at_start() {
190        let (start, end) = bounds(1, 10, 1..5);
191
192        assert_eq!(start, 1 + 1);
193        assert_eq!(end, 1 + 5);
194    }
195
196    #[test]
197    fn range_inclusive() {
198        let (start, end) = bounds(1, 10, 1..=5);
199
200        assert_eq!(start, 1 + 1);
201        assert_eq!(end, 1 + 5 + 1);
202    }
203
204    #[test]
205    fn range_inclusive_at_end() {
206        let (start, end) = bounds(1, 10, 5..=8);
207
208        assert_eq!(start, 1 + 5);
209        assert_eq!(end, 1 + 8 + 1);
210    }
211
212    #[test]
213    fn simple_mmap() {
214        let mut temp = tempfile::tempfile().unwrap();
215        let content = b"Hello, World!";
216        temp.write_all(content).unwrap();
217
218        let mmap = MmappedSlice::from_file(&temp).unwrap();
219
220        assert_eq!(mmap.as_slice(), content);
221    }
222
223    #[test]
224    fn slice_mmap() {
225        let mut temp = tempfile::tempfile().unwrap();
226        let content = b"Hello, World!";
227        temp.write_all(content).unwrap();
228        let mmap = MmappedSlice::from_file(&temp).unwrap();
229
230        let slice = mmap.slice(..5);
231
232        assert_eq!(slice.as_slice(), b"Hello");
233    }
234
235    #[test]
236    fn slicing_is_relative_to_the_slice_not_the_overall_file() {
237        let mut temp = tempfile::tempfile().unwrap();
238        let content = "Hello, World!";
239        temp.write_all(content.as_ref()).unwrap();
240        let mmap = MmappedSlice::from_file(&temp).unwrap();
241        let slice = mmap.slice(3..);
242
243        let sub_slice = slice.slice(4..7);
244
245        assert_eq!(
246            std::str::from_utf8(sub_slice.as_slice()).unwrap(),
247            &content[3 + 4..3 + 7]
248        );
249    }
250}