use std::future::Future;
use tokio::sync::mpsc;
#[cfg(not(web))]
mod implementation {
use super::*;
pub trait Post: Send + Sync {}
impl<T: Send + Sync> Post for T {}
pub type NoInput = std::convert::Infallible;
pub type NonBlockingFuture<R> = tokio::task::JoinHandle<R>;
pub type BlockingFuture<R> = tokio::task::JoinHandle<R>;
pub type InputReceiver<T> = tokio_stream::wrappers::UnboundedReceiverStream<T>;
pub use mpsc::error::SendError;
pub fn spawn<F: Future<Output: Send> + Send + 'static>(
future: F,
) -> NonBlockingFuture<F::Output> {
tokio::task::spawn(future)
}
pub struct Blocking<Input = NoInput, Output = ()> {
sender: mpsc::UnboundedSender<Input>,
join_handle: tokio::task::JoinHandle<Output>,
}
impl<Input: Send + 'static, Output: Send + 'static> Blocking<Input, Output> {
pub async fn spawn<F: Future<Output = Output>>(
work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
Self {
sender,
join_handle: tokio::task::spawn_blocking(|| {
futures::executor::block_on(work(receiver.into()))
}),
}
}
pub async fn join(self) -> Output {
self.join_handle.await.expect("task shouldn't be cancelled")
}
pub fn send(&self, message: Input) -> Result<(), SendError<Input>> {
self.sender.send(message)
}
}
}
#[cfg(web)]
mod implementation {
use std::convert::TryFrom;
use futures::{channel::oneshot, stream, StreamExt as _};
use wasm_bindgen::prelude::*;
use web_sys::js_sys;
use super::*;
use crate::dyn_convert;
pub trait Post: dyn_convert::DynInto<JsValue> {}
impl<T: dyn_convert::DynInto<JsValue>> Post for T {}
pub enum NoInput {}
impl TryFrom<JsValue> for NoInput {
type Error = JsValue;
fn try_from(value: JsValue) -> Result<Self, JsValue> {
Err(value)
}
}
impl From<NoInput> for JsValue {
fn from(no_input: NoInput) -> Self {
match no_input {}
}
}
pub struct SendError<T>(T);
impl<T> std::fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SendError").finish_non_exhaustive()
}
}
impl<T> std::fmt::Display for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "send error")
}
}
impl<T> std::error::Error for SendError<T> {}
pub struct Blocking<Input = NoInput, Output = ()> {
join_handle: wasm_thread::JoinHandle<Output>,
_phantom: std::marker::PhantomData<fn(Input)>,
}
pub type InputReceiver<T> =
stream::Map<tokio_stream::wrappers::UnboundedReceiverStream<JsValue>, fn(JsValue) -> T>;
fn convert_or_panic<V, T: TryFrom<V, Error: std::fmt::Debug>>(value: V) -> T {
T::try_from(value).expect("type correctness should ensure this can be deserialized")
}
pub type NonblockingFuture<R> = oneshot::Receiver<R>;
pub fn spawn<F: Future + 'static>(future: F) -> NonblockingFuture<F::Output> {
let (send, recv) = oneshot::channel();
wasm_bindgen_futures::spawn_local(async {
let _ = send.send(future.await);
});
recv
}
impl<Input, Output> Blocking<Input, Output> {
pub async fn spawn<F: Future<Output = Output>>(
work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
) -> Self
where
Input: Into<JsValue> + TryFrom<JsValue, Error: std::fmt::Debug>,
Output: Send + 'static,
{
let (ready_sender, ready_receiver) = oneshot::channel();
let join_handle = wasm_thread::Builder::new()
.spawn(|| async move {
let (input_sender, input_receiver) = mpsc::unbounded_channel::<JsValue>();
let input_receiver =
tokio_stream::wrappers::UnboundedReceiverStream::new(input_receiver);
let onmessage = wasm_bindgen::closure::Closure::<
dyn FnMut(web_sys::MessageEvent) -> Result<(), JsError>,
>::new(
move |event: web_sys::MessageEvent| -> Result<(), JsError> {
input_sender.send(event.data())?;
Ok(())
},
);
js_sys::global()
.dyn_into::<web_sys::DedicatedWorkerGlobalScope>()
.unwrap()
.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
onmessage.forget(); ready_sender.send(()).unwrap();
work(input_receiver.map(convert_or_panic::<JsValue, Input>)).await
})
.expect("should successfully start Web Worker");
ready_receiver
.await
.expect("should successfully initialize the worker thread");
Self {
join_handle,
_phantom: Default::default(),
}
}
pub fn send(&self, message: Input) -> Result<(), SendError<Input>>
where
Input: Into<JsValue> + TryFrom<JsValue> + Clone,
{
self.join_handle
.thread()
.post_message(&message.clone().into())
.map_err(|_| SendError(message))
}
pub async fn join(self) -> Output {
match self.join_handle.join_async().await {
Ok(output) => output,
Err(panic) => std::panic::resume_unwind(panic),
}
}
}
}
pub use implementation::*;