linera_core/
join_set_ext.rs1use futures::channel::oneshot;
12
13#[cfg(web)]
14mod implementation {
15 pub use futures::future::AbortHandle;
16 use futures::{future, stream, StreamExt as _};
17
18 use super::*;
19
20 #[derive(Default)]
21 pub struct JoinSet(Vec<oneshot::Receiver<()>>);
22
23 pub trait JoinSetExt: Sized {
25 fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;
30
31 fn await_all_tasks(&mut self) -> impl Future<Output = ()>;
33
34 fn reap_finished_tasks(&mut self);
36 }
37
38 impl JoinSetExt for JoinSet {
39 fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
40 let (abort_handle, abort_registration) = AbortHandle::new_pair();
41 let (send_done, recv_done) = oneshot::channel();
42 let (send_output, recv_output) = oneshot::channel();
43 let future = async move {
44 let _ = send_output.send(future.await);
45 let _ = send_done.send(());
46 };
47 self.0.push(recv_done);
48 wasm_bindgen_futures::spawn_local(
49 future::Abortable::new(future, abort_registration).map(drop),
50 );
51
52 TaskHandle {
53 output_receiver: recv_output,
54 abort_handle,
55 }
56 }
57
58 async fn await_all_tasks(&mut self) {
59 stream::iter(&mut self.0)
60 .then(|x| x)
61 .map(drop)
62 .collect()
63 .await
64 }
65
66 fn reap_finished_tasks(&mut self) {
67 self.0.retain_mut(|task| task.try_recv() == Ok(None));
68 }
69 }
70}
71
72#[cfg(not(web))]
73mod implementation {
74 pub use tokio::task::AbortHandle;
75
76 use super::*;
77
78 pub type JoinSet = tokio::task::JoinSet<()>;
79
80 #[trait_variant::make(Send)]
82 pub trait JoinSetExt: Sized {
83 fn spawn_task<F: Future<Output: Send> + Send + 'static>(
88 &mut self,
89 future: F,
90 ) -> TaskHandle<F::Output>;
91
92 async fn await_all_tasks(&mut self);
94
95 fn reap_finished_tasks(&mut self);
97 }
98
99 impl JoinSetExt for JoinSet {
100 fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
101 where
102 F: Future + Send + 'static,
103 F::Output: Send,
104 {
105 let (output_sender, output_receiver) = oneshot::channel();
106
107 let abort_handle = self.spawn(async move {
108 let _ = output_sender.send(future.await);
109 });
110
111 TaskHandle {
112 output_receiver,
113 abort_handle,
114 }
115 }
116
117 async fn await_all_tasks(&mut self) {
118 while self.join_next().await.is_some() {}
119 }
120
121 fn reap_finished_tasks(&mut self) {
122 while self.try_join_next().is_some() {}
123 }
124 }
125}
126
127use std::{
128 future::Future,
129 pin::Pin,
130 task::{Context, Poll},
131};
132
133use futures::FutureExt as _;
134pub use implementation::*;
135
136pub struct TaskHandle<Output> {
140 output_receiver: oneshot::Receiver<Output>,
141 abort_handle: AbortHandle,
142}
143
144impl<Output> Future for TaskHandle<Output> {
145 type Output = Result<Output, oneshot::Canceled>;
146
147 fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
148 self.as_mut().output_receiver.poll_unpin(context)
149 }
150}
151
152impl<Output> TaskHandle<Output> {
153 pub fn abort(&self) {
155 self.abort_handle.abort();
156 }
157
158 pub fn is_running(&mut self) -> bool {
160 self.output_receiver.try_recv().is_err()
161 }
162}