1use std::sync::Arc;
6
7use futures::{sink::SinkExt, stream::StreamExt};
8use linera_base::{
9 crypto::CryptoHash,
10 data_types::{BlobContent, BlockHeight, NetworkDescription},
11 identifiers::{BlobId, ChainId, EventId},
12 time::{timer, Duration},
13};
14use linera_chain::{
15 data_types::BlockProposal,
16 types::{
17 ConfirmedBlockCertificate, LiteCertificate, TimeoutCertificate, ValidatedBlockCertificate,
18 },
19};
20use linera_core::{
21 data_types::{ChainInfoQuery, ChainInfoResponse},
22 node::{BlobStream, CrossChainMessageDelivery, NodeError, NotificationStream, ValidatorNode},
23};
24use linera_version::VersionInfo;
25
26use super::{codec, transport::TransportProtocol};
27use crate::{
28 config::ValidatorPublicNetworkPreConfig, HandleConfirmedCertificateRequest,
29 HandleLiteCertRequest, HandleTimeoutCertificateRequest, HandleValidatedCertificateRequest,
30 RpcMessage,
31};
32
33#[derive(Clone)]
34pub struct SimpleClient {
35 network: ValidatorPublicNetworkPreConfig<TransportProtocol>,
36 send_timeout: Duration,
37 recv_timeout: Duration,
38}
39
40impl SimpleClient {
41 pub(crate) fn new(
42 network: ValidatorPublicNetworkPreConfig<TransportProtocol>,
43 send_timeout: Duration,
44 recv_timeout: Duration,
45 ) -> Self {
46 Self {
47 network,
48 send_timeout,
49 recv_timeout,
50 }
51 }
52
53 async fn send_recv_internal(&self, message: RpcMessage) -> Result<RpcMessage, codec::Error> {
54 let address = format!("{}:{}", self.network.host, self.network.port);
55 let mut stream = self.network.protocol.connect(address).await?;
56 timer::timeout(self.send_timeout, stream.send(message))
58 .await
59 .map_err(|timeout| codec::Error::IoError(timeout.into()))??;
60 timer::timeout(self.recv_timeout, stream.next())
62 .await
63 .map_err(|timeout| codec::Error::IoError(timeout.into()))?
64 .transpose()?
65 .ok_or_else(|| codec::Error::IoError(std::io::ErrorKind::UnexpectedEof.into()))
66 }
67
68 async fn query<Response>(&self, query: RpcMessage) -> Result<Response, Response::Error>
69 where
70 Response: TryFrom<RpcMessage>,
71 Response::Error: From<codec::Error>,
72 {
73 self.send_recv_internal(query).await?.try_into()
74 }
75}
76
77impl ValidatorNode for SimpleClient {
78 type NotificationStream = NotificationStream;
79
80 fn address(&self) -> String {
81 format!(
82 "{}://{}:{}",
83 self.network.protocol, self.network.host, self.network.port
84 )
85 }
86
87 async fn handle_block_proposal(
89 &self,
90 proposal: BlockProposal,
91 ) -> Result<ChainInfoResponse, NodeError> {
92 let request = RpcMessage::BlockProposal(Box::new(proposal));
93 self.query(request).await
94 }
95
96 async fn handle_lite_certificate(
98 &self,
99 certificate: LiteCertificate<'_>,
100 delivery: CrossChainMessageDelivery,
101 ) -> Result<ChainInfoResponse, NodeError> {
102 let wait_for_outgoing_messages = delivery.wait_for_outgoing_messages();
103 let request = RpcMessage::LiteCertificate(Box::new(HandleLiteCertRequest {
104 certificate: certificate.cloned(),
105 wait_for_outgoing_messages,
106 }));
107 self.query(request).await
108 }
109
110 async fn handle_validated_certificate(
112 &self,
113 certificate: ValidatedBlockCertificate,
114 ) -> Result<ChainInfoResponse, NodeError> {
115 let request = HandleValidatedCertificateRequest { certificate };
116 let request = RpcMessage::ValidatedCertificate(Box::new(request));
117 self.query(request).await
118 }
119
120 async fn handle_confirmed_certificate(
122 &self,
123 certificate: Arc<ConfirmedBlockCertificate>,
124 delivery: CrossChainMessageDelivery,
125 ) -> Result<ChainInfoResponse, NodeError> {
126 let wait_for_outgoing_messages = delivery.wait_for_outgoing_messages();
127 let request = HandleConfirmedCertificateRequest {
128 certificate: Arc::unwrap_or_clone(certificate),
129 wait_for_outgoing_messages,
130 };
131 let request = RpcMessage::ConfirmedCertificate(Box::new(request));
132 self.query(request).await
133 }
134
135 async fn handle_timeout_certificate(
137 &self,
138 certificate: TimeoutCertificate,
139 ) -> Result<ChainInfoResponse, NodeError> {
140 let request = HandleTimeoutCertificateRequest { certificate };
141 let request = RpcMessage::TimeoutCertificate(Box::new(request));
142 self.query(request).await
143 }
144
145 async fn handle_chain_info_query(
147 &self,
148 query: ChainInfoQuery,
149 ) -> Result<ChainInfoResponse, NodeError> {
150 let request = RpcMessage::ChainInfoQuery(Box::new(query));
151 self.query(request).await
152 }
153
154 async fn subscribe(&self, chains: Vec<ChainId>) -> Result<NotificationStream, NodeError> {
155 let mut stream = self
156 .network
157 .protocol
158 .connect((self.network.host.clone(), self.network.port))
159 .await
160 .map_err(|e| NodeError::ClientIoError {
161 error: e.to_string(),
162 })?;
163 timer::timeout(
165 self.send_timeout,
166 stream.send(RpcMessage::SubscribeNotifications(chains)),
167 )
168 .await
169 .map_err(|timeout| NodeError::ClientIoError {
170 error: timeout.to_string(),
171 })?
172 .map_err(|e| NodeError::ClientIoError {
173 error: e.to_string(),
174 })?;
175 let notification_stream = stream.filter_map(|result| async {
177 match result {
178 Ok(RpcMessage::Notification(notification)) => Some(*notification),
179 _ => None,
180 }
181 });
182 Ok(Box::pin(notification_stream) as NotificationStream)
183 }
184
185 async fn get_version_info(&self) -> Result<VersionInfo, NodeError> {
186 self.query(RpcMessage::VersionInfoQuery).await
187 }
188
189 async fn get_network_description(&self) -> Result<NetworkDescription, NodeError> {
190 self.query(RpcMessage::NetworkDescriptionQuery).await
191 }
192
193 async fn upload_blob(&self, content: BlobContent) -> Result<BlobId, NodeError> {
194 self.query(RpcMessage::UploadBlob(Box::new(content))).await
195 }
196
197 async fn download_blob(&self, blob_id: BlobId) -> Result<BlobContent, NodeError> {
198 self.query(RpcMessage::DownloadBlob(Box::new(blob_id)))
199 .await
200 }
201
202 async fn download_blobs(&self, blob_ids: Vec<BlobId>) -> Result<BlobStream, NodeError> {
203 let mut stream = self
204 .network
205 .protocol
206 .connect((self.network.host.clone(), self.network.port))
207 .await
208 .map_err(|e| NodeError::ClientIoError {
209 error: e.to_string(),
210 })?;
211 timer::timeout(
212 self.send_timeout,
213 stream.send(RpcMessage::DownloadBlobs(blob_ids)),
214 )
215 .await
216 .map_err(|timeout| NodeError::ClientIoError {
217 error: timeout.to_string(),
218 })?
219 .map_err(|e| NodeError::ClientIoError {
220 error: e.to_string(),
221 })?;
222 let blob_stream = stream.filter_map(|result| async {
223 match result {
224 Ok(RpcMessage::DownloadBlobResponse(blob)) => Some(Ok(*blob)),
225 Ok(RpcMessage::Error(err)) => Some(Err(*err)),
226 Ok(_) => Some(Err(NodeError::UnexpectedMessage)),
227 Err(e) => Some(Err(NodeError::ClientIoError {
228 error: e.to_string(),
229 })),
230 }
231 });
232 Ok(Box::pin(blob_stream))
233 }
234
235 async fn download_pending_blob(
236 &self,
237 chain_id: ChainId,
238 blob_id: BlobId,
239 ) -> Result<BlobContent, NodeError> {
240 self.query(RpcMessage::DownloadPendingBlob(Box::new((
241 chain_id, blob_id,
242 ))))
243 .await
244 }
245
246 async fn handle_pending_blob(
247 &self,
248 chain_id: ChainId,
249 blob: BlobContent,
250 ) -> Result<ChainInfoResponse, NodeError> {
251 self.query(RpcMessage::HandlePendingBlob(Box::new((chain_id, blob))))
252 .await
253 }
254
255 async fn download_certificate(
256 &self,
257 hash: CryptoHash,
258 ) -> Result<ConfirmedBlockCertificate, NodeError> {
259 Ok(self
260 .download_certificates(vec![hash])
261 .await?
262 .into_iter()
263 .next()
264 .unwrap()) }
266
267 async fn download_certificates(
268 &self,
269 hashes: Vec<CryptoHash>,
270 ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
271 let certificates = self
272 .query::<Vec<ConfirmedBlockCertificate>>(RpcMessage::DownloadCertificates(
273 hashes.clone(),
274 ))
275 .await?;
276
277 if certificates.len() != hashes.len() {
278 let missing_hashes: Vec<CryptoHash> = hashes
279 .into_iter()
280 .filter(|hash| !certificates.iter().any(|cert| cert.hash() == *hash))
281 .collect();
282 Err(NodeError::MissingCertificates(missing_hashes))
283 } else {
284 Ok(certificates)
285 }
286 }
287
288 async fn download_certificates_by_heights(
289 &self,
290 chain_id: ChainId,
291 heights: Vec<BlockHeight>,
292 ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
293 let expected_count = heights.len();
294 let certificates: Vec<ConfirmedBlockCertificate> = self
295 .query(RpcMessage::DownloadCertificatesByHeights(
296 chain_id,
297 heights.clone(),
298 ))
299 .await?;
300
301 if certificates.len() < expected_count {
302 return Err(NodeError::MissingCertificatesByHeights { chain_id, heights });
303 }
304 Ok(certificates)
305 }
306
307 async fn blob_last_used_by(&self, blob_id: BlobId) -> Result<CryptoHash, NodeError> {
308 self.query(RpcMessage::BlobLastUsedBy(Box::new(blob_id)))
309 .await
310 }
311
312 async fn blob_last_used_by_certificate(
313 &self,
314 blob_id: BlobId,
315 ) -> Result<ConfirmedBlockCertificate, NodeError> {
316 self.query::<ConfirmedBlockCertificate>(RpcMessage::BlobLastUsedByCertificate(Box::new(
317 blob_id,
318 )))
319 .await
320 }
321
322 async fn missing_blob_ids(&self, blob_ids: Vec<BlobId>) -> Result<Vec<BlobId>, NodeError> {
323 self.query(RpcMessage::MissingBlobIds(blob_ids)).await
324 }
325
326 async fn event_block_heights(
327 &self,
328 event_ids: Vec<EventId>,
329 ) -> Result<Vec<Option<BlockHeight>>, NodeError> {
330 self.query(RpcMessage::EventBlockHeights(event_ids)).await
331 }
332
333 async fn get_shard_info(
334 &self,
335 chain_id: ChainId,
336 ) -> Result<linera_core::data_types::ShardInfo, NodeError> {
337 let rpc_shard_info: crate::message::ShardInfo =
338 self.query(RpcMessage::ShardInfoQuery(chain_id)).await?;
339 Ok(linera_core::data_types::ShardInfo {
340 shard_id: rpc_shard_info.shard_id,
341 total_shards: rpc_shard_info.total_shards,
342 })
343 }
344}