Skip to main content

linera_rpc/simple/
client.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2// Copyright (c) Zefchain Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use futures::{sink::SinkExt, stream::StreamExt};
6use linera_base::{
7    crypto::CryptoHash,
8    data_types::{BlobContent, BlockHeight, NetworkDescription},
9    identifiers::{BlobId, ChainId, EventId},
10    time::{timer, Duration},
11};
12use linera_chain::{
13    data_types::BlockProposal,
14    types::{
15        ConfirmedBlockCertificate, LiteCertificate, TimeoutCertificate, ValidatedBlockCertificate,
16    },
17};
18use linera_core::{
19    data_types::{ChainInfoQuery, ChainInfoResponse},
20    node::{BlobStream, CrossChainMessageDelivery, NodeError, NotificationStream, ValidatorNode},
21};
22use linera_storage::Arc as CacheArc;
23use linera_version::VersionInfo;
24
25use super::{codec, transport::TransportProtocol};
26use crate::{
27    config::ValidatorPublicNetworkPreConfig, HandleConfirmedCertificateRequest,
28    HandleLiteCertRequest, HandleTimeoutCertificateRequest, HandleValidatedCertificateRequest,
29    RpcMessage,
30};
31
32#[derive(Clone)]
33pub struct SimpleClient {
34    network: ValidatorPublicNetworkPreConfig<TransportProtocol>,
35    send_timeout: Duration,
36    recv_timeout: Duration,
37}
38
39impl SimpleClient {
40    pub(crate) fn new(
41        network: ValidatorPublicNetworkPreConfig<TransportProtocol>,
42        send_timeout: Duration,
43        recv_timeout: Duration,
44    ) -> Self {
45        Self {
46            network,
47            send_timeout,
48            recv_timeout,
49        }
50    }
51
52    async fn send_recv_internal(&self, message: RpcMessage) -> Result<RpcMessage, codec::Error> {
53        let address = format!("{}:{}", self.network.host, self.network.port);
54        let mut stream = self.network.protocol.connect(address).await?;
55        // Send message
56        timer::timeout(self.send_timeout, stream.send(message))
57            .await
58            .map_err(|timeout| codec::Error::IoError(timeout.into()))??;
59        // Wait for reply
60        timer::timeout(self.recv_timeout, stream.next())
61            .await
62            .map_err(|timeout| codec::Error::IoError(timeout.into()))?
63            .transpose()?
64            .ok_or_else(|| codec::Error::IoError(std::io::ErrorKind::UnexpectedEof.into()))
65    }
66
67    async fn query<Response>(&self, query: RpcMessage) -> Result<Response, Response::Error>
68    where
69        Response: TryFrom<RpcMessage>,
70        Response::Error: From<codec::Error>,
71    {
72        self.send_recv_internal(query).await?.try_into()
73    }
74}
75
76impl ValidatorNode for SimpleClient {
77    type NotificationStream = NotificationStream;
78
79    fn address(&self) -> String {
80        format!(
81            "{}://{}:{}",
82            self.network.protocol, self.network.host, self.network.port
83        )
84    }
85
86    /// Initiates a new block.
87    async fn handle_block_proposal(
88        &self,
89        proposal: BlockProposal,
90    ) -> Result<ChainInfoResponse, NodeError> {
91        let request = RpcMessage::BlockProposal(Box::new(proposal));
92        self.query(request).await
93    }
94
95    /// Processes a lite certificate.
96    async fn handle_lite_certificate(
97        &self,
98        certificate: LiteCertificate<'_>,
99        delivery: CrossChainMessageDelivery,
100    ) -> Result<ChainInfoResponse, NodeError> {
101        let wait_for_outgoing_messages = delivery.wait_for_outgoing_messages();
102        let request = RpcMessage::LiteCertificate(Box::new(HandleLiteCertRequest {
103            certificate: certificate.cloned(),
104            wait_for_outgoing_messages,
105        }));
106        self.query(request).await
107    }
108
109    /// Processes a validated certificate.
110    async fn handle_validated_certificate(
111        &self,
112        certificate: ValidatedBlockCertificate,
113    ) -> Result<ChainInfoResponse, NodeError> {
114        let request = HandleValidatedCertificateRequest { certificate };
115        let request = RpcMessage::ValidatedCertificate(Box::new(request));
116        self.query(request).await
117    }
118
119    /// Processes a confirmed certificate.
120    async fn handle_confirmed_certificate(
121        &self,
122        certificate: CacheArc<ConfirmedBlockCertificate>,
123        delivery: CrossChainMessageDelivery,
124    ) -> Result<ChainInfoResponse, NodeError> {
125        let wait_for_outgoing_messages = delivery.wait_for_outgoing_messages();
126        let request = HandleConfirmedCertificateRequest {
127            certificate: CacheArc::unwrap_or_clone(certificate),
128            wait_for_outgoing_messages,
129        };
130        let request = RpcMessage::ConfirmedCertificate(Box::new(request));
131        self.query(request).await
132    }
133
134    /// Processes a timeout certificate.
135    async fn handle_timeout_certificate(
136        &self,
137        certificate: TimeoutCertificate,
138    ) -> Result<ChainInfoResponse, NodeError> {
139        let request = HandleTimeoutCertificateRequest { certificate };
140        let request = RpcMessage::TimeoutCertificate(Box::new(request));
141        self.query(request).await
142    }
143
144    /// Handles information queries for this chain.
145    async fn handle_chain_info_query(
146        &self,
147        query: ChainInfoQuery,
148    ) -> Result<ChainInfoResponse, NodeError> {
149        let request = RpcMessage::ChainInfoQuery(Box::new(query));
150        self.query(request).await
151    }
152
153    async fn subscribe(&self, chains: Vec<ChainId>) -> Result<NotificationStream, NodeError> {
154        let mut stream = self
155            .network
156            .protocol
157            .connect((self.network.host.clone(), self.network.port))
158            .await
159            .map_err(|e| NodeError::ClientIoError {
160                error: e.to_string(),
161            })?;
162        // Send subscription request
163        timer::timeout(
164            self.send_timeout,
165            stream.send(RpcMessage::SubscribeNotifications(chains)),
166        )
167        .await
168        .map_err(|timeout| NodeError::ClientIoError {
169            error: timeout.to_string(),
170        })?
171        .map_err(|e| NodeError::ClientIoError {
172            error: e.to_string(),
173        })?;
174        // Return a stream that reads notifications from the connection
175        let notification_stream = stream.filter_map(|result| async {
176            match result {
177                Ok(RpcMessage::Notification(notification)) => Some(*notification),
178                _ => None,
179            }
180        });
181        Ok(Box::pin(notification_stream) as NotificationStream)
182    }
183
184    async fn get_version_info(&self) -> Result<VersionInfo, NodeError> {
185        self.query(RpcMessage::VersionInfoQuery).await
186    }
187
188    async fn get_network_description(&self) -> Result<NetworkDescription, NodeError> {
189        self.query(RpcMessage::NetworkDescriptionQuery).await
190    }
191
192    async fn upload_blob(&self, content: BlobContent) -> Result<BlobId, NodeError> {
193        self.query(RpcMessage::UploadBlob(Box::new(content))).await
194    }
195
196    async fn download_blob(&self, blob_id: BlobId) -> Result<BlobContent, NodeError> {
197        self.query(RpcMessage::DownloadBlob(Box::new(blob_id)))
198            .await
199    }
200
201    async fn download_blobs(&self, blob_ids: Vec<BlobId>) -> Result<BlobStream, NodeError> {
202        let mut stream = self
203            .network
204            .protocol
205            .connect((self.network.host.clone(), self.network.port))
206            .await
207            .map_err(|e| NodeError::ClientIoError {
208                error: e.to_string(),
209            })?;
210        timer::timeout(
211            self.send_timeout,
212            stream.send(RpcMessage::DownloadBlobs(blob_ids)),
213        )
214        .await
215        .map_err(|timeout| NodeError::ClientIoError {
216            error: timeout.to_string(),
217        })?
218        .map_err(|e| NodeError::ClientIoError {
219            error: e.to_string(),
220        })?;
221        let blob_stream = stream.filter_map(|result| async {
222            match result {
223                Ok(RpcMessage::DownloadBlobResponse(blob)) => Some(Ok(*blob)),
224                Ok(RpcMessage::Error(err)) => Some(Err(*err)),
225                Ok(_) => Some(Err(NodeError::UnexpectedMessage)),
226                Err(e) => Some(Err(NodeError::ClientIoError {
227                    error: e.to_string(),
228                })),
229            }
230        });
231        Ok(Box::pin(blob_stream))
232    }
233
234    async fn download_pending_blob(
235        &self,
236        chain_id: ChainId,
237        blob_id: BlobId,
238    ) -> Result<BlobContent, NodeError> {
239        self.query(RpcMessage::DownloadPendingBlob(Box::new((
240            chain_id, blob_id,
241        ))))
242        .await
243    }
244
245    async fn handle_pending_blob(
246        &self,
247        chain_id: ChainId,
248        blob: BlobContent,
249    ) -> Result<ChainInfoResponse, NodeError> {
250        self.query(RpcMessage::HandlePendingBlob(Box::new((chain_id, blob))))
251            .await
252    }
253
254    async fn download_certificate(
255        &self,
256        hash: CryptoHash,
257    ) -> Result<ConfirmedBlockCertificate, NodeError> {
258        Ok(self
259            .download_certificates(vec![hash])
260            .await?
261            .into_iter()
262            .next()
263            .unwrap()) // UNWRAP: We know there is exactly one certificate, otherwise we would have an error.
264    }
265
266    async fn download_certificates(
267        &self,
268        hashes: Vec<CryptoHash>,
269    ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
270        let certificates = self
271            .query::<Vec<ConfirmedBlockCertificate>>(RpcMessage::DownloadCertificates(
272                hashes.clone(),
273            ))
274            .await?;
275
276        if certificates.len() != hashes.len() {
277            let missing_hashes: Vec<CryptoHash> = hashes
278                .into_iter()
279                .filter(|hash| !certificates.iter().any(|cert| cert.hash() == *hash))
280                .collect();
281            Err(NodeError::MissingCertificates(missing_hashes))
282        } else {
283            Ok(certificates)
284        }
285    }
286
287    async fn download_certificates_by_heights(
288        &self,
289        chain_id: ChainId,
290        heights: Vec<BlockHeight>,
291    ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
292        let expected_count = heights.len();
293        let certificates: Vec<ConfirmedBlockCertificate> = self
294            .query(RpcMessage::DownloadCertificatesByHeights(
295                chain_id,
296                heights.clone(),
297            ))
298            .await?;
299
300        if certificates.len() < expected_count {
301            return Err(NodeError::MissingCertificatesByHeights { chain_id, heights });
302        }
303        Ok(certificates)
304    }
305
306    async fn blob_last_used_by(&self, blob_id: BlobId) -> Result<CryptoHash, NodeError> {
307        self.query(RpcMessage::BlobLastUsedBy(Box::new(blob_id)))
308            .await
309    }
310
311    async fn blob_last_used_by_certificate(
312        &self,
313        blob_id: BlobId,
314    ) -> Result<ConfirmedBlockCertificate, NodeError> {
315        self.query::<ConfirmedBlockCertificate>(RpcMessage::BlobLastUsedByCertificate(Box::new(
316            blob_id,
317        )))
318        .await
319    }
320
321    async fn missing_blob_ids(&self, blob_ids: Vec<BlobId>) -> Result<Vec<BlobId>, NodeError> {
322        self.query(RpcMessage::MissingBlobIds(blob_ids)).await
323    }
324
325    async fn event_block_heights(
326        &self,
327        event_ids: Vec<EventId>,
328    ) -> Result<Vec<Option<BlockHeight>>, NodeError> {
329        self.query(RpcMessage::EventBlockHeights(event_ids)).await
330    }
331
332    async fn get_shard_info(
333        &self,
334        chain_id: ChainId,
335    ) -> Result<linera_core::data_types::ShardInfo, NodeError> {
336        let rpc_shard_info: crate::message::ShardInfo =
337            self.query(RpcMessage::ShardInfoQuery(chain_id)).await?;
338        Ok(linera_core::data_types::ShardInfo {
339            shard_id: rpc_shard_info.shard_id,
340            total_shards: rpc_shard_info.total_shards,
341        })
342    }
343}