1use std::sync::Arc;
5
6use linera_base::identifiers::ChainId;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8use tracing::trace;
9
10use crate::worker;
11
12pub struct ChannelNotifier<N> {
18 inner: papaya::HashMap<ChainId, Vec<UnboundedSender<N>>>,
19}
20
21impl<N> Default for ChannelNotifier<N> {
22 fn default() -> Self {
23 Self {
24 inner: papaya::HashMap::default(),
25 }
26 }
27}
28
29impl<N> ChannelNotifier<N> {
30 fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
31 let pinned = self.inner.pin();
32 for id in chain_ids {
33 pinned.update_or_insert_with(
34 id,
35 |senders| senders.iter().cloned().chain([sender.clone()]).collect(),
36 || vec![sender.clone()],
37 );
38 }
39 }
40
41 pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
43 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
44 self.add_sender(chain_ids, &tx);
45 rx
46 }
47
48 pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
51 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
52 self.add_sender(chain_ids, &tx);
53 tx.send(ack)
54 .expect("pushing to a new channel should succeed");
55 rx
56 }
57}
58
59impl<N> ChannelNotifier<N>
60where
61 N: Clone,
62{
63 pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
65 self.inner.pin().compute(*chain_id, |senders| {
66 let Some((_key, senders)) = senders else {
67 trace!("Chain {chain_id:?} has no subscribers.");
68 return papaya::Operation::Abort(());
69 };
70 let live_senders = senders
71 .iter()
72 .filter(|sender| sender.send(notification.clone()).is_ok())
73 .cloned()
74 .collect::<Vec<_>>();
75 if live_senders.is_empty() {
76 trace!("No more subscribers for chain {chain_id:?}. Removing entry.");
77 return papaya::Operation::Remove;
78 }
79 papaya::Operation::Insert(live_senders)
80 });
81 }
82}
83
84pub trait Notifier: Clone + Send + 'static {
85 fn notify(&self, notifications: &[worker::Notification]);
86}
87
88impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
89 fn notify(&self, notifications: &[worker::Notification]) {
90 for notification in notifications {
91 self.notify_chain(¬ification.chain_id, notification);
92 }
93 }
94}
95
96impl Notifier for () {
97 fn notify(&self, _notifications: &[worker::Notification]) {}
98}
99
100#[cfg(with_testing)]
101impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
102 fn notify(&self, notifications: &[worker::Notification]) {
103 let mut guard = self.lock().unwrap();
104 guard.extend(notifications.iter().cloned())
105 }
106}
107
108#[cfg(test)]
109pub mod tests {
110 use std::{
111 sync::{atomic::Ordering, Arc},
112 time::Duration,
113 };
114
115 use linera_execution::test_utils::dummy_chain_description;
116
117 use super::*;
118
119 #[test]
120 fn test_concurrent() {
121 let notifier = ChannelNotifier::default();
122
123 let chain_a = dummy_chain_description(0).id();
124 let chain_b = dummy_chain_description(1).id();
125
126 let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
127 let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
128 let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
129
130 let mut rx_a = notifier.subscribe(vec![chain_a]);
131 let mut rx_b = notifier.subscribe(vec![chain_b]);
132 let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
133
134 let a_rec_clone = a_rec.clone();
135 let b_rec_clone = b_rec.clone();
136 let a_b_rec_clone = a_b_rec.clone();
137
138 let notifier = Arc::new(notifier);
139
140 std::thread::spawn(move || {
141 while rx_a.blocking_recv().is_some() {
142 a_rec_clone.fetch_add(1, Ordering::Relaxed);
143 }
144 });
145
146 std::thread::spawn(move || {
147 while rx_b.blocking_recv().is_some() {
148 b_rec_clone.fetch_add(1, Ordering::Relaxed);
149 }
150 });
151
152 std::thread::spawn(move || {
153 while rx_a_b.blocking_recv().is_some() {
154 a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
155 }
156 });
157
158 const NOTIFICATIONS_A: usize = 500;
159 const NOTIFICATIONS_B: usize = 700;
160
161 let a_notifier = notifier.clone();
162 let handle_a = std::thread::spawn(move || {
163 for _ in 0..NOTIFICATIONS_A {
164 a_notifier.notify_chain(&chain_a, &());
165 }
166 });
167
168 let handle_b = std::thread::spawn(move || {
169 for _ in 0..NOTIFICATIONS_B {
170 notifier.notify_chain(&chain_b, &());
171 }
172 });
173
174 handle_a.join().unwrap();
176 handle_b.join().unwrap();
177
178 std::thread::sleep(Duration::from_millis(100));
180
181 assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
182 assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
183 assert_eq!(
184 a_b_rec.load(Ordering::Relaxed),
185 NOTIFICATIONS_A + NOTIFICATIONS_B
186 );
187 }
188
189 #[test]
190 fn test_eviction() {
191 let notifier = ChannelNotifier::default();
192
193 let chain_a = dummy_chain_description(0).id();
194 let chain_b = dummy_chain_description(1).id();
195 let chain_c = dummy_chain_description(2).id();
196 let chain_d = dummy_chain_description(3).id();
197
198 let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
204 let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
205 let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
206 let mut rx_d = notifier.subscribe(vec![chain_d]);
207
208 assert_eq!(notifier.inner.len(), 4);
209
210 rx_c.close();
211 notifier.notify_chain(&chain_c, &());
212 assert_eq!(notifier.inner.len(), 3);
213
214 rx_a.close();
215 notifier.notify_chain(&chain_a, &());
216 assert_eq!(notifier.inner.len(), 3);
217
218 rx_b.close();
219 notifier.notify_chain(&chain_b, &());
220 assert_eq!(notifier.inner.len(), 2);
221
222 notifier.notify_chain(&chain_a, &());
223 assert_eq!(notifier.inner.len(), 1);
224
225 rx_d.close();
226 notifier.notify_chain(&chain_d, &());
227 assert_eq!(notifier.inner.len(), 0);
228 }
229}