1use std::{
6 collections::HashMap,
7 io, mem,
8 net::SocketAddr,
9 pin::{pin, Pin},
10 sync::Arc,
11};
12
13use async_trait::async_trait;
14use futures::{
15 future,
16 stream::{self, FuturesUnordered, SplitSink, SplitStream},
17 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
18};
19use linera_base::identifiers::ChainId;
20use linera_core::{JoinSetExt as _, TaskHandle};
21use serde::{Deserialize, Serialize};
22use tokio::{
23 io::AsyncWriteExt,
24 net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
25 sync::Mutex,
26 task::JoinSet,
27};
28use tokio_util::{codec::Framed, sync::CancellationToken, udp::UdpFramed};
29use tracing::{error, warn};
30
31use crate::{
32 simple::{codec, codec::Codec},
33 RpcMessage,
34};
35
36pub const DEFAULT_MAX_DATAGRAM_SIZE: &str = "65507";
38
39const REAP_TASKS_THRESHOLD: usize = 100;
41
42#[derive(clap::ValueEnum, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
44pub enum TransportProtocol {
45 Udp,
46 Tcp,
47}
48
49impl std::str::FromStr for TransportProtocol {
50 type Err = String;
51
52 fn from_str(s: &str) -> Result<Self, Self::Err> {
53 clap::ValueEnum::from_str(s, true)
54 }
55}
56
57impl std::fmt::Display for TransportProtocol {
58 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59 write!(f, "{:?}", self)
60 }
61}
62
63impl TransportProtocol {
64 pub fn scheme(&self) -> &'static str {
65 match self {
66 TransportProtocol::Udp => "udp",
67 TransportProtocol::Tcp => "tcp",
68 }
69 }
70}
71
72pub trait ConnectionPool: Send {
74 fn send_message_to<'a>(
75 &'a mut self,
76 message: RpcMessage,
77 address: &'a str,
78 ) -> future::BoxFuture<'a, Result<(), codec::Error>>;
79}
80
81#[async_trait]
87pub trait MessageHandler: Clone {
88 async fn handle_message(&mut self, message: RpcMessage) -> Option<RpcMessage>;
89
90 async fn handle_subscribe(
93 &mut self,
94 _chains: Vec<ChainId>,
95 ) -> Option<Pin<Box<dyn Stream<Item = RpcMessage> + Send>>> {
96 None
97 }
98}
99
100pub struct ServerHandle {
103 pub handle: TaskHandle<Result<(), std::io::Error>>,
104}
105
106impl ServerHandle {
107 pub async fn join(self) -> Result<(), std::io::Error> {
108 self.handle.await.map_err(|_| {
109 std::io::Error::new(
110 std::io::ErrorKind::Interrupted,
111 "Server task did not finish successfully",
112 )
113 })?
114 }
115}
116
117pub trait Transport:
122 Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
123{
124}
125
126impl<T> Transport for T where
127 T: Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
128{
129}
130
131impl TransportProtocol {
132 pub async fn connect(
134 self,
135 address: impl ToSocketAddrs,
136 ) -> Result<impl Transport, std::io::Error> {
137 let mut addresses = lookup_host(address)
138 .await
139 .expect("Invalid address to connect to");
140 let address = addresses
141 .next()
142 .expect("Couldn't resolve address to connect to");
143
144 let stream: futures::future::Either<_, _> = match self {
145 TransportProtocol::Udp => {
146 let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
147
148 UdpFramed::new(socket, Codec)
149 .with(move |message| future::ready(Ok((message, address))))
150 .map_ok(|(message, _address)| message)
151 .left_stream()
152 }
153 TransportProtocol::Tcp => {
154 let stream = TcpStream::connect(address).await?;
155
156 Framed::new(stream, Codec).right_stream()
157 }
158 };
159
160 Ok(stream)
161 }
162
163 pub async fn make_outgoing_connection_pool(
165 self,
166 ) -> Result<Box<dyn ConnectionPool>, std::io::Error> {
167 let pool: Box<dyn ConnectionPool> = match self {
168 Self::Udp => Box::new(UdpConnectionPool::new().await?),
169 Self::Tcp => Box::new(TcpConnectionPool::new()),
170 };
171 Ok(pool)
172 }
173
174 pub fn spawn_server<S>(
176 self,
177 address: impl ToSocketAddrs + Send + 'static,
178 state: S,
179 shutdown_signal: CancellationToken,
180 join_set: &mut JoinSet<()>,
181 ) -> ServerHandle
182 where
183 S: MessageHandler + Send + 'static,
184 {
185 let handle = match self {
186 Self::Udp => join_set.spawn_task(UdpServer::run(address, state, shutdown_signal)),
187 Self::Tcp => join_set.spawn_task(TcpServer::run(address, state, shutdown_signal)),
188 };
189 ServerHandle { handle }
190 }
191}
192
193struct UdpConnectionPool {
195 transport: UdpFramed<Codec>,
196}
197
198impl UdpConnectionPool {
199 async fn new() -> Result<Self, std::io::Error> {
200 let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
201 let transport = UdpFramed::new(socket, Codec);
202 Ok(Self { transport })
203 }
204}
205
206impl ConnectionPool for UdpConnectionPool {
207 fn send_message_to<'a>(
208 &'a mut self,
209 message: RpcMessage,
210 address: &'a str,
211 ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
212 Box::pin(async move {
213 let address = address.parse().map_err(std::io::Error::other)?;
214 self.transport.send((message, address)).await
215 })
216 }
217}
218
219pub struct UdpServer<State> {
221 handler: State,
222 udp_sink: SharedUdpSink,
223 udp_stream: SplitStream<UdpFramed<Codec>>,
224 active_handlers: HashMap<SocketAddr, TaskHandle<()>>,
225 join_set: JoinSet<()>,
226}
227
228type SharedUdpSink = Arc<Mutex<SplitSink<UdpFramed<Codec>, (RpcMessage, SocketAddr)>>>;
230
231impl<State> UdpServer<State>
232where
233 State: MessageHandler + Send + 'static,
234{
235 pub async fn run(
237 address: impl ToSocketAddrs,
238 state: State,
239 shutdown_signal: CancellationToken,
240 ) -> Result<(), std::io::Error> {
241 let mut server = Self::bind(address, state).await?;
242
243 loop {
244 tokio::select! { biased;
245 _ = shutdown_signal.cancelled() => {
246 server.shutdown().await;
247 return Ok(());
248 }
249 result = server.udp_stream.next() => match result {
250 Some(Ok((message, peer))) => server.handle_message(message, peer),
251 Some(Err(error)) => server.handle_error(error).await?,
252 None => unreachable!("`UdpFramed` should never return `None`"),
253 },
254 }
255 }
256 }
257
258 async fn bind(address: impl ToSocketAddrs, handler: State) -> Result<Self, std::io::Error> {
261 let socket = UdpSocket::bind(address).await?;
262 let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
263
264 Ok(UdpServer {
265 handler,
266 udp_sink: Arc::new(Mutex::new(udp_sink)),
267 udp_stream,
268 active_handlers: HashMap::new(),
269 join_set: JoinSet::new(),
270 })
271 }
272
273 fn handle_message(&mut self, message: RpcMessage, peer: SocketAddr) {
275 let previous_task = self.active_handlers.remove(&peer);
276 let mut state = self.handler.clone();
277 let udp_sink = self.udp_sink.clone();
278
279 let new_task = self.join_set.spawn_task(async move {
280 if let Some(reply) = state.handle_message(message).await {
281 if let Some(task) = previous_task {
282 if let Err(error) = task.await {
283 warn!("Message handler task panicked: {}", error);
284 }
285 }
286 let status = udp_sink.lock().await.send((reply, peer)).await;
287 if let Err(error) = status {
288 error!("Failed to send query response: {}", error);
289 }
290 }
291 });
292
293 self.active_handlers.insert(peer, new_task);
294
295 if self.active_handlers.len() >= REAP_TASKS_THRESHOLD {
296 self.active_handlers.retain(|_, task| task.is_running());
298 self.join_set.reap_finished_tasks();
299 }
300 }
301
302 async fn handle_error(&mut self, error: codec::Error) -> Result<(), std::io::Error> {
304 match error {
305 codec::Error::IoError(io_error) => {
306 error!("I/O error in UDP server: {io_error}");
307 self.shutdown().await;
308 Err(io_error)
309 }
310 other_error => {
311 warn!("Received an invalid message: {other_error}");
312 Ok(())
313 }
314 }
315 }
316
317 async fn shutdown(&mut self) {
319 let handlers = mem::take(&mut self.active_handlers);
320 let mut handler_results = handlers.into_values().collect::<FuturesUnordered<_>>();
321
322 while let Some(result) = handler_results.next().await {
323 if let Err(error) = result {
324 warn!("Message handler panicked: {}", error);
325 }
326 }
327
328 self.join_set.await_all_tasks().await;
329 }
330}
331
332struct TcpConnectionPool {
334 streams: HashMap<String, Framed<TcpStream, Codec>>,
335}
336
337impl TcpConnectionPool {
338 fn new() -> Self {
339 let streams = HashMap::new();
340 Self { streams }
341 }
342
343 async fn get_stream(
344 &mut self,
345 address: &str,
346 ) -> Result<&mut Framed<TcpStream, Codec>, io::Error> {
347 if !self.streams.contains_key(address) {
348 match TcpStream::connect(address).await {
349 Ok(s) => {
350 self.streams
351 .insert(address.to_string(), Framed::new(s, Codec));
352 }
353 Err(error) => {
354 error!("Failed to open connection to {}: {}", address, error);
355 return Err(error);
356 }
357 };
358 };
359 Ok(self.streams.get_mut(address).unwrap())
360 }
361}
362
363impl ConnectionPool for TcpConnectionPool {
364 fn send_message_to<'a>(
365 &'a mut self,
366 message: RpcMessage,
367 address: &'a str,
368 ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
369 Box::pin(async move {
370 let stream = self.get_stream(address).await?;
371 let result = stream.send(message).await;
372 if result.is_err() {
373 self.streams.remove(address);
374 }
375 result
376 })
377 }
378}
379
380pub struct TcpServer<State> {
382 connection: Framed<TcpStream, Codec>,
383 handler: State,
384 shutdown_signal: CancellationToken,
385}
386
387impl<State> TcpServer<State>
388where
389 State: MessageHandler + Send + 'static,
390{
391 pub async fn run(
396 address: impl ToSocketAddrs,
397 handler: State,
398 shutdown_signal: CancellationToken,
399 ) -> Result<(), std::io::Error> {
400 let listener = TcpListener::bind(address).await?;
401
402 let accept_stream = stream::try_unfold(listener, |listener| async move {
403 let (socket, _) = listener.accept().await?;
404 Ok::<_, io::Error>(Some((socket, listener)))
405 });
406 let mut accept_stream = pin!(accept_stream);
407
408 let connection_shutdown_signal = shutdown_signal.child_token();
409 let mut join_set = JoinSet::new();
410 let mut reap_countdown = REAP_TASKS_THRESHOLD;
411
412 loop {
413 tokio::select! { biased;
414 _ = shutdown_signal.cancelled() => {
415 join_set.await_all_tasks().await;
416 return Ok(());
417 }
418 maybe_socket = accept_stream.next() => match maybe_socket {
419 Some(Ok(socket)) => {
420 let server = TcpServer::new_connection(
421 socket,
422 handler.clone(),
423 connection_shutdown_signal.clone(),
424 );
425 join_set.spawn_task(server.serve());
426 reap_countdown -= 1;
427 }
428 Some(Err(error)) => {
429 join_set.await_all_tasks().await;
430 return Err(error);
431 }
432 None => unreachable!(
433 "The `accept_stream` should never finish unless there's an error",
434 ),
435 },
436 }
437
438 if reap_countdown == 0 {
439 join_set.reap_finished_tasks();
440 reap_countdown = REAP_TASKS_THRESHOLD;
441 }
442 }
443 }
444
445 fn new_connection(
448 tcp_stream: TcpStream,
449 handler: State,
450 shutdown_signal: CancellationToken,
451 ) -> Self {
452 TcpServer {
453 connection: Framed::new(tcp_stream, Codec),
454 handler,
455 shutdown_signal,
456 }
457 }
458
459 async fn serve(mut self) {
461 loop {
462 tokio::select! { biased;
463 _ = self.shutdown_signal.cancelled() => {
464 let mut tcp_stream = self.connection.into_inner();
465 if let Err(error) = tcp_stream.shutdown().await {
466 let peer = tcp_stream
467 .peer_addr()
468 .map_or_else(|_| "an unknown peer".to_owned(), |address| address.to_string());
469 warn!("Failed to close connection to {peer}: {error:?}");
470 }
471 return;
472 }
473 result = self.connection.next() => match result {
474 Some(Ok(RpcMessage::SubscribeNotifications(chains))) => {
475 self.handle_subscription(chains).await;
476 return;
477 }
478 Some(Ok(message)) => self.handle_message(message).await,
479 Some(Err(error)) => {
480 Self::handle_error(&error);
481 return;
482 }
483 None => break,
484 },
485 }
486 }
487 }
488
489 async fn handle_message(&mut self, message: RpcMessage) {
491 if let Some(reply) = self.handler.handle_message(message).await {
492 if let Err(error) = self.connection.send(reply).await {
493 error!("Failed to send query response: {error}");
494 }
495 }
496 }
497
498 async fn handle_subscription(&mut self, chains: Vec<ChainId>) {
500 let Some(mut stream) = self.handler.handle_subscribe(chains).await else {
501 return;
502 };
503 loop {
504 tokio::select! { biased;
505 _ = self.shutdown_signal.cancelled() => break,
506 msg = stream.next() => match msg {
507 Some(notification) => {
508 if let Err(error) = self.connection.send(notification).await {
509 error!("Failed to send notification: {error}");
510 break;
511 }
512 }
513 None => break,
514 }
515 }
516 }
517 }
518
519 fn handle_error(error: &codec::Error) {
524 if !matches!(
525 error,
526 codec::Error::IoError(error)
527 if error.kind() == io::ErrorKind::UnexpectedEof
528 || error.kind() == io::ErrorKind::ConnectionReset
529 ) {
530 error!("Error while reading TCP stream: {error}");
531 }
532 }
533}