Skip to main content

linera_core/unit_tests/
test_utils.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(clippy::cast_possible_truncation)]
5
6use std::{
7    collections::{BTreeMap, HashMap, HashSet},
8    sync::Arc,
9    time::Duration,
10    vec,
11};
12
13use async_trait::async_trait;
14use futures::{
15    future::Either,
16    lock::{Mutex, MutexGuard},
17    Future,
18};
19use linera_base::{
20    crypto::{
21        AccountPublicKey, CryptoHash, ValidatorKeypair, ValidatorPublicKey, ValidatorSecretKey,
22    },
23    data_types::*,
24    identifiers::{AccountOwner, BlobId, ChainId, EventId},
25    ownership::ChainOwnership,
26};
27use linera_chain::{
28    data_types::BlockProposal,
29    types::{
30        CertificateKind, ConfirmedBlock, ConfirmedBlockCertificate, GenericCertificate,
31        LiteCertificate, Timeout, ValidatedBlock,
32    },
33};
34use linera_execution::{committee::Committee, ResourceControlPolicy, WasmRuntime};
35use linera_storage::{Arc as CacheArc, DbStorage, ResultReadCertificates, Storage, TestClock};
36#[cfg(all(not(target_arch = "wasm32"), feature = "storage-service"))]
37use linera_storage_service::client::StorageServiceDatabase;
38use linera_version::VersionInfo;
39#[cfg(feature = "scylladb")]
40use linera_views::scylla_db::ScyllaDbDatabase;
41use linera_views::{
42    memory::MemoryDatabase,
43    random::generate_test_namespace,
44    store::{KeyValueStore, TestKeyValueDatabase},
45};
46use tokio::sync::oneshot;
47use tokio_stream::wrappers::UnboundedReceiverStream;
48#[cfg(feature = "rocksdb")]
49use {
50    linera_views::rocks_db::RocksDbDatabase,
51    tokio::sync::{Semaphore, SemaphorePermit},
52};
53
54use crate::{
55    chain_worker::ChainWorkerConfig,
56    client::{chain_client, Client},
57    data_types::*,
58    environment::{TestSigner, TestWallet},
59    node::{
60        CrossChainMessageDelivery, NodeError, NotificationStream, ValidatorNode,
61        ValidatorNodeProvider,
62    },
63    notifier::ChannelNotifier,
64    worker::{
65        Notification, ProcessableCertificate, WorkerState, DEFAULT_BLOCK_CACHE_SIZE,
66        DEFAULT_EXECUTION_STATE_CACHE_SIZE,
67    },
68};
69
70#[derive(Debug, PartialEq, Clone, Copy)]
71pub enum FaultType {
72    Honest,
73    Offline,
74    OfflineWithInfo,
75    NoChains,
76    DontSendConfirmVote,
77    DontProcessValidated,
78    DontSendValidateVote,
79}
80
81/// A validator used for testing. "Faulty" validators ignore block proposals (but not
82/// certificates or info queries) and have the wrong initial balance for all chains.
83///
84/// All methods are executed in spawned Tokio tasks, so that canceling a client task doesn't cause
85/// the validator's tasks to be canceled: In a real network, a validator also wouldn't cancel
86/// tasks if the client stopped waiting for the response.
87struct LocalValidator<S>
88where
89    S: Storage,
90{
91    state: WorkerState<S>,
92    notifier: Arc<ChannelNotifier<Notification>>,
93}
94
95#[derive(Clone)]
96pub struct LocalValidatorClient<S>
97where
98    S: Storage,
99{
100    public_key: ValidatorPublicKey,
101    client: Arc<Mutex<LocalValidator<S>>>,
102    fault_type: FaultType,
103}
104
105impl<S> ValidatorNode for LocalValidatorClient<S>
106where
107    S: Storage + Clone + Send + Sync + 'static,
108{
109    type NotificationStream = NotificationStream;
110
111    fn address(&self) -> String {
112        format!("local:{}", self.public_key)
113    }
114
115    async fn handle_block_proposal(
116        &self,
117        proposal: BlockProposal,
118    ) -> Result<ChainInfoResponse, NodeError> {
119        self.spawn_and_receive(move |validator, sender| {
120            validator.do_handle_block_proposal(proposal, sender)
121        })
122        .await
123    }
124
125    async fn handle_lite_certificate(
126        &self,
127        certificate: LiteCertificate<'_>,
128        _delivery: CrossChainMessageDelivery,
129    ) -> Result<ChainInfoResponse, NodeError> {
130        let certificate = certificate.cloned();
131        self.spawn_and_receive(move |validator, sender| {
132            validator.do_handle_lite_certificate(certificate, sender)
133        })
134        .await
135    }
136
137    async fn handle_timeout_certificate(
138        &self,
139        certificate: GenericCertificate<Timeout>,
140    ) -> Result<ChainInfoResponse, NodeError> {
141        self.spawn_and_receive(move |validator, sender| {
142            validator.do_handle_certificate(certificate, sender)
143        })
144        .await
145    }
146
147    async fn handle_validated_certificate(
148        &self,
149        certificate: GenericCertificate<ValidatedBlock>,
150    ) -> Result<ChainInfoResponse, NodeError> {
151        self.spawn_and_receive(move |validator, sender| {
152            validator.do_handle_certificate(certificate, sender)
153        })
154        .await
155    }
156
157    async fn handle_confirmed_certificate(
158        &self,
159        certificate: CacheArc<GenericCertificate<ConfirmedBlock>>,
160        _delivery: CrossChainMessageDelivery,
161    ) -> Result<ChainInfoResponse, NodeError> {
162        self.spawn_and_receive(move |validator, sender| {
163            validator.do_handle_certificate(CacheArc::unwrap_or_clone(certificate), sender)
164        })
165        .await
166    }
167
168    async fn handle_chain_info_query(
169        &self,
170        query: ChainInfoQuery,
171    ) -> Result<ChainInfoResponse, NodeError> {
172        self.spawn_and_receive(move |validator, sender| {
173            validator.do_handle_chain_info_query(query, sender)
174        })
175        .await
176    }
177
178    async fn subscribe(&self, chains: Vec<ChainId>) -> Result<NotificationStream, NodeError> {
179        self.spawn_and_receive(move |validator, sender| validator.do_subscribe(chains, sender))
180            .await
181    }
182
183    async fn get_version_info(&self) -> Result<VersionInfo, NodeError> {
184        Ok(Default::default())
185    }
186
187    async fn get_network_description(&self) -> Result<NetworkDescription, NodeError> {
188        Ok(self
189            .client
190            .lock()
191            .await
192            .state
193            .storage_client()
194            .read_network_description()
195            .await
196            .transpose()
197            .ok_or_else(|| NodeError::ViewError {
198                error: "missing NetworkDescription".to_owned(),
199            })??)
200    }
201
202    async fn upload_blob(&self, content: BlobContent) -> Result<BlobId, NodeError> {
203        self.spawn_and_receive(move |validator, sender| validator.do_upload_blob(content, sender))
204            .await
205    }
206
207    async fn download_blob(&self, blob_id: BlobId) -> Result<BlobContent, NodeError> {
208        self.spawn_and_receive(move |validator, sender| validator.do_download_blob(blob_id, sender))
209            .await
210    }
211
212    async fn download_blobs(
213        &self,
214        blob_ids: Vec<BlobId>,
215    ) -> Result<crate::node::BlobStream, NodeError> {
216        let this = self.clone();
217        let stream = futures::stream::unfold(blob_ids.into_iter(), move |mut iter| {
218            let this = this.clone();
219            async move {
220                let blob_id = iter.next()?;
221                let result = this.download_blob(blob_id).await;
222                Some((result, iter))
223            }
224        });
225        Ok(Box::pin(stream))
226    }
227
228    async fn download_pending_blob(
229        &self,
230        chain_id: ChainId,
231        blob_id: BlobId,
232    ) -> Result<BlobContent, NodeError> {
233        self.spawn_and_receive(move |validator, sender| {
234            validator.do_download_pending_blob(chain_id, blob_id, sender)
235        })
236        .await
237    }
238
239    async fn handle_pending_blob(
240        &self,
241        chain_id: ChainId,
242        blob: BlobContent,
243    ) -> Result<ChainInfoResponse, NodeError> {
244        self.spawn_and_receive(move |validator, sender| {
245            validator.do_handle_pending_blob(chain_id, blob, sender)
246        })
247        .await
248    }
249
250    async fn download_certificate(
251        &self,
252        hash: CryptoHash,
253    ) -> Result<ConfirmedBlockCertificate, NodeError> {
254        self.spawn_and_receive(move |validator, sender| {
255            validator.do_download_certificate(hash, sender)
256        })
257        .await
258    }
259
260    async fn download_certificates(
261        &self,
262        hashes: Vec<CryptoHash>,
263    ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
264        self.spawn_and_receive(move |validator, sender| {
265            validator.do_download_certificates(hashes, sender)
266        })
267        .await
268    }
269
270    async fn download_certificates_by_heights(
271        &self,
272        chain_id: ChainId,
273        heights: Vec<BlockHeight>,
274    ) -> Result<Vec<ConfirmedBlockCertificate>, NodeError> {
275        self.spawn_and_receive(move |validator, sender| {
276            validator.do_download_certificates_by_heights(chain_id, heights, sender)
277        })
278        .await
279    }
280
281    async fn blob_last_used_by(&self, blob_id: BlobId) -> Result<CryptoHash, NodeError> {
282        self.spawn_and_receive(move |validator, sender| {
283            validator.do_blob_last_used_by(blob_id, sender)
284        })
285        .await
286    }
287
288    async fn blob_last_used_by_certificate(
289        &self,
290        blob_id: BlobId,
291    ) -> Result<ConfirmedBlockCertificate, NodeError> {
292        self.spawn_and_receive(move |validator, sender| {
293            validator.do_blob_last_used_by_certificate(blob_id, sender)
294        })
295        .await
296    }
297
298    async fn missing_blob_ids(&self, blob_ids: Vec<BlobId>) -> Result<Vec<BlobId>, NodeError> {
299        self.spawn_and_receive(move |validator, sender| {
300            validator.do_missing_blob_ids(blob_ids, sender)
301        })
302        .await
303    }
304
305    async fn event_block_heights(
306        &self,
307        event_ids: Vec<EventId>,
308    ) -> Result<Vec<Option<BlockHeight>>, NodeError> {
309        self.spawn_and_receive(move |validator, sender| {
310            validator.do_event_block_heights(event_ids, sender)
311        })
312        .await
313    }
314
315    async fn get_shard_info(
316        &self,
317        _chain_id: ChainId,
318    ) -> Result<crate::data_types::ShardInfo, NodeError> {
319        // For test purposes, return a dummy shard info
320        Ok(crate::data_types::ShardInfo {
321            shard_id: 0,
322            total_shards: 1,
323        })
324    }
325}
326
327impl<S> LocalValidatorClient<S>
328where
329    S: Storage + Clone + Send + Sync + 'static,
330{
331    fn new(public_key: ValidatorPublicKey, state: WorkerState<S>) -> Self {
332        let client = LocalValidator {
333            state,
334            notifier: Arc::new(ChannelNotifier::default()),
335        };
336        Self {
337            public_key,
338            client: Arc::new(Mutex::new(client)),
339            fault_type: FaultType::Honest,
340        }
341    }
342
343    pub fn name(&self) -> ValidatorPublicKey {
344        self.public_key
345    }
346
347    pub fn fault_type(&self) -> FaultType {
348        self.fault_type
349    }
350
351    fn set_fault_type(&mut self, fault_type: FaultType) {
352        self.fault_type = fault_type;
353    }
354
355    /// Obtains the basic `ChainInfo` data for the local validator chain, with chain manager values.
356    pub async fn chain_info_with_manager_values(
357        &mut self,
358        chain_id: ChainId,
359    ) -> Result<Box<ChainInfo>, NodeError> {
360        let query = ChainInfoQuery::new(chain_id).with_manager_values();
361        let response = self.handle_chain_info_query(query).await?;
362        Ok(response.info)
363    }
364
365    /// Executes the future produced by `f` in a new thread in a new Tokio runtime.
366    /// Returns the value that the future puts into the sender.
367    async fn spawn_and_receive<F, R, T>(&self, f: F) -> T
368    where
369        T: Send + 'static,
370        R: Future<Output = Result<(), T>> + Send,
371        F: FnOnce(Self, oneshot::Sender<T>) -> R + Send + 'static,
372    {
373        let validator = self.clone();
374        let (sender, receiver) = oneshot::channel();
375        tokio::spawn(async move {
376            if f(validator, sender).await.is_err() {
377                tracing::debug!("result could not be sent");
378            }
379        });
380        receiver.await.unwrap()
381    }
382
383    async fn do_handle_block_proposal(
384        self,
385        proposal: BlockProposal,
386        sender: oneshot::Sender<Result<ChainInfoResponse, NodeError>>,
387    ) -> Result<(), Result<ChainInfoResponse, NodeError>> {
388        let result = match self.fault_type {
389            FaultType::Offline | FaultType::OfflineWithInfo => Err(NodeError::ClientIoError {
390                error: "offline".to_string(),
391            }),
392            FaultType::NoChains => Err(NodeError::InactiveChain(proposal.content.block.chain_id)),
393            FaultType::DontSendValidateVote
394            | FaultType::Honest
395            | FaultType::DontSendConfirmVote
396            | FaultType::DontProcessValidated => {
397                let (response_result, _actions) = self
398                    .client
399                    .lock()
400                    .await
401                    .state
402                    .handle_block_proposal(proposal)
403                    .await;
404                let result = response_result.map_err(NodeError::from);
405                if self.fault_type == FaultType::DontSendValidateVote {
406                    Err(NodeError::ClientIoError {
407                        error: "refusing to validate".to_string(),
408                    })
409                } else {
410                    result
411                }
412            }
413        };
414        // In a local node cross-chain messages can't get lost, so we can ignore the actions here.
415        sender.send(result)
416    }
417
418    async fn do_handle_lite_certificate(
419        self,
420        certificate: LiteCertificate<'_>,
421        sender: oneshot::Sender<Result<ChainInfoResponse, NodeError>>,
422    ) -> Result<(), Result<ChainInfoResponse, NodeError>> {
423        let client = self.client.clone();
424        let validator = client.lock().await;
425        let result = async move {
426            match validator.state.full_certificate(certificate).await? {
427                Either::Left(confirmed) => {
428                    self.do_handle_certificate_internal(confirmed, &validator)
429                        .await
430                }
431                Either::Right(validated) => {
432                    self.do_handle_certificate_internal(validated, &validator)
433                        .await
434                }
435            }
436        }
437        .await;
438        sender.send(result)
439    }
440
441    async fn do_handle_certificate_internal<T: ProcessableCertificate>(
442        &self,
443        certificate: GenericCertificate<T>,
444        validator: &MutexGuard<'_, LocalValidator<S>>,
445    ) -> Result<ChainInfoResponse, NodeError> {
446        match self.fault_type {
447            FaultType::DontProcessValidated if T::KIND == CertificateKind::Validated => {
448                Err(NodeError::ClientIoError {
449                    error: "refusing to process validated block".to_string(),
450                })
451            }
452            FaultType::NoChains => Err(NodeError::InactiveChain(certificate.value().chain_id())),
453            FaultType::Honest
454            | FaultType::DontSendConfirmVote
455            | FaultType::DontProcessValidated
456            | FaultType::DontSendValidateVote => {
457                let result = validator
458                    .state
459                    .fully_handle_certificate_with_notifications(certificate, &validator.notifier)
460                    .await
461                    .map_err(Into::into);
462                if T::KIND == CertificateKind::Validated
463                    && self.fault_type == FaultType::DontSendConfirmVote
464                {
465                    Err(NodeError::ClientIoError {
466                        error: "refusing to confirm".to_string(),
467                    })
468                } else {
469                    result
470                }
471            }
472            FaultType::Offline | FaultType::OfflineWithInfo => Err(NodeError::ClientIoError {
473                error: "offline".to_string(),
474            }),
475        }
476    }
477
478    async fn do_handle_certificate<T: ProcessableCertificate>(
479        self,
480        certificate: GenericCertificate<T>,
481        sender: oneshot::Sender<Result<ChainInfoResponse, NodeError>>,
482    ) -> Result<(), Result<ChainInfoResponse, NodeError>> {
483        let validator = self.client.lock().await;
484        let result = self
485            .do_handle_certificate_internal(certificate, &validator)
486            .await;
487        sender.send(result)
488    }
489
490    async fn do_handle_chain_info_query(
491        self,
492        query: ChainInfoQuery,
493        sender: oneshot::Sender<Result<ChainInfoResponse, NodeError>>,
494    ) -> Result<(), Result<ChainInfoResponse, NodeError>> {
495        let validator = self.client.lock().await;
496        let result = match self.fault_type {
497            FaultType::Offline => Err(NodeError::ClientIoError {
498                error: "offline".to_string(),
499            }),
500            FaultType::NoChains => Err(NodeError::InactiveChain(query.chain_id)),
501            FaultType::Honest
502            | FaultType::DontSendConfirmVote
503            | FaultType::DontProcessValidated
504            | FaultType::DontSendValidateVote
505            | FaultType::OfflineWithInfo => validator
506                .state
507                .handle_chain_info_query(query)
508                .await
509                .map_err(Into::into),
510        };
511        sender.send(result)
512    }
513
514    async fn do_subscribe(
515        self,
516        chains: Vec<ChainId>,
517        sender: oneshot::Sender<Result<NotificationStream, NodeError>>,
518    ) -> Result<(), Result<NotificationStream, NodeError>> {
519        let validator = self.client.lock().await;
520        let rx = validator.notifier.subscribe(chains);
521        let stream: NotificationStream = Box::pin(UnboundedReceiverStream::new(rx));
522        sender.send(Ok(stream))
523    }
524
525    async fn do_upload_blob(
526        self,
527        content: BlobContent,
528        sender: oneshot::Sender<Result<BlobId, NodeError>>,
529    ) -> Result<(), Result<BlobId, NodeError>> {
530        let validator = self.client.lock().await;
531        let blob = Blob::new(content);
532        let id = blob.id();
533        let storage = validator.state.storage_client();
534        let result = match storage.maybe_write_blobs(&[blob]).await {
535            Ok(has_state) if has_state.first() == Some(&true) => Ok(id),
536            Ok(_) => Err(NodeError::BlobsNotFound(vec![id])),
537            Err(error) => Err(error.into()),
538        };
539        sender.send(result)
540    }
541
542    async fn do_download_blob(
543        self,
544        blob_id: BlobId,
545        sender: oneshot::Sender<Result<BlobContent, NodeError>>,
546    ) -> Result<(), Result<BlobContent, NodeError>> {
547        let validator = self.client.lock().await;
548        let blob = validator
549            .state
550            .storage_client()
551            .read_blob(blob_id)
552            .await
553            .map_err(Into::into);
554        let blob = match blob {
555            Ok(blob) => blob.ok_or_else(|| NodeError::BlobsNotFound(vec![blob_id])),
556            Err(error) => Err(error),
557        };
558        sender.send(blob.map(|blob| CacheArc::unwrap_or_clone(blob).into_content()))
559    }
560
561    async fn do_download_pending_blob(
562        self,
563        chain_id: ChainId,
564        blob_id: BlobId,
565        sender: oneshot::Sender<Result<BlobContent, NodeError>>,
566    ) -> Result<(), Result<BlobContent, NodeError>> {
567        let validator = self.client.lock().await;
568        let result = validator
569            .state
570            .download_pending_blob(chain_id, blob_id)
571            .await
572            .map_err(Into::into);
573        sender.send(result.map(|blob| blob.content().clone()))
574    }
575
576    async fn do_handle_pending_blob(
577        self,
578        chain_id: ChainId,
579        blob: BlobContent,
580        sender: oneshot::Sender<Result<ChainInfoResponse, NodeError>>,
581    ) -> Result<(), Result<ChainInfoResponse, NodeError>> {
582        let validator = self.client.lock().await;
583        let result = validator
584            .state
585            .handle_pending_blob(chain_id, Blob::new(blob))
586            .await
587            .map_err(Into::into);
588        sender.send(result)
589    }
590
591    async fn do_download_certificate(
592        self,
593        hash: CryptoHash,
594        sender: oneshot::Sender<Result<ConfirmedBlockCertificate, NodeError>>,
595    ) -> Result<(), Result<ConfirmedBlockCertificate, NodeError>> {
596        let validator = self.client.lock().await;
597        let certificate = validator
598            .state
599            .storage_client()
600            .read_certificate(hash)
601            .await
602            .map_err(Into::into);
603
604        let certificate = match certificate {
605            Err(error) => Err(error),
606            Ok(entry) => match entry {
607                Some(certificate) => Ok(CacheArc::unwrap_or_clone(certificate)),
608                None => {
609                    panic!("Missing certificate: {hash}");
610                }
611            },
612        };
613
614        sender.send(certificate)
615    }
616
617    async fn do_download_certificates(
618        self,
619        hashes: Vec<CryptoHash>,
620        sender: oneshot::Sender<Result<Vec<ConfirmedBlockCertificate>, NodeError>>,
621    ) -> Result<(), Result<Vec<ConfirmedBlockCertificate>, NodeError>> {
622        let validator = self.client.lock().await;
623        let certificates = validator
624            .state
625            .storage_client()
626            .read_certificates(&hashes)
627            .await
628            .map_err(Into::into);
629
630        let certificates = match certificates {
631            Err(error) => Err(error),
632            Ok(certificates) => match ResultReadCertificates::new(certificates, hashes) {
633                ResultReadCertificates::Certificates(certificates) => Ok(certificates),
634                ResultReadCertificates::InvalidHashes(hashes) => {
635                    panic!("Missing certificates: {hashes:?}")
636                }
637            },
638        };
639
640        sender.send(certificates)
641    }
642
643    async fn do_download_certificates_by_heights(
644        self,
645        chain_id: ChainId,
646        heights: Vec<BlockHeight>,
647        sender: oneshot::Sender<Result<Vec<ConfirmedBlockCertificate>, NodeError>>,
648    ) -> Result<(), Result<Vec<ConfirmedBlockCertificate>, NodeError>> {
649        // First, use do_handle_chain_info_query to get the certificate hashes
650        let (query_sender, query_receiver) = oneshot::channel();
651        let query = ChainInfoQuery::new(chain_id).with_sent_certificate_hashes_by_heights(heights);
652
653        let self_clone = self.clone();
654        self.do_handle_chain_info_query(query, query_sender)
655            .await
656            .expect("Failed to handle chain info query");
657
658        // Get the response from the chain info query
659        let chain_info_response = query_receiver.await.map_err(|_| {
660            Err(NodeError::ClientIoError {
661                error: "Failed to receive chain info response".to_string(),
662            })
663        })?;
664
665        let hashes = match chain_info_response {
666            Ok(response) => response.info.requested_sent_certificate_hashes,
667            Err(e) => {
668                return sender.send(Err(e));
669            }
670        };
671
672        // Now use do_download_certificates to get the actual certificates
673        let (cert_sender, cert_receiver) = oneshot::channel();
674        self_clone
675            .do_download_certificates(hashes, cert_sender)
676            .await?;
677
678        // Forward the result to the original sender
679        let result = cert_receiver.await.map_err(|_| {
680            Err(NodeError::ClientIoError {
681                error: "Failed to receive certificates".to_string(),
682            })
683        })?;
684
685        sender.send(result)
686    }
687
688    async fn do_blob_last_used_by(
689        self,
690        blob_id: BlobId,
691        sender: oneshot::Sender<Result<CryptoHash, NodeError>>,
692    ) -> Result<(), Result<CryptoHash, NodeError>> {
693        let validator = self.client.lock().await;
694        let blob_state = validator
695            .state
696            .storage_client()
697            .read_blob_state(blob_id)
698            .await
699            .map_err(Into::into);
700        let certificate_hash = match blob_state {
701            Err(err) => Err(err),
702            Ok(blob_state) => match blob_state {
703                None => Err(NodeError::BlobsNotFound(vec![blob_id])),
704                Some(blob_state) => blob_state
705                    .last_used_by
706                    .ok_or_else(|| NodeError::BlobsNotFound(vec![blob_id])),
707            },
708        };
709
710        sender.send(certificate_hash)
711    }
712
713    async fn do_blob_last_used_by_certificate(
714        self,
715        blob_id: BlobId,
716        sender: oneshot::Sender<Result<ConfirmedBlockCertificate, NodeError>>,
717    ) -> Result<(), Result<ConfirmedBlockCertificate, NodeError>> {
718        match self.blob_last_used_by(blob_id).await {
719            Ok(cert_hash) => {
720                let cert = self.download_certificate(cert_hash).await;
721                sender.send(cert)
722            }
723            Err(err) => sender.send(Err(err)),
724        }
725    }
726
727    async fn do_missing_blob_ids(
728        self,
729        blob_ids: Vec<BlobId>,
730        sender: oneshot::Sender<Result<Vec<BlobId>, NodeError>>,
731    ) -> Result<(), Result<Vec<BlobId>, NodeError>> {
732        let validator = self.client.lock().await;
733        let missing_blob_ids = validator
734            .state
735            .storage_client()
736            .missing_blobs(&blob_ids)
737            .await
738            .map_err(Into::into);
739        sender.send(missing_blob_ids)
740    }
741
742    async fn do_event_block_heights(
743        self,
744        event_ids: Vec<EventId>,
745        sender: oneshot::Sender<Result<Vec<Option<BlockHeight>>, NodeError>>,
746    ) -> Result<(), Result<Vec<Option<BlockHeight>>, NodeError>> {
747        let validator = self.client.lock().await;
748        let heights = validator
749            .state
750            .storage_client()
751            .read_event_block_heights(&event_ids)
752            .await
753            .map_err(Into::into);
754        sender.send(heights)
755    }
756}
757
758#[derive(Clone)]
759pub struct NodeProvider<S>(Arc<std::sync::Mutex<Vec<LocalValidatorClient<S>>>>)
760where
761    S: Storage;
762
763impl<S> NodeProvider<S>
764where
765    S: Storage + Clone,
766{
767    fn all_nodes(&self) -> Vec<LocalValidatorClient<S>> {
768        self.0.lock().unwrap().clone()
769    }
770}
771
772impl<S> ValidatorNodeProvider for NodeProvider<S>
773where
774    S: Storage + Clone + Send + Sync + 'static,
775{
776    type Node = LocalValidatorClient<S>;
777
778    fn make_node(&self, _name: &str) -> Result<Self::Node, NodeError> {
779        unimplemented!()
780    }
781
782    fn make_nodes_from_list<A>(
783        &self,
784        validators: impl IntoIterator<Item = (ValidatorPublicKey, A)>,
785    ) -> Result<impl Iterator<Item = (ValidatorPublicKey, Self::Node)>, NodeError>
786    where
787        A: AsRef<str>,
788    {
789        let list = self.0.lock().unwrap();
790        Ok(validators
791            .into_iter()
792            .map(|(public_key, address)| {
793                list.iter()
794                    .find(|client| client.public_key == public_key)
795                    .ok_or_else(|| NodeError::CannotResolveValidatorAddress {
796                        address: address.as_ref().to_string(),
797                    })
798                    .map(|client| (public_key, client.clone()))
799            })
800            .collect::<Result<Vec<_>, _>>()?
801            .into_iter())
802    }
803}
804
805impl<S> FromIterator<LocalValidatorClient<S>> for NodeProvider<S>
806where
807    S: Storage,
808{
809    fn from_iter<T>(iter: T) -> Self
810    where
811        T: IntoIterator<Item = LocalValidatorClient<S>>,
812    {
813        Self(Arc::new(std::sync::Mutex::new(iter.into_iter().collect())))
814    }
815}
816
817// NOTE:
818// * To communicate with a quorum of validators, chain clients iterate over a copy of
819// `validator_clients` to spawn I/O tasks.
820// * When using `LocalValidatorClient`, clients communicate with an exact quorum then stop.
821// * Most tests have 1 faulty validator out 4 so that there is exactly only 1 quorum to
822// communicate with.
823pub struct TestBuilder<B: StorageBuilder> {
824    storage_builder: B,
825    pub initial_committee: Committee,
826    admin_description: Option<ChainDescription>,
827    network_description: Option<NetworkDescription>,
828    genesis_storage_builder: GenesisStorageBuilder,
829    node_provider: NodeProvider<B::Storage>,
830    pub validator_storages: HashMap<ValidatorPublicKey, B::Storage>,
831    pub validator_key_pairs: HashMap<ValidatorPublicKey, ValidatorSecretKey>,
832    chain_client_storages: Vec<B::Storage>,
833    pub chain_owners: BTreeMap<ChainId, AccountOwner>,
834    pub signer: TestSigner,
835}
836
837#[async_trait]
838pub trait StorageBuilder {
839    type Storage: Storage + Clone + Send + Sync + 'static;
840
841    async fn build(&mut self) -> Result<Self::Storage, anyhow::Error>;
842
843    fn clock(&self) -> &TestClock;
844}
845
846#[derive(Default)]
847struct GenesisStorageBuilder {
848    accounts: Vec<GenesisAccount>,
849}
850
851struct GenesisAccount {
852    description: ChainDescription,
853    public_key: AccountPublicKey,
854}
855
856impl GenesisStorageBuilder {
857    fn add(&mut self, description: ChainDescription, public_key: AccountPublicKey) {
858        self.accounts.push(GenesisAccount {
859            description,
860            public_key,
861        })
862    }
863
864    async fn build<S>(&self, storage: S) -> S
865    where
866        S: Storage + Clone + Send + Sync + 'static,
867    {
868        for account in &self.accounts {
869            storage
870                .create_chain(account.description.clone())
871                .await
872                .unwrap();
873        }
874        storage
875    }
876}
877
878pub type ChainClient<S> = crate::client::ChainClient<crate::environment::Impl<S, NodeProvider<S>>>;
879
880impl<S: Storage + Clone + Send + Sync + 'static> ChainClient<S> {
881    /// Reads the hashed certificate values in descending order from the given hash.
882    pub async fn read_confirmed_blocks_downward(
883        &self,
884        from: CryptoHash,
885        limit: u32,
886    ) -> anyhow::Result<Vec<Arc<ConfirmedBlock>>> {
887        let mut hash = Some(from);
888        let mut values = Vec::new();
889        for _ in 0..limit {
890            let Some(next_hash) = hash else {
891                break;
892            };
893            let value = self.read_confirmed_block(next_hash).await?;
894            hash = value.block().header.previous_block_hash;
895            values.push(value);
896        }
897        Ok(values)
898    }
899}
900
901impl<B> TestBuilder<B>
902where
903    B: StorageBuilder,
904{
905    pub async fn new(
906        mut storage_builder: B,
907        count: usize,
908        with_faulty_validators: usize,
909        mut signer: TestSigner,
910    ) -> Result<Self, anyhow::Error> {
911        let mut validators = Vec::new();
912        for _ in 0..count {
913            let validator_keypair = ValidatorKeypair::generate();
914            let account_public_key = signer.generate_new();
915            validators.push((validator_keypair, account_public_key));
916        }
917        let for_committee = validators
918            .iter()
919            .map(|(validating, account)| (validating.public_key, *account))
920            .collect::<Vec<_>>();
921        let initial_committee = Committee::make_simple(for_committee);
922        let mut validator_clients = Vec::new();
923        let mut validator_storages = HashMap::new();
924        let mut validator_key_pairs = HashMap::new();
925        let mut faulty_validators = HashSet::new();
926        for (i, (validator_keypair, _account_public_key)) in validators.into_iter().enumerate() {
927            let validator_public_key = validator_keypair.public_key;
928            let storage = storage_builder.build().await?;
929            let secret_key_copy = validator_keypair.secret_key.copy();
930            let config = ChainWorkerConfig {
931                nickname: format!("Node {i}"),
932                ..ChainWorkerConfig::default()
933            }
934            .with_key_pair(Some(validator_keypair.secret_key));
935            let state = WorkerState::new(storage.clone(), config, None);
936            let mut validator = LocalValidatorClient::new(validator_public_key, state);
937            if i < with_faulty_validators {
938                faulty_validators.insert(validator_public_key);
939                validator.set_fault_type(FaultType::NoChains);
940            }
941            validator_clients.push(validator);
942            validator_storages.insert(validator_public_key, storage);
943            validator_key_pairs.insert(validator_public_key, secret_key_copy);
944        }
945        tracing::info!(
946            "Test will use the following faulty validators: {:?}",
947            faulty_validators
948        );
949        Ok(Self {
950            storage_builder,
951            initial_committee,
952            admin_description: None,
953            network_description: None,
954            genesis_storage_builder: GenesisStorageBuilder::default(),
955            node_provider: NodeProvider::from_iter(validator_clients),
956            validator_storages,
957            validator_key_pairs,
958            chain_client_storages: Vec::new(),
959            chain_owners: BTreeMap::new(),
960            signer,
961        })
962    }
963
964    pub fn with_policy(mut self, policy: ResourceControlPolicy) -> Self {
965        let validators = self.initial_committee.validators().clone();
966        self.initial_committee =
967            Committee::new(validators, policy).expect("committee votes should not overflow");
968        self
969    }
970
971    pub fn with_cross_chain_message_chunk_limit(self, limit: usize) -> Self {
972        let validator_clients = self.node_provider.0.lock().unwrap();
973        for validator in validator_clients.iter() {
974            let mut inner = validator.client.try_lock().expect("no contention at setup");
975            inner.state.set_cross_chain_message_chunk_limit(limit);
976        }
977        drop(validator_clients);
978        self
979    }
980
981    /// Returns the [`FaultType`] currently configured for the given validator, or `None`
982    /// if no validator with that key is in the test setup.
983    pub fn fault_type(&self, public_key: &ValidatorPublicKey) -> Option<FaultType> {
984        self.node_provider
985            .0
986            .lock()
987            .unwrap()
988            .iter()
989            .find(|client| client.public_key == *public_key)
990            .map(|client| client.fault_type())
991    }
992
993    pub fn set_fault_type(&mut self, indexes: impl AsRef<[usize]>, fault_type: FaultType) {
994        let mut faulty_validators = vec![];
995        let mut validator_clients = self.node_provider.0.lock().unwrap();
996        for index in indexes.as_ref() {
997            let validator = &mut validator_clients[*index];
998            validator.set_fault_type(fault_type);
999            faulty_validators.push(validator.public_key);
1000        }
1001        tracing::info!(
1002            "Making the following validators {:?}: {:?}",
1003            fault_type,
1004            faulty_validators
1005        );
1006    }
1007
1008    /// Creates the root chain with the given `index`, and returns a client for it.
1009    ///
1010    /// Root chain 0 is the admin chain and needs to be initialized first, otherwise its balance
1011    /// is automatically set to zero.
1012    pub async fn add_root_chain(
1013        &mut self,
1014        index: u32,
1015        balance: Amount,
1016    ) -> anyhow::Result<ChainClient<B::Storage>> {
1017        // Make sure the admin chain is initialized.
1018        if self.admin_description.is_none() && index != 0 {
1019            Box::pin(self.add_root_chain(0, Amount::ZERO)).await?;
1020        }
1021        let origin = ChainOrigin::Root(index);
1022        let public_key = self.signer.generate_new();
1023        let open_chain_config = InitialChainConfig {
1024            ownership: ChainOwnership::single(public_key.into()),
1025            epoch: Epoch(0),
1026            balance,
1027            application_permissions: ApplicationPermissions::default(),
1028        };
1029        let description = ChainDescription::new(origin, open_chain_config, Timestamp::from(0));
1030        let committee_blob = Blob::new_committee(bcs::to_bytes(&self.initial_committee).unwrap());
1031        if index == 0 {
1032            self.admin_description = Some(description.clone());
1033            self.network_description = Some(NetworkDescription {
1034                admin_chain_id: description.id(),
1035                // dummy values to fill the description
1036                genesis_config_hash: CryptoHash::test_hash("genesis config"),
1037                genesis_timestamp: Timestamp::from(0),
1038                genesis_committee_blob_hash: committee_blob.id().hash,
1039                name: "test network".to_string(),
1040            });
1041        }
1042        // Remember what's in the genesis store for future clients to join.
1043        self.genesis_storage_builder
1044            .add(description.clone(), public_key);
1045
1046        let network_description = self.network_description.as_ref().unwrap();
1047
1048        for validator in self.node_provider.all_nodes() {
1049            let storage = self
1050                .validator_storages
1051                .get_mut(&validator.public_key)
1052                .unwrap();
1053            storage
1054                .write_network_description(network_description)
1055                .await
1056                .expect("writing the NetworkDescription should succeed");
1057            storage
1058                .write_blob(&committee_blob)
1059                .await
1060                .expect("writing a blob should succeed");
1061            storage.create_chain(description.clone()).await.unwrap();
1062        }
1063        for storage in &mut self.chain_client_storages {
1064            storage.create_chain(description.clone()).await.unwrap();
1065        }
1066        let chain_id = description.id();
1067        self.chain_owners.insert(chain_id, public_key.into());
1068        self.make_client(chain_id, None, BlockHeight::ZERO).await
1069    }
1070
1071    pub fn genesis_chains(&self) -> Vec<(AccountPublicKey, Amount)> {
1072        let mut result = Vec::new();
1073        for (i, genesis_account) in self.genesis_storage_builder.accounts.iter().enumerate() {
1074            assert_eq!(
1075                genesis_account.description.origin(),
1076                ChainOrigin::Root(i as u32)
1077            );
1078            result.push((
1079                genesis_account.public_key,
1080                genesis_account.description.config().balance,
1081            ));
1082        }
1083        result
1084    }
1085
1086    pub fn admin_chain_id(&self) -> ChainId {
1087        self.admin_description
1088            .as_ref()
1089            .expect("admin chain not initialized")
1090            .id()
1091    }
1092
1093    pub fn admin_description(&self) -> Option<&ChainDescription> {
1094        self.admin_description.as_ref()
1095    }
1096
1097    pub fn make_node_provider(&self) -> NodeProvider<B::Storage> {
1098        self.node_provider.clone()
1099    }
1100
1101    pub fn node(&mut self, index: usize) -> LocalValidatorClient<B::Storage> {
1102        self.node_provider.0.lock().unwrap()[index].clone()
1103    }
1104
1105    pub async fn make_storage(&mut self) -> anyhow::Result<B::Storage> {
1106        let storage = self.storage_builder.build().await?;
1107        let network_description = self.network_description.as_ref().unwrap();
1108        let committee_blob = Blob::new_committee(bcs::to_bytes(&self.initial_committee).unwrap());
1109        storage
1110            .write_network_description(network_description)
1111            .await
1112            .expect("writing the NetworkDescription should succeed");
1113        storage
1114            .write_blob(&committee_blob)
1115            .await
1116            .expect("writing a blob should succeed");
1117        Ok(self.genesis_storage_builder.build(storage).await)
1118    }
1119
1120    pub async fn make_client_with_options(
1121        &mut self,
1122        chain_id: ChainId,
1123        block_hash: Option<CryptoHash>,
1124        block_height: BlockHeight,
1125        options: chain_client::Options,
1126        follow_only: bool,
1127    ) -> anyhow::Result<ChainClient<B::Storage>> {
1128        // Note that new clients are only given the genesis store: they must figure out
1129        // the rest by asking validators.
1130        let storage = self.make_storage().await?;
1131        self.chain_client_storages.push(storage.clone());
1132        let mode = if follow_only {
1133            crate::client::ListeningMode::FollowChain
1134        } else {
1135            crate::client::ListeningMode::FullChain
1136        };
1137        let client = Arc::new(Client::new(
1138            crate::environment::Impl {
1139                network: self.make_node_provider(),
1140                storage,
1141                signer: self.signer.clone(),
1142                wallet: TestWallet::default(),
1143            },
1144            self.admin_chain_id(),
1145            false,
1146            [(chain_id, mode)],
1147            format!("Client node for {chain_id:.8}"),
1148            Some(Duration::from_secs(30)),
1149            Some(Duration::from_secs(1)),
1150            1000,
1151            options,
1152            DEFAULT_BLOCK_CACHE_SIZE,
1153            DEFAULT_EXECUTION_STATE_CACHE_SIZE,
1154            &crate::client::RequestsSchedulerConfig::default(),
1155        ));
1156        Ok(client.create_chain_client(
1157            chain_id,
1158            block_hash,
1159            block_height,
1160            &None,
1161            self.chain_owners.get(&chain_id).copied(),
1162            None,
1163            follow_only,
1164        ))
1165    }
1166
1167    pub async fn make_client(
1168        &mut self,
1169        chain_id: ChainId,
1170        block_hash: Option<CryptoHash>,
1171        block_height: BlockHeight,
1172    ) -> anyhow::Result<ChainClient<B::Storage>> {
1173        self.make_client_with_options(
1174            chain_id,
1175            block_hash,
1176            block_height,
1177            chain_client::Options::test_default(),
1178            false,
1179        )
1180        .await
1181    }
1182
1183    /// Tries to find a (confirmation) certificate for the given chain_id and block height.
1184    pub async fn check_that_validators_have_certificate(
1185        &self,
1186        chain_id: ChainId,
1187        block_height: BlockHeight,
1188        target_count: usize,
1189    ) -> Option<ConfirmedBlockCertificate> {
1190        let query = ChainInfoQuery::new(chain_id)
1191            .with_sent_certificate_hashes_by_heights(vec![block_height]);
1192        let mut count = 0;
1193        let mut certificate = None;
1194        for validator in self.node_provider.all_nodes() {
1195            if let Ok(response) = validator.handle_chain_info_query(query.clone()).await {
1196                if response.check(validator.public_key).is_ok() {
1197                    let ChainInfo {
1198                        mut requested_sent_certificate_hashes,
1199                        ..
1200                    } = *response.info;
1201                    debug_assert!(requested_sent_certificate_hashes.len() <= 1);
1202                    if let Some(cert_hash) = requested_sent_certificate_hashes.pop() {
1203                        if let Ok(cert) = validator.download_certificate(cert_hash).await {
1204                            if cert.inner().block().header.chain_id == chain_id
1205                                && cert.inner().block().header.height == block_height
1206                            {
1207                                cert.check(&self.initial_committee).unwrap();
1208                                count += 1;
1209                                certificate = Some(cert);
1210                            }
1211                        }
1212                    }
1213                }
1214            }
1215        }
1216        assert!(count >= target_count);
1217        certificate
1218    }
1219
1220    /// Tries to find a (confirmation) certificate for the given chain_id and block height, and are
1221    /// in the expected round.
1222    pub async fn check_that_validators_are_in_round(
1223        &self,
1224        chain_id: ChainId,
1225        block_height: BlockHeight,
1226        round: Round,
1227        target_count: usize,
1228    ) {
1229        let query = ChainInfoQuery::new(chain_id);
1230        let mut count = 0;
1231        for validator in self.node_provider.all_nodes() {
1232            if let Ok(response) = validator.handle_chain_info_query(query.clone()).await {
1233                if response.info.manager.current_round == round
1234                    && response.info.next_block_height == block_height
1235                    && response.check(validator.public_key).is_ok()
1236                {
1237                    count += 1;
1238                }
1239            }
1240        }
1241        assert!(count >= target_count);
1242    }
1243
1244    /// Panics if any validator has a nonempty outbox for the given chain.
1245    pub async fn check_that_validators_have_empty_outboxes(&self, chain_id: ChainId) {
1246        for validator in self.node_provider.all_nodes() {
1247            let guard = validator.client.lock().await;
1248            let chain = guard.state.chain_state_view(chain_id).await.unwrap();
1249            assert_eq!(chain.outboxes.indices().await.unwrap(), []);
1250        }
1251    }
1252}
1253
1254#[cfg(feature = "rocksdb")]
1255/// Limit concurrency for RocksDB tests to avoid "too many open files" errors.
1256static ROCKS_DB_SEMAPHORE: Semaphore = Semaphore::const_new(5);
1257
1258/// State shared by every [`StorageBuilder`] in this module. The actual
1259/// database type varies, so `build_storage` is generic over it.
1260#[derive(Default)]
1261struct CommonStorageBuilder {
1262    namespace: String,
1263    instance_counter: usize,
1264    wasm_runtime: Option<WasmRuntime>,
1265    clock: TestClock,
1266}
1267
1268impl CommonStorageBuilder {
1269    fn with_wasm_runtime(wasm_runtime: impl Into<Option<WasmRuntime>>) -> Self {
1270        Self {
1271            wasm_runtime: wasm_runtime.into(),
1272            ..Self::default()
1273        }
1274    }
1275
1276    async fn build_storage<DB>(
1277        &mut self,
1278        config: DB::Config,
1279    ) -> anyhow::Result<DbStorage<DB, TestClock>>
1280    where
1281        DB: TestKeyValueDatabase + Clone + Send + Sync + 'static,
1282        DB::Store: KeyValueStore + Clone + Send + Sync + 'static,
1283        DB::Error: std::error::Error + Send + Sync + 'static,
1284    {
1285        self.instance_counter += 1;
1286        if self.namespace.is_empty() {
1287            self.namespace = generate_test_namespace();
1288        }
1289        let namespace = format!("{}_{}", self.namespace, self.instance_counter);
1290        Ok(
1291            DbStorage::new_for_testing(config, &namespace, self.wasm_runtime, self.clock.clone())
1292                .await?,
1293        )
1294    }
1295}
1296
1297#[derive(Default)]
1298pub struct MemoryStorageBuilder {
1299    inner: CommonStorageBuilder,
1300}
1301
1302impl MemoryStorageBuilder {
1303    /// Creates a [`MemoryStorageBuilder`] that uses the specified [`WasmRuntime`] to run Wasm
1304    /// applications.
1305    pub fn with_wasm_runtime(wasm_runtime: impl Into<Option<WasmRuntime>>) -> Self {
1306        Self {
1307            inner: CommonStorageBuilder::with_wasm_runtime(wasm_runtime),
1308        }
1309    }
1310}
1311
1312#[async_trait]
1313impl StorageBuilder for MemoryStorageBuilder {
1314    type Storage = DbStorage<MemoryDatabase, TestClock>;
1315
1316    async fn build(&mut self) -> Result<Self::Storage, anyhow::Error> {
1317        let config = MemoryDatabase::new_test_config().await?;
1318        self.inner.build_storage::<MemoryDatabase>(config).await
1319    }
1320
1321    fn clock(&self) -> &TestClock {
1322        &self.inner.clock
1323    }
1324}
1325
1326#[cfg(feature = "rocksdb")]
1327pub struct RocksDbStorageBuilder {
1328    inner: CommonStorageBuilder,
1329    _permit: SemaphorePermit<'static>,
1330}
1331
1332#[cfg(feature = "rocksdb")]
1333impl RocksDbStorageBuilder {
1334    pub async fn new() -> Self {
1335        Self {
1336            inner: CommonStorageBuilder::default(),
1337            _permit: ROCKS_DB_SEMAPHORE.acquire().await.unwrap(),
1338        }
1339    }
1340
1341    /// Creates a [`RocksDbStorageBuilder`] that uses the specified [`WasmRuntime`] to run Wasm
1342    /// applications.
1343    #[cfg(any(feature = "wasmer", feature = "wasmtime"))]
1344    pub async fn with_wasm_runtime(wasm_runtime: impl Into<Option<WasmRuntime>>) -> Self {
1345        Self {
1346            inner: CommonStorageBuilder::with_wasm_runtime(wasm_runtime),
1347            _permit: ROCKS_DB_SEMAPHORE.acquire().await.unwrap(),
1348        }
1349    }
1350}
1351
1352#[cfg(feature = "rocksdb")]
1353#[async_trait]
1354impl StorageBuilder for RocksDbStorageBuilder {
1355    type Storage = DbStorage<RocksDbDatabase, TestClock>;
1356
1357    async fn build(&mut self) -> Result<Self::Storage, anyhow::Error> {
1358        let config = RocksDbDatabase::new_test_config().await?;
1359        self.inner.build_storage::<RocksDbDatabase>(config).await
1360    }
1361
1362    fn clock(&self) -> &TestClock {
1363        &self.inner.clock
1364    }
1365}
1366
1367#[cfg(all(not(target_arch = "wasm32"), feature = "storage-service"))]
1368#[derive(Default)]
1369pub struct ServiceStorageBuilder {
1370    inner: CommonStorageBuilder,
1371}
1372
1373#[cfg(all(not(target_arch = "wasm32"), feature = "storage-service"))]
1374impl ServiceStorageBuilder {
1375    /// Creates a `ServiceStorage`.
1376    pub fn new() -> Self {
1377        Self::with_wasm_runtime(None)
1378    }
1379
1380    /// Creates a `ServiceStorage` with the given Wasm runtime.
1381    pub fn with_wasm_runtime(wasm_runtime: impl Into<Option<WasmRuntime>>) -> Self {
1382        Self {
1383            inner: CommonStorageBuilder::with_wasm_runtime(wasm_runtime),
1384        }
1385    }
1386}
1387
1388#[cfg(all(not(target_arch = "wasm32"), feature = "storage-service"))]
1389#[async_trait]
1390impl StorageBuilder for ServiceStorageBuilder {
1391    type Storage = DbStorage<StorageServiceDatabase, TestClock>;
1392
1393    async fn build(&mut self) -> anyhow::Result<Self::Storage> {
1394        let config = StorageServiceDatabase::new_test_config().await?;
1395        self.inner
1396            .build_storage::<StorageServiceDatabase>(config)
1397            .await
1398    }
1399
1400    fn clock(&self) -> &TestClock {
1401        &self.inner.clock
1402    }
1403}
1404
1405#[cfg(feature = "scylladb")]
1406#[derive(Default)]
1407pub struct ScyllaDbStorageBuilder {
1408    inner: CommonStorageBuilder,
1409}
1410
1411#[cfg(feature = "scylladb")]
1412impl ScyllaDbStorageBuilder {
1413    /// Creates a [`ScyllaDbStorageBuilder`] that uses the specified [`WasmRuntime`] to run Wasm
1414    /// applications.
1415    pub fn with_wasm_runtime(wasm_runtime: impl Into<Option<WasmRuntime>>) -> Self {
1416        Self {
1417            inner: CommonStorageBuilder::with_wasm_runtime(wasm_runtime),
1418        }
1419    }
1420}
1421
1422#[cfg(feature = "scylladb")]
1423#[async_trait]
1424impl StorageBuilder for ScyllaDbStorageBuilder {
1425    type Storage = DbStorage<ScyllaDbDatabase, TestClock>;
1426
1427    async fn build(&mut self) -> Result<Self::Storage, anyhow::Error> {
1428        let config = ScyllaDbDatabase::new_test_config().await?;
1429        self.inner.build_storage::<ScyllaDbDatabase>(config).await
1430    }
1431
1432    fn clock(&self) -> &TestClock {
1433        &self.inner.clock
1434    }
1435}
1436
1437pub trait ClientOutcomeResultExt<T, E> {
1438    /// Unwraps the result and panics if it's not `Committed`.
1439    /// Use this when you expect the operation to succeed without conflicts.
1440    fn unwrap_ok_committed(self) -> T;
1441
1442    /// Unwraps the result, accepting both `Committed` and `Conflict` outcomes.
1443    /// Returns the committed value or the conflicting certificate (boxed).
1444    fn unwrap_ok_or_conflict(self) -> Result<T, Box<ConfirmedBlockCertificate>>;
1445}
1446
1447impl<T, E: std::fmt::Debug> ClientOutcomeResultExt<T, E> for Result<ClientOutcome<T>, E> {
1448    fn unwrap_ok_committed(self) -> T {
1449        match self.unwrap() {
1450            ClientOutcome::Committed(t) => t,
1451            ClientOutcome::WaitForTimeout(timeout) => {
1452                panic!("unexpected timeout: {timeout}")
1453            }
1454            ClientOutcome::Conflict(certificate) => {
1455                panic!("unexpected conflict: {}", certificate.hash())
1456            }
1457        }
1458    }
1459
1460    fn unwrap_ok_or_conflict(self) -> Result<T, Box<ConfirmedBlockCertificate>> {
1461        match self.unwrap() {
1462            ClientOutcome::Committed(t) => Ok(t),
1463            ClientOutcome::Conflict(certificate) => Err(certificate),
1464            ClientOutcome::WaitForTimeout(timeout) => {
1465                panic!("unexpected timeout: {timeout}")
1466            }
1467        }
1468    }
1469}