linera_base/
task.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5Abstractions over tasks that can be used natively or on the Web.
6 */
7
8use std::future::Future;
9
10use tokio::sync::mpsc;
11
12#[cfg(not(web))]
13mod implementation {
14    use super::*;
15
16    /// Types that can be _explicitly_ sent to a new thread.
17    /// This differs from `Send` in that we can provide an explicit post step
18    /// (e.g. `postMessage` on the Web).
19    pub trait Post: Send + Sync {}
20
21    impl<T: Send + Sync> Post for T {}
22
23    /// A type that satisfies the send/receive bounds, but can never be sent or received.
24    pub type NoInput = std::convert::Infallible;
25
26    /// The type of a future awaiting another task.
27    pub type NonBlockingFuture<R> = tokio::task::JoinHandle<R>;
28    /// The type of a future awaiting another thread.
29    pub type BlockingFuture<R> = tokio::task::JoinHandle<R>;
30    /// The stream of inputs available to the spawned task.
31    pub type InputReceiver<T> = tokio_stream::wrappers::UnboundedReceiverStream<T>;
32    /// The type of errors that can result from sending a message to the spawned task.
33    pub use mpsc::error::SendError;
34
35    /// Spawns a new task, potentially on the current thread.
36    pub fn spawn<F: Future<Output: Send> + Send + 'static>(
37        future: F,
38    ) -> NonBlockingFuture<F::Output> {
39        tokio::task::spawn(future)
40    }
41
42    /// A new task running in a different thread.
43    pub struct Blocking<Input = NoInput, Output = ()> {
44        sender: mpsc::UnboundedSender<Input>,
45        join_handle: tokio::task::JoinHandle<Output>,
46    }
47
48    impl<Input: Send + 'static, Output: Send + 'static> Blocking<Input, Output> {
49        /// Spawns a blocking task on a new thread with a stream of input messages.
50        pub async fn spawn<F: Future<Output = Output>>(
51            work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
52        ) -> Self {
53            let (sender, receiver) = mpsc::unbounded_channel();
54            Self {
55                sender,
56                join_handle: tokio::task::spawn_blocking(|| {
57                    futures::executor::block_on(work(receiver.into()))
58                }),
59            }
60        }
61
62        /// Waits for the task to complete and returns its output.
63        pub async fn join(self) -> Output {
64            self.join_handle.await.expect("task shouldn't be cancelled")
65        }
66
67        /// Sends a message to the task.
68        pub fn send(&self, message: Input) -> Result<(), SendError<Input>> {
69            self.sender.send(message)
70        }
71    }
72}
73
74#[cfg(web)]
75mod implementation {
76    use std::convert::TryFrom;
77
78    use futures::{channel::oneshot, stream, StreamExt as _};
79    use wasm_bindgen::prelude::*;
80    use web_sys::js_sys;
81
82    use super::*;
83    use crate::dyn_convert;
84
85    /// Types that can be _explicitly_ sent to a new thread.
86    /// This differs from `Send` in that we can provide an explicit post step
87    /// (e.g. `postMessage` on the Web).
88    // TODO(#2809): this trait is overly liberal.
89    pub trait Post: dyn_convert::DynInto<JsValue> {}
90
91    impl<T: dyn_convert::DynInto<JsValue>> Post for T {}
92
93    /// A type that satisfies the send/receive bounds, but can never be sent or received.
94    pub enum NoInput {}
95
96    impl TryFrom<JsValue> for NoInput {
97        type Error = JsValue;
98        fn try_from(value: JsValue) -> Result<Self, JsValue> {
99            Err(value)
100        }
101    }
102
103    impl From<NoInput> for JsValue {
104        fn from(no_input: NoInput) -> Self {
105            match no_input {}
106        }
107    }
108
109    /// The type of errors that can result from sending a message to the spawned task.
110    pub struct SendError<T>(T);
111
112    impl<T> std::fmt::Debug for SendError<T> {
113        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114            f.debug_struct("SendError").finish_non_exhaustive()
115        }
116    }
117
118    impl<T> std::fmt::Display for SendError<T> {
119        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
120            write!(f, "send error")
121        }
122    }
123
124    impl<T> std::error::Error for SendError<T> {}
125
126    /// A new task running in a different thread.
127    pub struct Blocking<Input = NoInput, Output = ()> {
128        join_handle: wasm_thread::JoinHandle<Output>,
129        _phantom: std::marker::PhantomData<fn(Input)>,
130    }
131
132    /// The stream of inputs available to the spawned task.
133    pub type InputReceiver<T> =
134        stream::Map<tokio_stream::wrappers::UnboundedReceiverStream<JsValue>, fn(JsValue) -> T>;
135
136    fn convert_or_panic<V, T: TryFrom<V, Error: std::fmt::Debug>>(value: V) -> T {
137        T::try_from(value).expect("type correctness should ensure this can be deserialized")
138    }
139
140    /// The type of a future awaiting another task.
141    pub type NonBlockingFuture<R> = oneshot::Receiver<R>;
142
143    /// Spawns a new task on the current thread.
144    pub fn spawn<F: Future + 'static>(future: F) -> NonBlockingFuture<F::Output> {
145        let (send, recv) = oneshot::channel();
146        wasm_bindgen_futures::spawn_local(async {
147            let _ = send.send(future.await);
148        });
149        recv
150    }
151
152    impl<Input, Output> Blocking<Input, Output> {
153        /// Spawns a blocking task on a new Web Worker with a stream of input messages.
154        pub async fn spawn<F: Future<Output = Output>>(
155            work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
156        ) -> Self
157        where
158            Input: Into<JsValue> + TryFrom<JsValue, Error: std::fmt::Debug>,
159            Output: Send + 'static,
160        {
161            let (ready_sender, ready_receiver) = oneshot::channel();
162            let join_handle = wasm_thread::Builder::new()
163                .spawn(|| async move {
164                    let (input_sender, input_receiver) = mpsc::unbounded_channel::<JsValue>();
165                    let input_receiver =
166                        tokio_stream::wrappers::UnboundedReceiverStream::new(input_receiver);
167                    let onmessage = wasm_bindgen::closure::Closure::<
168                        dyn FnMut(web_sys::MessageEvent) -> Result<(), JsError>,
169                    >::new(
170                        move |event: web_sys::MessageEvent| -> Result<(), JsError> {
171                            input_sender.send(event.data())?;
172                            Ok(())
173                        },
174                    );
175                    js_sys::global()
176                        .dyn_into::<web_sys::DedicatedWorkerGlobalScope>()
177                        .unwrap()
178                        .set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
179                    onmessage.forget(); // doesn't truly forget it, but lets the JS GC take care of it
180                    ready_sender.send(()).unwrap();
181                    work(input_receiver.map(convert_or_panic::<JsValue, Input>)).await
182                })
183                .expect("should successfully start Web Worker");
184
185            ready_receiver
186                .await
187                .expect("should successfully initialize the worker thread");
188            Self {
189                join_handle,
190                _phantom: Default::default(),
191            }
192        }
193
194        /// Sends a message to the task using
195        /// [`postMessage`](https://developer.mozilla.org/en-US/docs/Web/API/Worker/postMessage).
196        pub fn send(&self, message: Input) -> Result<(), SendError<Input>>
197        where
198            Input: Into<JsValue> + TryFrom<JsValue> + Clone,
199        {
200            self.join_handle
201                .thread()
202                .post_message(&message.clone().into())
203                .map_err(|_| SendError(message))
204        }
205
206        /// Waits for the task to complete and returns its output.
207        pub async fn join(self) -> Output {
208            match self.join_handle.join_async().await {
209                Ok(output) => output,
210                Err(panic) => std::panic::resume_unwind(panic),
211            }
212        }
213    }
214}
215
216pub use implementation::*;