1use std::{
5 collections::{BTreeMap, BTreeSet},
6 sync::Arc,
7};
8
9use futures::{lock::Mutex, stream::StreamExt, FutureExt};
10use linera_base::identifiers::{ApplicationId, ChainId};
11use linera_client::chain_listener::{ClientContext, ListenerCommand};
12use linera_core::{
13 client::{ChainClient, ListeningMode},
14 node::NotificationStream,
15 worker::Reason,
16};
17use linera_sdk::abis::controller::{LocalWorkerState, Operation, WorkerCommand};
18use serde_json::json;
19use tokio::{
20 select,
21 sync::mpsc::{self, UnboundedSender},
22};
23use tokio_util::sync::CancellationToken;
24use tracing::{debug, error, info};
25
26use crate::task_processor::{OperatorMap, TaskProcessor};
27
28#[derive(Debug)]
30pub struct Update {
31 pub application_ids: Vec<ApplicationId>,
32}
33
34struct ProcessorHandle {
35 update_sender: mpsc::UnboundedSender<Update>,
36}
37
38pub struct Controller<Ctx: ClientContext> {
39 chain_id: ChainId,
40 controller_id: ApplicationId,
41 context: Arc<Mutex<Ctx>>,
42 chain_client: ChainClient<Ctx::Environment>,
43 cancellation_token: CancellationToken,
44 notifications: NotificationStream,
45 operators: OperatorMap,
46 processors: BTreeMap<ChainId, ProcessorHandle>,
47 command_sender: UnboundedSender<ListenerCommand>,
48}
49
50impl<Ctx> Controller<Ctx>
51where
52 Ctx: ClientContext + Send + Sync + 'static,
53 Ctx::Environment: 'static,
54 <Ctx::Environment as linera_core::Environment>::Storage: Clone,
55{
56 pub fn new(
57 chain_id: ChainId,
58 controller_id: ApplicationId,
59 context: Arc<Mutex<Ctx>>,
60 chain_client: ChainClient<Ctx::Environment>,
61 cancellation_token: CancellationToken,
62 operators: OperatorMap,
63 command_sender: UnboundedSender<ListenerCommand>,
64 ) -> Self {
65 let notifications = chain_client.subscribe().expect("client subscription");
66 Self {
67 chain_id,
68 controller_id,
69 context,
70 chain_client,
71 cancellation_token,
72 notifications,
73 operators,
74 processors: BTreeMap::new(),
75 command_sender,
76 }
77 }
78
79 pub async fn run(mut self) {
80 info!(
81 "Watching for notifications for controller chain {}",
82 self.chain_id
83 );
84 self.process_controller_state().await;
85 loop {
86 select! {
87 Some(notification) = self.notifications.next() => {
88 if let Reason::NewBlock { .. } = notification.reason {
89 debug!("Processing notification on controller chain {}", self.chain_id);
90 self.process_controller_state().await;
91 }
92 }
93 _ = self.cancellation_token.cancelled().fuse() => {
94 break;
95 }
96 }
97 }
98 debug!("Notification stream ended.");
99 }
100
101 async fn process_controller_state(&mut self) {
102 let state = match self.query_controller_state().await {
103 Ok(state) => state,
104 Err(error) => {
105 error!("Error reading controller state: {error}");
106 return;
107 }
108 };
109 let Some(worker) = state.local_worker else {
110 self.register_worker().await;
112 return;
113 };
114 assert_eq!(
115 worker.owner,
116 self.chain_client
117 .preferred_owner()
118 .expect("The current wallet should own the chain being watched"),
119 "We should be registered with the current account owner."
120 );
121
122 let mut chain_apps: BTreeMap<ChainId, Vec<ApplicationId>> = BTreeMap::new();
124 for service in &state.local_services {
125 chain_apps
126 .entry(service.chain_id)
127 .or_default()
128 .push(service.application_id);
129 }
130
131 let old_chains: BTreeSet<_> = self.processors.keys().cloned().collect();
132
133 for (service_chain_id, application_ids) in chain_apps {
135 if let Err(err) = self
136 .update_or_spawn_processor(service_chain_id, application_ids)
137 .await
138 {
139 error!("Error updating or spawning processor: {err}");
140 return;
141 }
142 }
143
144 let active_chains: std::collections::BTreeSet<_> =
147 state.local_services.iter().map(|s| s.chain_id).collect();
148 let stale_chains: BTreeSet<_> = self
149 .processors
150 .keys()
151 .filter(|chain_id| !active_chains.contains(chain_id))
152 .cloned()
153 .collect();
154 for chain_id in &stale_chains {
155 if let Some(handle) = self.processors.get(chain_id) {
156 let update = Update {
157 application_ids: Vec::new(),
158 };
159 if handle.update_sender.send(update).is_err() {
160 self.processors.remove(chain_id);
162 }
163 }
164 }
165
166 let new_chains: BTreeMap<_, _> = active_chains
167 .difference(&old_chains)
168 .map(|chain_id| (*chain_id, ListeningMode::FullChain))
169 .collect();
170
171 if let Err(err) = self
172 .command_sender
173 .send(ListenerCommand::Listen(new_chains))
174 {
175 error!(%err, "error sending a command to chain listener");
176 }
177 if let Err(err) = self
178 .command_sender
179 .send(ListenerCommand::StopListening(stale_chains))
180 {
181 error!(%err, "error sending a command to chain listener");
182 }
183 }
184
185 async fn register_worker(&mut self) {
186 let capabilities = self.operators.keys().cloned().collect();
187 let command = WorkerCommand::RegisterWorker { capabilities };
188 let owner = self
189 .chain_client
190 .preferred_owner()
191 .expect("The current wallet should own the chain being watched");
192 let bytes =
193 bcs::to_bytes(&Operation::ExecuteWorkerCommand { owner, command }).expect("bcs bytes");
194 let operation = linera_execution::Operation::User {
195 application_id: self.controller_id,
196 bytes,
197 };
198 if let Err(e) = self
199 .chain_client
200 .execute_operations(vec![operation], vec![])
201 .await
202 {
203 error!("Failed to execute worker on-chain registration: {e}");
205 }
206 }
207
208 async fn update_or_spawn_processor(
209 &mut self,
210 service_chain_id: ChainId,
211 application_ids: Vec<ApplicationId>,
212 ) -> Result<(), anyhow::Error> {
213 if let Some(handle) = self.processors.get(&service_chain_id) {
214 let update = Update {
216 application_ids: application_ids.clone(),
217 };
218 if handle.update_sender.send(update).is_err() {
219 self.processors.remove(&service_chain_id);
221 self.spawn_processor(service_chain_id, application_ids)
222 .await?;
223 }
224 } else {
225 self.spawn_processor(service_chain_id, application_ids)
227 .await?;
228 }
229 Ok(())
230 }
231
232 async fn spawn_processor(
233 &mut self,
234 service_chain_id: ChainId,
235 application_ids: Vec<ApplicationId>,
236 ) -> Result<(), anyhow::Error> {
237 info!(
238 "Spawning TaskProcessor for chain {} with applications {:?}",
239 service_chain_id, application_ids
240 );
241
242 let (update_sender, update_receiver) = mpsc::unbounded_channel();
243
244 let chain_client = self
245 .context
246 .lock()
247 .await
248 .make_chain_client(service_chain_id)
249 .await?;
250 let processor = TaskProcessor::new(
251 service_chain_id,
252 application_ids,
253 chain_client,
254 self.cancellation_token.child_token(),
255 self.operators.clone(),
256 Some(update_receiver),
257 );
258
259 tokio::spawn(processor.run());
260
261 self.processors
262 .insert(service_chain_id, ProcessorHandle { update_sender });
263
264 Ok(())
265 }
266
267 async fn query_controller_state(&mut self) -> Result<LocalWorkerState, anyhow::Error> {
268 let query = "query { localWorkerState }";
269 let bytes = serde_json::to_vec(&json!({"query": query}))?;
270 let query = linera_execution::Query::User {
271 application_id: self.controller_id,
272 bytes,
273 };
274 let linera_execution::QueryOutcome {
275 response,
276 operations: _,
277 } = self.chain_client.query_application(query, None).await?;
278 let linera_execution::QueryResponse::User(response) = response else {
279 anyhow::bail!("cannot get a system response for a user query");
280 };
281 let mut response: serde_json::Value = serde_json::from_slice(&response)?;
282 let state = serde_json::from_value(response["data"]["localWorkerState"].take())?;
283 Ok(state)
284 }
285}