1use std::future::Future;
9
10use tokio::sync::mpsc;
11
12#[cfg(not(web))]
13mod implementation {
14 use super::*;
15
16 pub trait Post: Send + Sync {}
20
21 impl<T: Send + Sync> Post for T {}
22
23 pub type NoInput = std::convert::Infallible;
25
26 pub type NonBlockingFuture<R> = tokio::task::JoinHandle<R>;
28 pub type BlockingFuture<R> = tokio::task::JoinHandle<R>;
30 pub type InputReceiver<T> = tokio_stream::wrappers::UnboundedReceiverStream<T>;
32 pub use mpsc::error::SendError;
34
35 pub fn spawn<F: Future<Output: Send> + Send + 'static>(
37 future: F,
38 ) -> NonBlockingFuture<F::Output> {
39 tokio::task::spawn(future)
40 }
41
42 pub struct Blocking<Input = NoInput, Output = ()> {
44 sender: mpsc::UnboundedSender<Input>,
45 join_handle: tokio::task::JoinHandle<Output>,
46 }
47
48 impl<Input: Send + 'static, Output: Send + 'static> Blocking<Input, Output> {
49 pub async fn spawn<F: Future<Output = Output>>(
51 work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
52 ) -> Self {
53 let (sender, receiver) = mpsc::unbounded_channel();
54 Self {
55 sender,
56 join_handle: tokio::task::spawn_blocking(|| {
57 futures::executor::block_on(work(receiver.into()))
58 }),
59 }
60 }
61
62 pub async fn join(self) -> Output {
64 self.join_handle.await.expect("task shouldn't be cancelled")
65 }
66
67 pub fn send(&self, message: Input) -> Result<(), SendError<Input>> {
69 self.sender.send(message)
70 }
71 }
72}
73
74#[cfg(web)]
75mod implementation {
76 use std::convert::TryFrom;
77
78 use futures::{channel::oneshot, stream, StreamExt as _};
79 use wasm_bindgen::prelude::*;
80 use web_sys::js_sys;
81
82 use super::*;
83 use crate::dyn_convert;
84
85 pub trait Post: dyn_convert::DynInto<JsValue> {}
90
91 impl<T: dyn_convert::DynInto<JsValue>> Post for T {}
92
93 pub enum NoInput {}
95
96 impl TryFrom<JsValue> for NoInput {
97 type Error = JsValue;
98 fn try_from(value: JsValue) -> Result<Self, JsValue> {
99 Err(value)
100 }
101 }
102
103 impl From<NoInput> for JsValue {
104 fn from(no_input: NoInput) -> Self {
105 match no_input {}
106 }
107 }
108
109 pub struct SendError<T>(T);
111
112 impl<T> std::fmt::Debug for SendError<T> {
113 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114 f.debug_struct("SendError").finish_non_exhaustive()
115 }
116 }
117
118 impl<T> std::fmt::Display for SendError<T> {
119 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
120 write!(f, "send error")
121 }
122 }
123
124 impl<T> std::error::Error for SendError<T> {}
125
126 pub struct Blocking<Input = NoInput, Output = ()> {
128 join_handle: wasm_thread::JoinHandle<Output>,
129 _phantom: std::marker::PhantomData<fn(Input)>,
130 }
131
132 pub type InputReceiver<T> =
134 stream::Map<tokio_stream::wrappers::UnboundedReceiverStream<JsValue>, fn(JsValue) -> T>;
135
136 fn convert_or_panic<V, T: TryFrom<V, Error: std::fmt::Debug>>(value: V) -> T {
137 T::try_from(value).expect("type correctness should ensure this can be deserialized")
138 }
139
140 pub type NonBlockingFuture<R> = oneshot::Receiver<R>;
142
143 pub fn spawn<F: Future + 'static>(future: F) -> NonBlockingFuture<F::Output> {
145 let (send, recv) = oneshot::channel();
146 wasm_bindgen_futures::spawn_local(async {
147 let _ = send.send(future.await);
148 });
149 recv
150 }
151
152 impl<Input, Output> Blocking<Input, Output> {
153 pub async fn spawn<F: Future<Output = Output>>(
155 work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
156 ) -> Self
157 where
158 Input: Into<JsValue> + TryFrom<JsValue, Error: std::fmt::Debug>,
159 Output: Send + 'static,
160 {
161 let (ready_sender, ready_receiver) = oneshot::channel();
162 let join_handle = wasm_thread::Builder::new()
163 .spawn(|| async move {
164 let (input_sender, input_receiver) = mpsc::unbounded_channel::<JsValue>();
165 let input_receiver =
166 tokio_stream::wrappers::UnboundedReceiverStream::new(input_receiver);
167 let onmessage = wasm_bindgen::closure::Closure::<
168 dyn FnMut(web_sys::MessageEvent) -> Result<(), JsError>,
169 >::new(
170 move |event: web_sys::MessageEvent| -> Result<(), JsError> {
171 input_sender.send(event.data())?;
172 Ok(())
173 },
174 );
175 js_sys::global()
176 .dyn_into::<web_sys::DedicatedWorkerGlobalScope>()
177 .unwrap()
178 .set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
179 onmessage.forget(); ready_sender.send(()).unwrap();
181 work(input_receiver.map(convert_or_panic::<JsValue, Input>)).await
182 })
183 .expect("should successfully start Web Worker");
184
185 ready_receiver
186 .await
187 .expect("should successfully initialize the worker thread");
188 Self {
189 join_handle,
190 _phantom: Default::default(),
191 }
192 }
193
194 pub fn send(&self, message: Input) -> Result<(), SendError<Input>>
197 where
198 Input: Into<JsValue> + TryFrom<JsValue> + Clone,
199 {
200 self.join_handle
201 .thread()
202 .post_message(&message.clone().into())
203 .map_err(|_| SendError(message))
204 }
205
206 pub async fn join(self) -> Output {
208 match self.join_handle.join_async().await {
209 Ok(output) => output,
210 Err(panic) => std::panic::resume_unwind(panic),
211 }
212 }
213 }
214}
215
216pub use implementation::*;