1use std::sync::Arc;
5
6use linera_base::identifiers::ChainId;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8use tracing::trace;
9
10use crate::worker;
11
12pub struct ChannelNotifier<N> {
18 inner: papaya::HashMap<ChainId, Vec<UnboundedSender<N>>>,
19}
20
21impl<N> Default for ChannelNotifier<N> {
22 fn default() -> Self {
23 Self {
24 inner: papaya::HashMap::default(),
25 }
26 }
27}
28
29impl<N> ChannelNotifier<N> {
30 pub fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
32 let pinned = self.inner.pin();
33 for id in chain_ids {
34 pinned.update_or_insert_with(
35 id,
36 |senders| senders.iter().cloned().chain([sender.clone()]).collect(),
37 || vec![sender.clone()],
38 );
39 }
40 }
41
42 pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
44 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
45 self.add_sender(chain_ids, &tx);
46 rx
47 }
48
49 pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
52 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
53 self.add_sender(chain_ids, &tx);
54 tx.send(ack)
55 .expect("pushing to a new channel should succeed");
56 rx
57 }
58}
59
60impl<N> ChannelNotifier<N>
61where
62 N: Clone,
63{
64 pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
66 let pinned = self.inner.pin();
67
68 let Some(senders) = pinned.get(chain_id).cloned() else {
73 trace!("Chain {chain_id} has no subscribers.");
74 return;
75 };
76
77 let mut has_dead = false;
79 for sender in &senders {
80 if sender.send(notification.clone()).is_err() {
81 has_dead = true;
82 }
83 }
84
85 if has_dead {
88 pinned.compute(*chain_id, |entry| {
89 let Some((_key, current_senders)) = entry else {
90 return papaya::Operation::Abort(());
91 };
92 let live: Vec<_> = current_senders
93 .iter()
94 .filter(|s| !s.is_closed())
95 .cloned()
96 .collect();
97 if live.is_empty() {
98 trace!("No more subscribers for chain {chain_id}. Removing entry.");
99 papaya::Operation::Remove
100 } else {
101 papaya::Operation::Insert(live)
102 }
103 });
104 }
105 }
106}
107
108pub trait Notifier: Clone + Send + 'static {
109 fn notify(&self, notifications: &[worker::Notification]);
110}
111
112impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
113 fn notify(&self, notifications: &[worker::Notification]) {
114 for notification in notifications {
115 self.notify_chain(¬ification.chain_id, notification);
116 }
117 }
118}
119
120impl Notifier for () {
121 fn notify(&self, _notifications: &[worker::Notification]) {}
122}
123
124#[cfg(with_testing)]
125impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
126 fn notify(&self, notifications: &[worker::Notification]) {
127 let mut guard = self.lock().unwrap();
128 guard.extend(notifications.iter().cloned())
129 }
130}
131
132#[cfg(test)]
133pub mod tests {
134 use std::{
135 sync::{atomic::Ordering, Arc},
136 time::Duration,
137 };
138
139 use linera_execution::test_utils::dummy_chain_description;
140
141 use super::*;
142
143 #[test]
144 fn test_concurrent() {
145 let notifier = ChannelNotifier::default();
146
147 let chain_a = dummy_chain_description(0).id();
148 let chain_b = dummy_chain_description(1).id();
149
150 let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
151 let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
152 let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
153
154 let mut rx_a = notifier.subscribe(vec![chain_a]);
155 let mut rx_b = notifier.subscribe(vec![chain_b]);
156 let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
157
158 let a_rec_clone = a_rec.clone();
159 let b_rec_clone = b_rec.clone();
160 let a_b_rec_clone = a_b_rec.clone();
161
162 let notifier = Arc::new(notifier);
163
164 std::thread::spawn(move || {
165 while rx_a.blocking_recv().is_some() {
166 a_rec_clone.fetch_add(1, Ordering::Relaxed);
167 }
168 });
169
170 std::thread::spawn(move || {
171 while rx_b.blocking_recv().is_some() {
172 b_rec_clone.fetch_add(1, Ordering::Relaxed);
173 }
174 });
175
176 std::thread::spawn(move || {
177 while rx_a_b.blocking_recv().is_some() {
178 a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
179 }
180 });
181
182 const NOTIFICATIONS_A: usize = 500;
183 const NOTIFICATIONS_B: usize = 700;
184
185 let a_notifier = notifier.clone();
186 let handle_a = std::thread::spawn(move || {
187 for _ in 0..NOTIFICATIONS_A {
188 a_notifier.notify_chain(&chain_a, &());
189 }
190 });
191
192 let handle_b = std::thread::spawn(move || {
193 for _ in 0..NOTIFICATIONS_B {
194 notifier.notify_chain(&chain_b, &());
195 }
196 });
197
198 handle_a.join().unwrap();
200 handle_b.join().unwrap();
201
202 std::thread::sleep(Duration::from_millis(100));
204
205 assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
206 assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
207 assert_eq!(
208 a_b_rec.load(Ordering::Relaxed),
209 NOTIFICATIONS_A + NOTIFICATIONS_B
210 );
211 }
212
213 #[test]
214 fn test_eviction() {
215 let notifier = ChannelNotifier::default();
216
217 let chain_a = dummy_chain_description(0).id();
218 let chain_b = dummy_chain_description(1).id();
219 let chain_c = dummy_chain_description(2).id();
220 let chain_d = dummy_chain_description(3).id();
221
222 let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
228 let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
229 let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
230 let mut rx_d = notifier.subscribe(vec![chain_d]);
231
232 assert_eq!(notifier.inner.len(), 4);
233
234 rx_c.close();
235 notifier.notify_chain(&chain_c, &());
236 assert_eq!(notifier.inner.len(), 3);
237
238 rx_a.close();
239 notifier.notify_chain(&chain_a, &());
240 assert_eq!(notifier.inner.len(), 3);
241
242 rx_b.close();
243 notifier.notify_chain(&chain_b, &());
244 assert_eq!(notifier.inner.len(), 2);
245
246 notifier.notify_chain(&chain_a, &());
247 assert_eq!(notifier.inner.len(), 1);
248
249 rx_d.close();
250 notifier.notify_chain(&chain_d, &());
251 assert_eq!(notifier.inner.len(), 0);
252 }
253}