linera_rpc/simple/
client.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2// Copyright (c) Zefchain Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::sync::Arc;
6
7use futures::{sink::SinkExt, stream::StreamExt};
8use linera_base::{
9    crypto::CryptoHash,
10    data_types::{BlobContent, BlockHeight, NetworkDescription},
11    identifiers::{BlobId, ChainId, EventId},
12    time::{timer, Duration},
13};
14use linera_chain::{
15    data_types::BlockProposal,
16    types::{
17        ConfirmedBlockCertificate, LiteCertificate, TimeoutCertificate, ValidatedBlockCertificate,
18    },
19};
20use linera_core::{
21    data_types::{ChainInfoQuery, ChainInfoResponse},
22    node::{BlobStream, CrossChainMessageDelivery, NodeError, NotificationStream, ValidatorNode},
23};
24use linera_version::VersionInfo;
25
26use super::{codec, transport::TransportProtocol};
27use crate::{
28    config::ValidatorPublicNetworkPreConfig, HandleConfirmedCertificateRequest,
29    HandleLiteCertRequest, HandleTimeoutCertificateRequest, HandleValidatedCertificateRequest,
30    RpcMessage,
31};
32
33#[derive(Clone)]
34pub struct SimpleClient {
35    network: ValidatorPublicNetworkPreConfig<TransportProtocol>,
36    send_timeout: Duration,
37    recv_timeout: Duration,
38}
39
40impl SimpleClient {
41    pub(crate) fn new(
42        network: ValidatorPublicNetworkPreConfig<TransportProtocol>,
43        send_timeout: Duration,
44        recv_timeout: Duration,
45    ) -> Self {
46        Self {
47            network,
48            send_timeout,
49            recv_timeout,
50        }
51    }
52
53    async fn send_recv_internal(&self, message: RpcMessage) -> Result<RpcMessage, codec::Error> {
54        let address = format!("{}:{}", self.network.host, self.network.port);
55        let mut stream = self.network.protocol.connect(address).await?;
56        // Send message
57        timer::timeout(self.send_timeout, stream.send(message))
58            .await
59            .map_err(|timeout| codec::Error::IoError(timeout.into()))??;
60        // Wait for reply
61        timer::timeout(self.recv_timeout, stream.next())
62            .await
63            .map_err(|timeout| codec::Error::IoError(timeout.into()))?
64            .transpose()?
65            .ok_or_else(|| codec::Error::IoError(std::io::ErrorKind::UnexpectedEof.into()))
66    }
67
68    async fn query<Response>(&self, query: RpcMessage) -> Result<Response, Response::Error>
69    where
70        Response: TryFrom<RpcMessage>,
71        Response::Error: From<codec::Error>,
72    {
73        self.send_recv_internal(query).await?.try_into()
74    }
75}
76
77impl ValidatorNode for SimpleClient {
78    type NotificationStream = NotificationStream;
79
80    fn address(&self) -> String {
81        format!(
82            "{}://{}:{}",
83            self.network.protocol, self.network.host, self.network.port
84        )
85    }
86
87    /// Initiates a new block.
88    async fn handle_block_proposal(
89        &self,
90        proposal: BlockProposal,
91    ) -> Result<ChainInfoResponse, NodeError> {
92        let request = RpcMessage::BlockProposal(Box::new(proposal));
93        self.query(request).await
94    }
95
96    /// Processes a lite certificate.
97    async fn handle_lite_certificate(
98        &self,
99        certificate: LiteCertificate<'_>,
100        delivery: CrossChainMessageDelivery,
101    ) -> Result<ChainInfoResponse, NodeError> {
102        let wait_for_outgoing_messages = delivery.wait_for_outgoing_messages();
103        let request = RpcMessage::LiteCertificate(Box::new(HandleLiteCertRequest {
104            certificate: certificate.cloned(),
105            wait_for_outgoing_messages,
106        }));
107        self.query(request).await
108    }
109
110    /// Processes a validated certificate.
111    async fn handle_validated_certificate(
112        &self,
113        certificate: ValidatedBlockCertificate,
114    ) -> Result<ChainInfoResponse, NodeError> {
115        let request = HandleValidatedCertificateRequest { certificate };
116        let request = RpcMessage::ValidatedCertificate(Box::new(request));
117        self.query(request).await
118    }
119
120    /// Processes a confirmed certificate.
121    async fn handle_confirmed_certificate(
122        &self,
123        certificate: Arc<ConfirmedBlockCertificate>,
124        delivery: CrossChainMessageDelivery,
125    ) -> Result<ChainInfoResponse, NodeError> {
126        let wait_for_outgoing_messages = delivery.wait_for_outgoing_messages();
127        let request = HandleConfirmedCertificateRequest {
128            certificate: Arc::unwrap_or_clone(certificate),
129            wait_for_outgoing_messages,
130        };
131        let request = RpcMessage::ConfirmedCertificate(Box::new(request));
132        self.query(request).await
133    }
134
135    /// Processes a timeout certificate.
136    async fn handle_timeout_certificate(
137        &self,
138        certificate: TimeoutCertificate,
139    ) -> Result<ChainInfoResponse, NodeError> {
140        let request = HandleTimeoutCertificateRequest { certificate };
141        let request = RpcMessage::TimeoutCertificate(Box::new(request));
142        self.query(request).await
143    }
144
145    /// Handles information queries for this chain.
146    async fn handle_chain_info_query(
147        &self,
148        query: ChainInfoQuery,
149    ) -> Result<ChainInfoResponse, NodeError> {
150        let request = RpcMessage::ChainInfoQuery(Box::new(query));
151        self.query(request).await
152    }
153
154    async fn subscribe(&self, chains: Vec<ChainId>) -> Result<NotificationStream, NodeError> {
155        let mut stream = self
156            .network
157            .protocol
158            .connect((self.network.host.clone(), self.network.port))
159            .await
160            .map_err(|e| NodeError::ClientIoError {
161                error: e.to_string(),
162            })?;
163        // Send subscription request
164        timer::timeout(
165            self.send_timeout,
166            stream.send(RpcMessage::SubscribeNotifications(chains)),
167        )
168        .await
169        .map_err(|timeout| NodeError::ClientIoError {
170            error: timeout.to_string(),
171        })?
172        .map_err(|e| NodeError::ClientIoError {
173            error: e.to_string(),
174        })?;
175        // Return a stream that reads notifications from the connection
176        let notification_stream = stream.filter_map(|result| async {
177            match result {
178                Ok(RpcMessage::Notification(notification)) => Some(*notification),
179                _ => None,
180            }
181        });
182        Ok(Box::pin(notification_stream) as NotificationStream)
183    }
184
185    async fn get_version_info(&self) -> Result<VersionInfo, NodeError> {
186        self.query(RpcMessage::VersionInfoQuery).await
187    }
188
189    async fn get_network_description(&self) -> Result<NetworkDescription, NodeError> {
190        self.query(RpcMessage::NetworkDescriptionQuery).await
191    }
192
193    async fn upload_blob(&self, content: BlobContent) -> Result<BlobId, NodeError> {
194        self.query(RpcMessage::UploadBlob(Box::new(content))).await
195    }
196
197    async fn download_blob(&self, blob_id: BlobId) -> Result<BlobContent, NodeError> {
198        self.query(RpcMessage::DownloadBlob(Box::new(blob_id)))
199            .await
200    }
201
202    async fn download_blobs(&self, blob_ids: Vec<BlobId>) -> Result<BlobStream, NodeError> {
203        let mut stream = self
204            .network
205            .protocol
206            .connect((self.network.host.clone(), self.network.port))
207            .await
208            .map_err(|e| NodeError::ClientIoError {
209                error: e.to_string(),
210            })?;
211        timer::timeout(
212            self.send_timeout,
213            stream.send(RpcMessage::DownloadBlobs(blob_ids)),
214        )
215        .await
216        .map_err(|timeout| NodeError::ClientIoError {
217            error: timeout.to_string(),
218        })?
219        .map_err(|e| NodeError::ClientIoError {
220            error: e.to_string(),
221        })?;
222        let blob_stream = stream.filter_map(|result| async {
223            match result {
224                Ok(RpcMessage::DownloadBlobResponse(blob)) => Some(Ok(*blob)),
225                Ok(RpcMessage::Error(err)) => Some(Err(*err)),
226                Ok(_) => Some(Err(NodeError::UnexpectedMessage)),
227                Err(e) => Some(Err(NodeError::ClientIoError {
228                    error: e.to_string(),
229                })),
230            }
231        });
232        Ok(Box::pin(blob_stream))
233    }
234
235    async fn download_pending_blob(
236        &self,
237        chain_id: ChainId,
238        blob_id: BlobId,
239    ) -> Result<BlobContent, NodeError> {
240        self.query(RpcMessage::DownloadPendingBlob(Box::new((
241            chain_id, blob_id,
242        ))))
243        .await
244    }
245
246    async fn handle_pending_blob(
247        &self,
248        chain_id: ChainId,
249        blob: BlobContent,
250    ) -> Result<ChainInfoResponse, NodeError> {
251        self.query(RpcMessage::HandlePendingBlob(Box::new((chain_id, blob))))
252            .await
253    }
254
255    async fn download_certificate(
256        &self,
257        hash: CryptoHash,
258    ) -> Result<ConfirmedBlockCertificate, NodeError> {
259        Ok(self
260            .download_certificates(vec![hash])
261            .await?
262            .into_iter()
263            .next()
264            .unwrap()) // UNWRAP: We know there is exactly one certificate, otherwise we would have an error.
265    }
266
267    async fn download_certificates(
268        &self,
269        hashes: Vec<CryptoHash>,
270    ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
271        let certificates = self
272            .query::<Vec<ConfirmedBlockCertificate>>(RpcMessage::DownloadCertificates(
273                hashes.clone(),
274            ))
275            .await?;
276
277        if certificates.len() != hashes.len() {
278            let missing_hashes: Vec<CryptoHash> = hashes
279                .into_iter()
280                .filter(|hash| !certificates.iter().any(|cert| cert.hash() == *hash))
281                .collect();
282            Err(NodeError::MissingCertificates(missing_hashes))
283        } else {
284            Ok(certificates)
285        }
286    }
287
288    async fn download_certificates_by_heights(
289        &self,
290        chain_id: ChainId,
291        heights: Vec<BlockHeight>,
292    ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
293        let expected_count = heights.len();
294        let certificates: Vec<ConfirmedBlockCertificate> = self
295            .query(RpcMessage::DownloadCertificatesByHeights(
296                chain_id,
297                heights.clone(),
298            ))
299            .await?;
300
301        if certificates.len() < expected_count {
302            return Err(NodeError::MissingCertificatesByHeights { chain_id, heights });
303        }
304        Ok(certificates)
305    }
306
307    async fn blob_last_used_by(&self, blob_id: BlobId) -> Result<CryptoHash, NodeError> {
308        self.query(RpcMessage::BlobLastUsedBy(Box::new(blob_id)))
309            .await
310    }
311
312    async fn blob_last_used_by_certificate(
313        &self,
314        blob_id: BlobId,
315    ) -> Result<ConfirmedBlockCertificate, NodeError> {
316        self.query::<ConfirmedBlockCertificate>(RpcMessage::BlobLastUsedByCertificate(Box::new(
317            blob_id,
318        )))
319        .await
320    }
321
322    async fn missing_blob_ids(&self, blob_ids: Vec<BlobId>) -> Result<Vec<BlobId>, NodeError> {
323        self.query(RpcMessage::MissingBlobIds(blob_ids)).await
324    }
325
326    async fn event_block_heights(
327        &self,
328        event_ids: Vec<EventId>,
329    ) -> Result<Vec<Option<BlockHeight>>, NodeError> {
330        self.query(RpcMessage::EventBlockHeights(event_ids)).await
331    }
332
333    async fn get_shard_info(
334        &self,
335        chain_id: ChainId,
336    ) -> Result<linera_core::data_types::ShardInfo, NodeError> {
337        let rpc_shard_info: crate::message::ShardInfo =
338            self.query(RpcMessage::ShardInfoQuery(chain_id)).await?;
339        Ok(linera_core::data_types::ShardInfo {
340            shard_id: rpc_shard_info.shard_id,
341            total_shards: rpc_shard_info.total_shards,
342        })
343    }
344}