1use std::sync::Arc;
5
6use dashmap::DashMap;
7use linera_base::identifiers::ChainId;
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9use tracing::trace;
10
11use crate::worker;
12
13pub struct ChannelNotifier<N> {
19 inner: DashMap<ChainId, Vec<UnboundedSender<N>>>,
20}
21
22impl<N> Default for ChannelNotifier<N> {
23 fn default() -> Self {
24 Self {
25 inner: DashMap::default(),
26 }
27 }
28}
29
30impl<N> ChannelNotifier<N> {
31 fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
32 for id in chain_ids {
33 let mut senders = self.inner.entry(id).or_default();
34 senders.push(sender.clone());
35 }
36 }
37
38 pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
40 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
41 self.add_sender(chain_ids, &tx);
42 rx
43 }
44
45 pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
48 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
49 self.add_sender(chain_ids, &tx);
50 tx.send(ack)
51 .expect("pushing to a new channel should succeed");
52 rx
53 }
54}
55
56impl<N> ChannelNotifier<N>
57where
58 N: Clone,
59{
60 pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
62 let senders_is_empty = {
63 let Some(mut senders) = self.inner.get_mut(chain_id) else {
64 trace!("Chain {chain_id:?} has no subscribers.");
65 return;
66 };
67 let mut dead_senders = vec![];
68 let senders = senders.value_mut();
69
70 for (index, sender) in senders.iter_mut().enumerate() {
71 if sender.send(notification.clone()).is_err() {
72 dead_senders.push(index);
73 }
74 }
75
76 for index in dead_senders.into_iter().rev() {
77 trace!("Removed dead subscriber for chain {chain_id:?}.");
78 senders.remove(index);
79 }
80
81 senders.is_empty()
82 };
83
84 if senders_is_empty {
85 trace!("No more subscribers for chain {chain_id:?}. Removing entry.");
86 self.inner.remove(chain_id);
87 }
88 }
89}
90
91pub trait Notifier: Clone + Send + 'static {
92 fn notify(&self, notifications: &[worker::Notification]);
93}
94
95impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
96 fn notify(&self, notifications: &[worker::Notification]) {
97 for notification in notifications {
98 self.notify_chain(¬ification.chain_id, notification);
99 }
100 }
101}
102
103impl Notifier for () {
104 fn notify(&self, _notifications: &[worker::Notification]) {}
105}
106
107#[cfg(with_testing)]
108impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
109 fn notify(&self, notifications: &[worker::Notification]) {
110 let mut guard = self.lock().unwrap();
111 guard.extend(notifications.iter().cloned())
112 }
113}
114
115#[cfg(test)]
116pub mod tests {
117 use std::{
118 sync::{atomic::Ordering, Arc},
119 time::Duration,
120 };
121
122 use linera_execution::test_utils::dummy_chain_description;
123
124 use super::*;
125
126 #[test]
127 fn test_concurrent() {
128 let notifier = ChannelNotifier::default();
129
130 let chain_a = dummy_chain_description(0).id();
131 let chain_b = dummy_chain_description(1).id();
132
133 let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
134 let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
135 let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
136
137 let mut rx_a = notifier.subscribe(vec![chain_a]);
138 let mut rx_b = notifier.subscribe(vec![chain_b]);
139 let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
140
141 let a_rec_clone = a_rec.clone();
142 let b_rec_clone = b_rec.clone();
143 let a_b_rec_clone = a_b_rec.clone();
144
145 let notifier = Arc::new(notifier);
146
147 std::thread::spawn(move || {
148 while rx_a.blocking_recv().is_some() {
149 a_rec_clone.fetch_add(1, Ordering::Relaxed);
150 }
151 });
152
153 std::thread::spawn(move || {
154 while rx_b.blocking_recv().is_some() {
155 b_rec_clone.fetch_add(1, Ordering::Relaxed);
156 }
157 });
158
159 std::thread::spawn(move || {
160 while rx_a_b.blocking_recv().is_some() {
161 a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
162 }
163 });
164
165 const NOTIFICATIONS_A: usize = 500;
166 const NOTIFICATIONS_B: usize = 700;
167
168 let a_notifier = notifier.clone();
169 let handle_a = std::thread::spawn(move || {
170 for _ in 0..NOTIFICATIONS_A {
171 a_notifier.notify_chain(&chain_a, &());
172 }
173 });
174
175 let handle_b = std::thread::spawn(move || {
176 for _ in 0..NOTIFICATIONS_B {
177 notifier.notify_chain(&chain_b, &());
178 }
179 });
180
181 handle_a.join().unwrap();
183 handle_b.join().unwrap();
184
185 std::thread::sleep(Duration::from_millis(100));
187
188 assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
189 assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
190 assert_eq!(
191 a_b_rec.load(Ordering::Relaxed),
192 NOTIFICATIONS_A + NOTIFICATIONS_B
193 );
194 }
195
196 #[test]
197 fn test_eviction() {
198 let notifier = ChannelNotifier::default();
199
200 let chain_a = dummy_chain_description(0).id();
201 let chain_b = dummy_chain_description(1).id();
202 let chain_c = dummy_chain_description(2).id();
203 let chain_d = dummy_chain_description(3).id();
204
205 let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
211 let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
212 let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
213 let mut rx_d = notifier.subscribe(vec![chain_d]);
214
215 assert_eq!(notifier.inner.len(), 4);
216
217 rx_c.close();
218 notifier.notify_chain(&chain_c, &());
219 assert_eq!(notifier.inner.len(), 3);
220
221 rx_a.close();
222 notifier.notify_chain(&chain_a, &());
223 assert_eq!(notifier.inner.len(), 3);
224
225 rx_b.close();
226 notifier.notify_chain(&chain_b, &());
227 assert_eq!(notifier.inner.len(), 2);
228
229 notifier.notify_chain(&chain_a, &());
230 assert_eq!(notifier.inner.len(), 1);
231
232 rx_d.close();
233 notifier.notify_chain(&chain_d, &());
234 assert_eq!(notifier.inner.len(), 0);
235 }
236}