use std::{
io::{self, BufRead as _, Write as _},
path::Path,
};
use fs4::FileExt as _;
use thiserror_context::Context;
use super::{Dirty, Persist};
struct Lock(fs_err::File);
#[derive(Debug, thiserror::Error)]
enum ErrorInner {
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
}
thiserror_context::impl_context!(Error(ErrorInner));
trait CleanupExt {
type Ok;
type Error;
fn or_cleanup<E>(self, f: impl FnOnce() -> Result<(), E>) -> Result<Self::Ok, Self::Error>
where
E: Into<Self::Error>,
Result<(), E>: Context<Self::Error, Self::Ok, E>;
}
impl<T, W> CleanupExt for Result<T, W>
where
W: std::fmt::Display + Send + Sync + 'static,
{
type Ok = T;
type Error = W;
fn or_cleanup<E>(self, cleanup: impl FnOnce() -> Result<(), E>) -> Self
where
E: Into<W>,
Result<(), E>: Context<W, T, E>,
{
self.or_else(|error| {
if let Err(cleanup_error) = cleanup() {
Err(cleanup_error).context(error)
} else {
Err(error)
}
})
}
}
impl Lock {
pub fn new(file: fs_err::File) -> std::io::Result<Self> {
file.file().try_lock_exclusive()?;
Ok(Lock(file))
}
}
impl Drop for Lock {
fn drop(&mut self) {
if let Err(error) = self.0.file().unlock() {
tracing::warn!("Failed to unlock wallet file: {error}");
}
}
}
pub struct File<T> {
_lock: Lock,
path: std::path::PathBuf,
value: T,
dirty: Dirty,
}
impl<T> std::ops::Deref for File<T> {
type Target = T;
fn deref(&self) -> &T {
&self.value
}
}
impl<T> std::ops::DerefMut for File<T> {
fn deref_mut(&mut self) -> &mut T {
*self.dirty = true;
&mut self.value
}
}
fn open_options() -> fs_err::OpenOptions {
let mut options = fs_err::OpenOptions::new();
#[cfg(target_family = "unix")]
fs_err::os::unix::fs::OpenOptionsExt::mode(&mut options, 0o600);
options.create(true).read(true).write(true);
options
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> File<T> {
pub fn new(path: &Path, value: T) -> Result<Self, Error> {
let this = Self {
_lock: Lock::new(
fs_err::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(path)?,
)
.with_context(|| format!("locking path {}", path.display()))?,
path: path.into(),
value,
dirty: Dirty::new(true),
};
Ok(this)
}
pub fn read(path: &Path) -> Result<Self, Error> {
Self::read_or_create(path, || {
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("path does not exist: {}", path.display()),
)
.into())
})
}
pub fn read_or_create(
path: &Path,
value: impl FnOnce() -> Result<T, Error>,
) -> Result<Self, Error> {
let lock = Lock::new(open_options().read(true).open(path)?)?;
let mut reader = io::BufReader::new(&lock.0);
let file_is_empty = reader.fill_buf()?.is_empty();
Ok(Self {
value: if file_is_empty {
value()?
} else {
serde_json::from_reader(reader)?
},
dirty: Dirty::new(file_is_empty),
path: path.into(),
_lock: lock,
})
}
fn save(&mut self) -> Result<(), Error> {
let mut temp_file_path = self.path.clone();
temp_file_path.set_extension("json.new");
let temp_file = open_options().open(&temp_file_path)?;
let mut temp_file_writer = std::io::BufWriter::new(temp_file);
let remove_temp_file = || fs_err::remove_file(&temp_file_path);
serde_json::to_writer_pretty(&mut temp_file_writer, &self.value)
.map_err(Error::from)
.or_cleanup(remove_temp_file)?;
temp_file_writer
.flush()
.map_err(Error::from)
.or_cleanup(remove_temp_file)?;
fs_err::rename(&temp_file_path, &self.path)?;
*self.dirty = false;
Ok(())
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned + Send> Persist for File<T> {
type Error = Error;
fn as_mut(&mut self) -> &mut T {
&mut self.value
}
async fn persist(&mut self) -> Result<(), Error> {
self.save()
}
fn into_value(self) -> T {
self.value
}
}