1use std::{
5 collections::{BTreeMap, BTreeSet},
6 sync::Arc,
7};
8
9use futures::{lock::Mutex, stream::StreamExt, FutureExt};
10use linera_base::identifiers::{ApplicationId, ChainId};
11use linera_client::chain_listener::{ClientContext, ListenerCommand};
12use linera_core::{client::ChainClient, node::NotificationStream, worker::Reason};
13use linera_sdk::abis::controller::{LocalWorkerState, Operation, WorkerCommand};
14use serde_json::json;
15use tokio::{
16 select,
17 sync::mpsc::{self, UnboundedSender},
18};
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, error, info};
21
22use crate::task_processor::{OperatorMap, TaskProcessor};
23
24#[derive(Debug)]
26pub struct Update {
27 pub application_ids: Vec<ApplicationId>,
28}
29
30struct ProcessorHandle {
31 update_sender: mpsc::UnboundedSender<Update>,
32}
33
34pub struct Controller<Ctx: ClientContext> {
35 chain_id: ChainId,
36 controller_id: ApplicationId,
37 context: Arc<Mutex<Ctx>>,
38 chain_client: ChainClient<Ctx::Environment>,
39 cancellation_token: CancellationToken,
40 notifications: NotificationStream,
41 operators: OperatorMap,
42 processors: BTreeMap<ChainId, ProcessorHandle>,
43 listened_local_chains: BTreeSet<ChainId>,
44 command_sender: UnboundedSender<ListenerCommand>,
45}
46
47impl<Ctx> Controller<Ctx>
48where
49 Ctx: ClientContext + Send + Sync + 'static,
50 Ctx::Environment: 'static,
51 <Ctx::Environment as linera_core::Environment>::Storage: Clone,
52{
53 pub fn new(
54 chain_id: ChainId,
55 controller_id: ApplicationId,
56 context: Arc<Mutex<Ctx>>,
57 chain_client: ChainClient<Ctx::Environment>,
58 cancellation_token: CancellationToken,
59 operators: OperatorMap,
60 command_sender: UnboundedSender<ListenerCommand>,
61 ) -> Self {
62 let notifications = chain_client.subscribe().expect("client subscription");
63 Self {
64 chain_id,
65 controller_id,
66 context,
67 chain_client,
68 cancellation_token,
69 notifications,
70 operators,
71 processors: BTreeMap::new(),
72 listened_local_chains: BTreeSet::new(),
73 command_sender,
74 }
75 }
76
77 pub async fn run(mut self) {
78 info!(
79 "Watching for notifications for controller chain {}",
80 self.chain_id
81 );
82 self.process_controller_state().await;
83 loop {
84 select! {
85 Some(notification) = self.notifications.next() => {
86 if let Reason::NewBlock { .. } = notification.reason {
87 debug!("Processing notification on controller chain {}", self.chain_id);
88 self.process_controller_state().await;
89 }
90 }
91 _ = self.cancellation_token.cancelled().fuse() => {
92 break;
93 }
94 }
95 }
96 debug!("Notification stream ended.");
97 }
98
99 async fn process_controller_state(&mut self) {
100 let state = match self.query_controller_state().await {
101 Ok(state) => state,
102 Err(error) => {
103 error!("Error reading controller state: {error}");
104 return;
105 }
106 };
107 let Some(worker) = state.local_worker else {
108 self.register_worker().await;
110 return;
111 };
112 assert_eq!(
113 worker.owner,
114 self.chain_client
115 .preferred_owner()
116 .expect("The current wallet should own the chain being watched"),
117 "We should be registered with the current account owner."
118 );
119
120 let mut chain_apps: BTreeMap<ChainId, Vec<ApplicationId>> = BTreeMap::new();
122 for service in &state.local_services {
123 chain_apps
124 .entry(service.chain_id)
125 .or_default()
126 .push(service.application_id);
127 }
128
129 let old_chains: BTreeSet<_> = self.processors.keys().cloned().collect();
130
131 for (service_chain_id, application_ids) in chain_apps {
133 if let Err(err) = self
134 .update_or_spawn_processor(service_chain_id, application_ids)
135 .await
136 {
137 error!("Error updating or spawning processor: {err}");
138 return;
139 }
140 }
141
142 let active_chains: std::collections::BTreeSet<_> =
145 state.local_services.iter().map(|s| s.chain_id).collect();
146 let stale_chains: BTreeSet<_> = self
147 .processors
148 .keys()
149 .filter(|chain_id| !active_chains.contains(chain_id))
150 .cloned()
151 .collect();
152 for chain_id in &stale_chains {
153 if let Some(handle) = self.processors.get(chain_id) {
154 let update = Update {
155 application_ids: Vec::new(),
156 };
157 if handle.update_sender.send(update).is_err() {
158 self.processors.remove(chain_id);
160 }
161 }
162 }
163
164 let local_chains: BTreeSet<_> = state.local_chains.iter().cloned().collect();
166
167 let old_listened: BTreeSet<_> = old_chains
169 .union(&self.listened_local_chains)
170 .cloned()
171 .collect();
172
173 let desired_listened: BTreeSet<_> = active_chains.union(&local_chains).cloned().collect();
175
176 let owner = worker.owner;
178 let new_chains: BTreeMap<_, _> = desired_listened
179 .difference(&old_listened)
180 .map(|chain_id| (*chain_id, Some(owner)))
181 .collect();
182
183 let chains_to_stop: BTreeSet<_> = old_listened
185 .difference(&desired_listened)
186 .cloned()
187 .collect();
188
189 self.listened_local_chains = local_chains.difference(&active_chains).cloned().collect();
192
193 if let Err(error) = self.command_sender.send(ListenerCommand::SetMessagePolicy(
194 state.local_message_policy,
195 )) {
196 error!(%error, "error sending a command to chain listener");
197 }
198 if let Err(error) = self
199 .command_sender
200 .send(ListenerCommand::Listen(new_chains))
201 {
202 error!(%error, "error sending a command to chain listener");
203 }
204 if let Err(error) = self
205 .command_sender
206 .send(ListenerCommand::StopListening(chains_to_stop))
207 {
208 error!(%error, "error sending a command to chain listener");
209 }
210 }
211
212 async fn register_worker(&mut self) {
213 let capabilities = self.operators.keys().cloned().collect();
214 let command = WorkerCommand::RegisterWorker { capabilities };
215 let owner = self
216 .chain_client
217 .preferred_owner()
218 .expect("The current wallet should own the chain being watched");
219 let bytes =
220 bcs::to_bytes(&Operation::ExecuteWorkerCommand { owner, command }).expect("bcs bytes");
221 let operation = linera_execution::Operation::User {
222 application_id: self.controller_id,
223 bytes,
224 };
225 if let Err(e) = self
226 .chain_client
227 .execute_operations(vec![operation], vec![])
228 .await
229 {
230 error!("Failed to execute worker on-chain registration: {e}");
232 }
233 }
234
235 async fn update_or_spawn_processor(
236 &mut self,
237 service_chain_id: ChainId,
238 application_ids: Vec<ApplicationId>,
239 ) -> Result<(), anyhow::Error> {
240 if let Some(handle) = self.processors.get(&service_chain_id) {
241 let update = Update {
243 application_ids: application_ids.clone(),
244 };
245 if handle.update_sender.send(update).is_err() {
246 self.processors.remove(&service_chain_id);
248 self.spawn_processor(service_chain_id, application_ids)
249 .await?;
250 }
251 } else {
252 self.spawn_processor(service_chain_id, application_ids)
254 .await?;
255 }
256 Ok(())
257 }
258
259 async fn spawn_processor(
260 &mut self,
261 service_chain_id: ChainId,
262 application_ids: Vec<ApplicationId>,
263 ) -> Result<(), anyhow::Error> {
264 info!(
265 "Spawning TaskProcessor for chain {} with applications {:?}",
266 service_chain_id, application_ids
267 );
268
269 let (update_sender, update_receiver) = mpsc::unbounded_channel();
270
271 let chain_client = self
272 .context
273 .lock()
274 .await
275 .make_chain_client(service_chain_id)
276 .await?;
277 let processor = TaskProcessor::new(
278 service_chain_id,
279 application_ids,
280 chain_client,
281 self.cancellation_token.child_token(),
282 self.operators.clone(),
283 Some(update_receiver),
284 );
285
286 tokio::spawn(processor.run());
287
288 self.processors
289 .insert(service_chain_id, ProcessorHandle { update_sender });
290
291 Ok(())
292 }
293
294 async fn query_controller_state(&mut self) -> Result<LocalWorkerState, anyhow::Error> {
295 let query = "query { localWorkerState }";
296 let bytes = serde_json::to_vec(&json!({"query": query}))?;
297 let query = linera_execution::Query::User {
298 application_id: self.controller_id,
299 bytes,
300 };
301 let linera_execution::QueryOutcome {
302 response,
303 operations: _,
304 } = self.chain_client.query_application(query, None).await?;
305 let linera_execution::QueryResponse::User(response) = response else {
306 anyhow::bail!("cannot get a system response for a user query");
307 };
308 let mut response: serde_json::Value = serde_json::from_slice(&response)?;
309 let state = serde_json::from_value(response["data"]["localWorkerState"].take())?;
310 Ok(state)
311 }
312}