1use std::{
2 fmt::Display,
3 fs::File,
4 ops::{Bound, RangeBounds},
5 path::{Path, PathBuf},
6 sync::Arc,
7};
8
9use bytes::Buf;
10use memmap2::Mmap;
11
12#[derive(Debug, Clone)]
13pub(crate) struct MmappedSlice {
14 mmap: Arc<Mmap>,
15 start: usize,
16 end: usize,
17}
18
19impl MmappedSlice {
20 pub fn from_path(path: &Path) -> Result<Self, MmapError> {
21 let f = File::open(path).map_err(|error| MmapError::FileOpen {
22 error,
23 path: path.to_path_buf(),
24 })?;
25 MmappedSlice::from_file(&f)
26 }
27
28 pub fn from_file(file: &File) -> Result<Self, MmapError> {
29 unsafe {
30 let mmap = Mmap::map(file).map_err(MmapError::Map)?;
31 let end = mmap.len();
32 Ok(MmappedSlice {
33 mmap: Arc::new(mmap),
34 start: 0,
35 end,
36 })
37 }
38 }
39
40 #[inline]
41 pub fn as_slice(&self) -> &[u8] {
42 &self.mmap[self.start..self.end]
43 }
44
45 #[inline]
46 #[track_caller]
47 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
48 let (start, end) = bounds(self.start, self.end, range);
49 MmappedSlice {
50 mmap: Arc::clone(&self.mmap),
51 start,
52 end,
53 }
54 }
55}
56
57#[track_caller]
58fn bounds(
59 original_start: usize,
60 original_end: usize,
61 range: impl RangeBounds<usize>,
62) -> (usize, usize) {
63 let start_offset = match range.start_bound() {
64 Bound::Included(&index) => index,
65 Bound::Excluded(index) => index.saturating_sub(1),
66 Bound::Unbounded => 0,
67 };
68 let start = original_start + start_offset;
69
70 let end = match range.end_bound() {
71 Bound::Included(index) => original_start + index.saturating_add(1),
72 Bound::Excluded(&index) => original_start + index,
73 Bound::Unbounded => original_end,
74 };
75
76 assert!(start <= end, "{start} <= {end}");
77 assert!(
78 start >= original_start,
79 "Start offset out of bounds: {start} >= {original_start}"
80 );
81 assert!(
82 end <= original_end,
83 "End offset out of bounds: {end} <= {original_end}"
84 );
85
86 (start, end)
87}
88
89impl Buf for MmappedSlice {
90 fn remaining(&self) -> usize {
91 self.as_slice().len()
92 }
93
94 fn chunk(&self) -> &[u8] {
95 self.as_slice()
96 }
97
98 fn advance(&mut self, cnt: usize) {
99 debug_assert!(cnt <= self.remaining());
100 self.start += cnt;
101 }
102}
103
104#[derive(Debug)]
107pub enum MmapError {
108 FileOpen {
110 error: std::io::Error,
111 path: PathBuf,
112 },
113 Map(std::io::Error),
115}
116
117impl Display for MmapError {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 MmapError::FileOpen { path, .. } => write!(f, "Unable to open \"{}\"", path.display()),
121 MmapError::Map(_) => write!(f, "Unable to map the file into memory"),
122 }
123 }
124}
125
126impl std::error::Error for MmapError {
127 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
128 match self {
129 MmapError::FileOpen { error, .. } | MmapError::Map(error) => Some(error),
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::io::Write;
137
138 use super::*;
139
140 #[test]
141 fn full_range() {
142 let (start, end) = bounds(1, 10, ..);
143
144 assert_eq!(start, 1);
145 assert_eq!(end, 10);
146 }
147
148 #[test]
149 fn range_to() {
150 let (start, end) = bounds(1, 10, ..5);
151
152 assert_eq!(start, 1);
153 assert_eq!(end, 1 + 5);
154 }
155
156 #[test]
157 fn range_to_inclusive() {
158 let (start, end) = bounds(1, 10, ..=5);
159
160 assert_eq!(start, 1);
161 assert_eq!(end, 1 + 6);
162 }
163
164 #[test]
165 fn range_from() {
166 let (start, end) = bounds(1, 10, 5..);
167
168 assert_eq!(start, 1 + 5);
169 assert_eq!(end, 10);
170 }
171
172 #[test]
173 fn range() {
174 let (start, end) = bounds(1, 10, 5..8);
175
176 assert_eq!(start, 1 + 5);
177 assert_eq!(end, 1 + 8);
178 }
179
180 #[test]
181 fn range_at_end() {
182 let (start, end) = bounds(1, 10, 5..9);
183
184 assert_eq!(start, 1 + 5);
185 assert_eq!(end, 1 + 9);
186 }
187
188 #[test]
189 fn range_at_start() {
190 let (start, end) = bounds(1, 10, 1..5);
191
192 assert_eq!(start, 1 + 1);
193 assert_eq!(end, 1 + 5);
194 }
195
196 #[test]
197 fn range_inclusive() {
198 let (start, end) = bounds(1, 10, 1..=5);
199
200 assert_eq!(start, 1 + 1);
201 assert_eq!(end, 1 + 5 + 1);
202 }
203
204 #[test]
205 fn range_inclusive_at_end() {
206 let (start, end) = bounds(1, 10, 5..=8);
207
208 assert_eq!(start, 1 + 5);
209 assert_eq!(end, 1 + 8 + 1);
210 }
211
212 #[test]
213 fn simple_mmap() {
214 let mut temp = tempfile::tempfile().unwrap();
215 let content = b"Hello, World!";
216 temp.write_all(content).unwrap();
217
218 let mmap = MmappedSlice::from_file(&temp).unwrap();
219
220 assert_eq!(mmap.as_slice(), content);
221 }
222
223 #[test]
224 fn slice_mmap() {
225 let mut temp = tempfile::tempfile().unwrap();
226 let content = b"Hello, World!";
227 temp.write_all(content).unwrap();
228 let mmap = MmappedSlice::from_file(&temp).unwrap();
229
230 let slice = mmap.slice(..5);
231
232 assert_eq!(slice.as_slice(), b"Hello");
233 }
234
235 #[test]
236 fn slicing_is_relative_to_the_slice_not_the_overall_file() {
237 let mut temp = tempfile::tempfile().unwrap();
238 let content = "Hello, World!";
239 temp.write_all(content.as_ref()).unwrap();
240 let mmap = MmappedSlice::from_file(&temp).unwrap();
241 let slice = mmap.slice(3..);
242
243 let sub_slice = slice.slice(4..7);
244
245 assert_eq!(
246 std::str::from_utf8(sub_slice.as_slice()).unwrap(),
247 &content[3 + 4..3 + 7]
248 );
249 }
250}