1use std::{
5 collections::{BTreeMap, BTreeSet, HashMap},
6 sync::Arc,
7};
8
9use futures::{lock::Mutex, stream::StreamExt, FutureExt};
10use linera_base::{
11 data_types::{MessagePolicy, TimeDelta},
12 identifiers::{ApplicationId, ChainId, GenericApplicationId},
13};
14use linera_client::chain_listener::{ClientContext, ListenerCommand};
15use linera_core::{
16 client::ChainClient,
17 node::NotificationStream,
18 worker::{Notification, Reason},
19};
20use linera_sdk::abis::controller::{
21 LocalWorkerState, ManagedServiceId, Operation, PendingService, WorkerCommand,
22};
23use serde_json::json;
24use tokio::{
25 select,
26 sync::mpsc::{self, UnboundedSender},
27};
28use tokio_util::sync::CancellationToken;
29use tracing::{debug, error, info};
30
31use crate::task_processor::{OperatorMap, TaskProcessor};
32
33#[derive(Debug)]
35pub struct Update {
36 pub application_ids: Vec<ApplicationId>,
37}
38
39struct ProcessorHandle {
40 update_sender: mpsc::UnboundedSender<Update>,
41}
42
43pub struct Controller<Ctx: ClientContext> {
44 chain_id: ChainId,
45 controller_id: ApplicationId,
46 context: Arc<Mutex<Ctx>>,
47 chain_client: ChainClient<Ctx::Environment>,
48 cancellation_token: CancellationToken,
49 notifications: NotificationStream,
50 operators: OperatorMap,
51 retry_delay: TimeDelta,
52 processors: BTreeMap<ChainId, ProcessorHandle>,
53 listened_local_chains: BTreeSet<ChainId>,
54 current_message_policies: BTreeMap<ChainId, MessagePolicy>,
55 command_sender: UnboundedSender<ListenerCommand>,
56 pending_services_notifications: BTreeMap<
57 ChainId,
58 (
59 HashMap<ManagedServiceId, PendingService>,
60 NotificationStream,
61 ),
62 >,
63}
64
65impl<Ctx> Controller<Ctx>
66where
67 Ctx: ClientContext + Send + Sync + 'static,
68 Ctx::Environment: 'static,
69 <Ctx::Environment as linera_core::Environment>::Storage: Clone,
70{
71 #[allow(clippy::too_many_arguments)]
72 pub fn new(
73 chain_id: ChainId,
74 controller_id: ApplicationId,
75 context: Arc<Mutex<Ctx>>,
76 chain_client: ChainClient<Ctx::Environment>,
77 cancellation_token: CancellationToken,
78 operators: OperatorMap,
79 retry_delay: TimeDelta,
80 command_sender: UnboundedSender<ListenerCommand>,
81 ) -> Self {
82 let notifications = chain_client.subscribe().expect("client subscription");
83 Self {
84 chain_id,
85 controller_id,
86 context,
87 chain_client,
88 cancellation_token,
89 notifications,
90 operators,
91 retry_delay,
92 processors: BTreeMap::new(),
93 listened_local_chains: BTreeSet::new(),
94 current_message_policies: BTreeMap::new(),
95 command_sender,
96 pending_services_notifications: BTreeMap::new(),
97 }
98 }
99
100 pub async fn run(mut self) {
101 info!(
102 "Watching for notifications for controller chain {}",
103 self.chain_id
104 );
105 self.process_controller_state().await;
106 loop {
107 let pending_services_notifications: std::pin::Pin<
108 Box<dyn futures::Future<Output = (ChainId, Option<Notification>)> + Send>,
109 > = if !self.pending_services_notifications.is_empty() {
110 Box::pin(
111 futures::future::select_all(
112 self.pending_services_notifications.iter_mut().map(
113 |(chain_id, (_, notifications))| {
114 notifications.next().map(|result| (*chain_id, result))
115 },
116 ),
117 )
118 .map(|((chain_id, maybe_notification), _, _)| (chain_id, maybe_notification)),
119 )
120 } else {
121 Box::pin(futures::future::pending())
122 };
123 select! {
124 Some(notification) = self.notifications.next() => {
125 if let Reason::NewBlock { .. } = notification.reason {
126 debug!("Processing notification on controller chain {}", self.chain_id);
127 self.process_controller_state().await;
128 }
129 }
130 (chain_id, Some(notification)) = pending_services_notifications => {
131 self.process_pending_service_notification(chain_id, notification).await;
132 }
133 _ = self.cancellation_token.cancelled().fuse() => {
134 break;
135 }
136 }
137 }
138 debug!("Notification stream ended.");
139 }
140
141 async fn process_pending_service_notification(
142 &mut self,
143 chain_id: ChainId,
144 notification: Notification,
145 ) {
146 debug!(
147 "Processing notification on pending service chain {}",
148 chain_id
149 );
150 if let Reason::NewBlock { height, .. } = notification.reason {
151 let pending_services = &mut self
152 .pending_services_notifications
153 .get_mut(&chain_id)
154 .expect("the entry should exist")
155 .0;
156 for (service_id, pending_service) in &*pending_services {
157 if pending_service.start_block_height <= height {
158 let bytes = bcs::to_bytes(&Operation::StartLocalService {
159 service_id: *service_id,
160 })
161 .expect("bcs bytes");
162 let operation = linera_execution::Operation::User {
163 application_id: self.controller_id,
164 bytes,
165 };
166 if let Err(e) = self
167 .chain_client
168 .execute_operations(vec![operation], vec![])
169 .await
170 {
171 error!("Failed to execute worker on-chain registration: {e}");
173 }
174 }
175 }
176 pending_services
177 .retain(|_, pending_service| pending_service.start_block_height > height);
178 if pending_services.is_empty() {
179 let _ = self.pending_services_notifications.remove(&chain_id);
180 }
181 }
182 }
183
184 async fn process_controller_state(&mut self) {
185 let state = match self.query_controller_state().await {
186 Ok(state) => state,
187 Err(error) => {
188 error!("Error reading controller state: {error}");
189 return;
190 }
191 };
192 let Some(worker) = state.local_worker else {
193 self.register_worker().await;
195 return;
196 };
197 assert_eq!(
198 worker.owner,
199 self.chain_client
200 .preferred_owner()
201 .expect("The current wallet should own the chain being watched"),
202 "We should be registered with the current account owner."
203 );
204
205 for (managed_service_id, (chain_id, pending_service)) in &state.local_pending_services {
208 if self.pending_services_notifications.contains_key(chain_id) {
210 continue;
211 }
212 let service_notifications = self
213 .chain_client
214 .subscribe_to(*chain_id)
215 .expect("client subscription");
216 self.pending_services_notifications
217 .entry(*chain_id)
218 .or_insert_with(|| (HashMap::new(), service_notifications))
219 .0
220 .insert(*managed_service_id, pending_service.clone());
221 }
222
223 let mut chain_apps: BTreeMap<ChainId, Vec<ApplicationId>> = BTreeMap::new();
225 for service in &state.local_services {
226 chain_apps
227 .entry(service.chain_id)
228 .or_default()
229 .push(service.application_id);
230 }
231
232 let mut message_policies: BTreeMap<_, _> = chain_apps
235 .iter()
236 .map(|(chain_id, apps)| {
237 let message_policy = MessagePolicy {
238 reject_message_bundles_without_application_ids: Some(
239 apps.iter()
240 .map(|app_id| GenericApplicationId::User(*app_id))
241 .chain(std::iter::once(GenericApplicationId::User(
242 self.controller_id,
243 )))
244 .chain(std::iter::once(GenericApplicationId::System))
245 .collect(),
246 ),
247 ..Default::default()
248 };
249 (*chain_id, message_policy)
250 })
251 .collect();
252 message_policies.extend(state.local_message_policy);
253
254 let message_policies_to_update: BTreeMap<_, _> = message_policies
255 .iter()
256 .filter(|(chain_id, message_policy)| {
257 self.current_message_policies.get(chain_id) != Some(*message_policy)
258 })
259 .map(|(chain_id, message_policy)| (*chain_id, message_policy.clone()))
260 .collect();
261
262 let old_chains: BTreeSet<_> = self.processors.keys().cloned().collect();
263
264 for (service_chain_id, application_ids) in chain_apps {
266 if let Err(err) = self
267 .update_or_spawn_processor(service_chain_id, application_ids)
268 .await
269 {
270 error!("Error updating or spawning processor: {err}");
271 return;
272 }
273 }
274
275 let active_chains: std::collections::BTreeSet<_> =
278 state.local_services.iter().map(|s| s.chain_id).collect();
279 let stale_chains: BTreeSet<_> = self
280 .processors
281 .keys()
282 .filter(|chain_id| !active_chains.contains(chain_id))
283 .cloned()
284 .collect();
285 for chain_id in &stale_chains {
286 if let Some(handle) = self.processors.get(chain_id) {
287 let update = Update {
288 application_ids: Vec::new(),
289 };
290 if handle.update_sender.send(update).is_err() {
291 self.processors.remove(chain_id);
293 }
294 }
295 }
296
297 let local_chains: BTreeSet<_> = state.local_chains.iter().cloned().collect();
299
300 let old_listened: BTreeSet<_> = old_chains
302 .union(&self.listened_local_chains)
303 .cloned()
304 .collect();
305
306 let desired_listened: BTreeSet<_> = active_chains.union(&local_chains).cloned().collect();
308
309 let owner = worker.owner;
311 let mut new_chains: BTreeMap<_, _> = desired_listened
312 .difference(&old_listened)
313 .map(|chain_id| (*chain_id, Some(owner)))
314 .collect();
315
316 new_chains.extend(
318 state
319 .local_pending_services
320 .iter()
321 .map(|(_, (chain_id, _))| *chain_id)
322 .collect::<BTreeSet<_>>()
323 .difference(&old_listened)
324 .map(|chain_id| (*chain_id, None)),
325 );
326
327 let chains_to_stop: BTreeSet<_> = old_listened
329 .difference(&desired_listened)
330 .cloned()
331 .collect();
332
333 self.listened_local_chains = local_chains.difference(&active_chains).cloned().collect();
336
337 if let Err(error) = self
338 .command_sender
339 .send(ListenerCommand::Listen(new_chains))
340 {
341 error!(%error, "error sending a command to chain listener");
342 }
343 if let Err(error) = self
344 .command_sender
345 .send(ListenerCommand::StopListening(chains_to_stop))
346 {
347 error!(%error, "error sending a command to chain listener");
348 }
349 if let Err(error) = self.command_sender.send(ListenerCommand::SetMessagePolicy(
352 message_policies_to_update,
353 )) {
354 error!(%error, "error sending a command to chain listener");
355 }
356 self.current_message_policies = message_policies;
357 }
358
359 #[allow(clippy::needless_pass_by_ref_mut)]
362 async fn register_worker(&mut self) {
363 let capabilities = self.operators.keys().cloned().collect();
364 let command = WorkerCommand::RegisterWorker { capabilities };
365 let owner = self
366 .chain_client
367 .preferred_owner()
368 .expect("The current wallet should own the chain being watched");
369 let bytes =
370 bcs::to_bytes(&Operation::ExecuteWorkerCommand { owner, command }).expect("bcs bytes");
371 let operation = linera_execution::Operation::User {
372 application_id: self.controller_id,
373 bytes,
374 };
375 if let Err(e) = self
376 .chain_client
377 .execute_operations(vec![operation], vec![])
378 .await
379 {
380 error!("Failed to execute worker on-chain registration: {e}");
382 }
383 }
384
385 async fn update_or_spawn_processor(
386 &mut self,
387 service_chain_id: ChainId,
388 application_ids: Vec<ApplicationId>,
389 ) -> Result<(), anyhow::Error> {
390 if let Some(handle) = self.processors.get(&service_chain_id) {
391 let update = Update {
393 application_ids: application_ids.clone(),
394 };
395 if handle.update_sender.send(update).is_err() {
396 self.processors.remove(&service_chain_id);
398 self.spawn_processor(service_chain_id, application_ids)
399 .await?;
400 }
401 } else {
402 self.spawn_processor(service_chain_id, application_ids)
404 .await?;
405 }
406 Ok(())
407 }
408
409 async fn spawn_processor(
410 &mut self,
411 service_chain_id: ChainId,
412 application_ids: Vec<ApplicationId>,
413 ) -> Result<(), anyhow::Error> {
414 info!(
415 "Spawning TaskProcessor for chain {} with applications {:?}",
416 service_chain_id, application_ids
417 );
418
419 let (update_sender, update_receiver) = mpsc::unbounded_channel();
420
421 let mut chain_client = self
422 .context
423 .lock()
424 .await
425 .make_chain_client(service_chain_id)
426 .await?;
427 if let Some(owner) = self.chain_client.preferred_owner() {
430 chain_client.set_preferred_owner(owner);
431 }
432 let processor = TaskProcessor::new(
433 service_chain_id,
434 application_ids,
435 chain_client,
436 self.cancellation_token.child_token(),
437 self.operators.clone(),
438 self.retry_delay,
439 Some(update_receiver),
440 );
441
442 tokio::spawn(processor.run());
443
444 self.processors
445 .insert(service_chain_id, ProcessorHandle { update_sender });
446
447 Ok(())
448 }
449
450 #[allow(clippy::needless_pass_by_ref_mut)]
453 async fn query_controller_state(&mut self) -> Result<LocalWorkerState, anyhow::Error> {
454 let query = "query { localWorkerState }";
455 let bytes = serde_json::to_vec(&json!({"query": query}))?;
456 let query = linera_execution::Query::User {
457 application_id: self.controller_id,
458 bytes,
459 };
460 let (
461 linera_execution::QueryOutcome {
462 response,
463 operations: _,
464 },
465 _,
466 ) = self.chain_client.query_application(query, None).await?;
467 let linera_execution::QueryResponse::User(response) = response else {
468 anyhow::bail!("cannot get a system response for a user query");
469 };
470 let mut response: serde_json::Value = serde_json::from_slice(&response)?;
471 let state = serde_json::from_value(response["data"]["localWorkerState"].take())?;
472 Ok(state)
473 }
474}