scylla/network/
connection.rs

1use super::tls::{TlsConfig, TlsProvider};
2use crate::authentication::AuthenticatorProvider;
3use crate::client::pager::{NextRowError, QueryPager};
4use crate::client::Compression;
5use crate::client::SelfIdentity;
6use crate::cluster::metadata::{PeerEndpoint, UntranslatedEndpoint};
7use crate::cluster::NodeAddr;
8use crate::errors::{
9    BadKeyspaceName, BrokenConnectionError, BrokenConnectionErrorKind, ConnectionError,
10    ConnectionSetupRequestError, ConnectionSetupRequestErrorKind, CqlEventHandlingError, DbError,
11    InternalRequestError, RequestAttemptError, ResponseParseError, SchemaAgreementError,
12    TranslationError, UseKeyspaceError,
13};
14use crate::frame::protocol_features::ProtocolFeatures;
15use crate::frame::{
16    self,
17    request::{self, batch, execute, query, register, SerializableRequest},
18    response::{event::Event, result, Response, ResponseOpcode},
19    server_event_type::EventType,
20    FrameParams, SerializedRequest,
21};
22use crate::policies::address_translator::{AddressTranslator, UntranslatedPeer};
23use crate::policies::timestamp_generator::TimestampGenerator;
24use crate::response::query_result::QueryResult;
25use crate::response::{
26    NonErrorAuthResponse, NonErrorStartupResponse, PagingState, PagingStateResponse, QueryResponse,
27};
28use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
29use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
30use crate::statement::batch::{Batch, BatchStatement};
31use crate::statement::prepared::PreparedStatement;
32use crate::statement::unprepared::Statement;
33use crate::statement::{Consistency, PageSize};
34use bytes::Bytes;
35use futures::{future::RemoteHandle, FutureExt};
36use scylla_cql::frame::frame_errors::CqlResponseParseError;
37use scylla_cql::frame::request::options::{self, Options};
38use scylla_cql::frame::request::CqlRequestKind;
39use scylla_cql::frame::response::authenticate::Authenticate;
40use scylla_cql::frame::response::result::{ResultMetadata, TableSpec};
41use scylla_cql::frame::response::Error;
42use scylla_cql::frame::response::{self, error};
43use scylla_cql::frame::types::SerialConsistency;
44use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
45use scylla_cql::serialize::raw_batch::RawBatchValuesAdapter;
46use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
47use socket2::{SockRef, TcpKeepalive};
48use std::borrow::Cow;
49use std::collections::{BTreeSet, HashMap, HashSet};
50use std::convert::TryFrom;
51use std::net::{IpAddr, SocketAddr};
52use std::num::NonZeroU64;
53use std::sync::atomic::AtomicU64;
54use std::sync::Arc;
55use std::sync::Mutex as StdMutex;
56use std::time::Duration;
57use std::{
58    cmp::Ordering,
59    net::{Ipv4Addr, Ipv6Addr},
60};
61use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
62use tokio::net::{TcpSocket, TcpStream};
63use tokio::sync::{mpsc, oneshot};
64use tokio::time::Instant;
65use tracing::{debug, error, trace, warn};
66use uuid::Uuid;
67
68// Queries for schema agreement
69const LOCAL_VERSION: &str = "SELECT schema_version FROM system.local WHERE key='local'";
70
71// FIXME: Make this constants configurable
72// The term "orphan" refers to stream ids, that were allocated for a {request, response} that no
73// one is waiting anymore (due to cancellation of `Connection::send_request`). Old orphan refers to
74// a stream id that is orphaned for a long time. This long time is defined below
75// (`OLD_AGE_ORPHAN_THRESHOLD`). Connection that has a big number (`OLD_ORPHAN_COUNT_THRESHOLD`)
76// of old orphans is shut down (and created again by a connection management layer).
77const OLD_ORPHAN_COUNT_THRESHOLD: usize = 1024;
78const OLD_AGE_ORPHAN_THRESHOLD: std::time::Duration = std::time::Duration::from_secs(1);
79
80/// Represents a write coalescing delay configuration option.
81#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub enum WriteCoalescingDelay {
84    /// A delay implemented by yielding a tokio task.
85    /// This should be used for sub-millisecond delays.
86    ///
87    /// Tokio sleeps have a millisecond granularity, so there is no reliable
88    /// way to implement deterministic delays shorter than one millisecond.
89    SmallNondeterministic,
90
91    /// A delay with millisecond granularity.
92    ///
93    /// This should be used with caution, and used only for throughput-bound applications
94    /// that send a lot of requests. We suggest benchmarking the application before
95    /// committing to this option.
96    Milliseconds(NonZeroU64),
97}
98
99pub(crate) struct Connection {
100    _worker_handle: RemoteHandle<()>,
101
102    connect_address: SocketAddr,
103    config: HostConnectionConfig,
104    features: ConnectionFeatures,
105    router_handle: Arc<RouterHandle>,
106}
107
108struct RouterHandle {
109    submit_channel: mpsc::Sender<Task>,
110
111    // Each request send by `Connection::send_request` needs a unique request id.
112    // This field is a monotonic generator of such ids.
113    request_id_generator: AtomicU64,
114    // If a `Connection::send_request` is cancelled, it sends notification
115    // about orphaning via the sender below.
116    // Also, this sender is unbounded, because only unbounded channels support
117    // pushing values in a synchronous way (without an `.await`), which is
118    // needed for pushing values in `Drop` implementations.
119    orphan_notification_sender: mpsc::UnboundedSender<RequestId>,
120}
121
122impl RouterHandle {
123    fn allocate_request_id(&self) -> RequestId {
124        self.request_id_generator
125            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
126    }
127
128    async fn send_request(
129        &self,
130        request: &impl SerializableRequest,
131        compression: Option<Compression>,
132        tracing: bool,
133    ) -> Result<TaskResponse, InternalRequestError> {
134        let serialized_request = SerializedRequest::make(request, compression, tracing)?;
135        let request_id = self.allocate_request_id();
136
137        let (response_sender, receiver) = oneshot::channel();
138        let response_handler = ResponseHandler {
139            response_sender,
140            request_id,
141        };
142
143        // Dropping `notifier` (before calling `notifier.disable()`) will send a notification to
144        // `Connection::router`. This notification is then used to mark a `stream_id` associated
145        // with this request as orphaned and free associated resources.
146        let notifier = OrphanhoodNotifier::new(request_id, &self.orphan_notification_sender);
147
148        self.submit_channel
149            .send(Task {
150                serialized_request,
151                response_handler,
152            })
153            .await
154            .map_err(|_| -> BrokenConnectionError {
155                BrokenConnectionErrorKind::ChannelError.into()
156            })?;
157
158        let task_response = receiver.await.map_err(|_| -> BrokenConnectionError {
159            BrokenConnectionErrorKind::ChannelError.into()
160        })?;
161
162        // Response was successfully received, so it's time to disable
163        // notification about orphaning.
164        notifier.disable();
165
166        task_response
167    }
168}
169
170#[derive(Default)]
171pub(crate) struct ConnectionFeatures {
172    shard_info: Option<ShardInfo>,
173    shard_aware_port: Option<u16>,
174    protocol_features: ProtocolFeatures,
175}
176
177type RequestId = u64;
178
179struct ResponseHandler {
180    response_sender: oneshot::Sender<Result<TaskResponse, InternalRequestError>>,
181    request_id: RequestId,
182}
183
184// Used to notify `Connection::orphaner` about `Connection::send_request`
185// future being dropped before receiving response.
186struct OrphanhoodNotifier<'a> {
187    enabled: bool,
188    request_id: RequestId,
189    notification_sender: &'a mpsc::UnboundedSender<RequestId>,
190}
191
192impl<'a> OrphanhoodNotifier<'a> {
193    fn new(
194        request_id: RequestId,
195        notification_sender: &'a mpsc::UnboundedSender<RequestId>,
196    ) -> Self {
197        Self {
198            enabled: true,
199            request_id,
200            notification_sender,
201        }
202    }
203
204    fn disable(mut self) {
205        self.enabled = false;
206    }
207}
208
209impl Drop for OrphanhoodNotifier<'_> {
210    fn drop(&mut self) {
211        if self.enabled {
212            let _ = self.notification_sender.send(self.request_id);
213        }
214    }
215}
216
217struct Task {
218    serialized_request: SerializedRequest,
219    response_handler: ResponseHandler,
220}
221
222struct TaskResponse {
223    params: FrameParams,
224    opcode: ResponseOpcode,
225    body: Bytes,
226}
227
228impl<'id: 'map, 'map> SelfIdentity<'id> {
229    fn add_startup_options(&'id self, options: &'map mut HashMap<Cow<'id, str>, Cow<'id, str>>) {
230        /* Driver identity. */
231        let driver_name = self
232            .get_custom_driver_name()
233            .unwrap_or(options::DEFAULT_DRIVER_NAME);
234        options.insert(
235            Cow::Borrowed(options::DRIVER_NAME),
236            Cow::Borrowed(driver_name),
237        );
238
239        let driver_version = self
240            .get_custom_driver_version()
241            .unwrap_or(options::DEFAULT_DRIVER_VERSION);
242        options.insert(
243            Cow::Borrowed(options::DRIVER_VERSION),
244            Cow::Borrowed(driver_version),
245        );
246
247        /* Application identity. */
248        if let Some(application_name) = self.get_application_name() {
249            options.insert(
250                Cow::Borrowed(options::APPLICATION_NAME),
251                Cow::Borrowed(application_name),
252            );
253        }
254
255        if let Some(application_version) = self.get_application_version() {
256            options.insert(
257                Cow::Borrowed(options::APPLICATION_VERSION),
258                Cow::Borrowed(application_version),
259            );
260        }
261
262        /* Client identity. */
263        if let Some(client_id) = self.get_client_id() {
264            options.insert(Cow::Borrowed(options::CLIENT_ID), Cow::Borrowed(client_id));
265        }
266    }
267}
268
269/// Configuration used for new connections.
270///
271/// Before being used for a particular connection, should be customized
272/// for a specific endpoint by converting to [HostConnectionConfig]
273/// using [ConnectionConfig::to_host_connection_config].
274#[derive(Clone)]
275pub(crate) struct ConnectionConfig {
276    pub(crate) local_ip_address: Option<IpAddr>,
277    pub(crate) shard_aware_local_port_range: ShardAwarePortRange,
278    pub(crate) compression: Option<Compression>,
279    pub(crate) tcp_nodelay: bool,
280    pub(crate) tcp_keepalive_interval: Option<Duration>,
281    pub(crate) timestamp_generator: Option<Arc<dyn TimestampGenerator>>,
282    pub(crate) tls_provider: Option<TlsProvider>,
283    pub(crate) connect_timeout: std::time::Duration,
284    // should be Some only in control connections,
285    pub(crate) event_sender: Option<mpsc::Sender<Event>>,
286    pub(crate) default_consistency: Consistency,
287    pub(crate) authenticator: Option<Arc<dyn AuthenticatorProvider>>,
288    pub(crate) address_translator: Option<Arc<dyn AddressTranslator>>,
289    pub(crate) write_coalescing_delay: Option<WriteCoalescingDelay>,
290
291    pub(crate) keepalive_interval: Option<Duration>,
292    pub(crate) keepalive_timeout: Option<Duration>,
293    pub(crate) tablet_sender: Option<mpsc::Sender<(TableSpec<'static>, RawTablet)>>,
294
295    pub(crate) identity: SelfIdentity<'static>,
296}
297
298impl ConnectionConfig {
299    /// Customizes the config for a specific endpoint.
300    pub(crate) fn to_host_connection_config(
301        &self,
302        // Currently, this is only used for cloud; but it makes abstract sense to pass endpoint here
303        // also for non-cloud cases, so let's just allow(unused).
304        #[allow(unused)] endpoint: &UntranslatedEndpoint,
305    ) -> HostConnectionConfig {
306        let tls_config = self
307            .tls_provider
308            .as_ref()
309            .and_then(|provider| provider.make_tls_config(endpoint));
310
311        HostConnectionConfig {
312            local_ip_address: self.local_ip_address,
313            shard_aware_local_port_range: self.shard_aware_local_port_range.clone(),
314            compression: self.compression,
315            tcp_nodelay: self.tcp_nodelay,
316            tcp_keepalive_interval: self.tcp_keepalive_interval,
317            timestamp_generator: self.timestamp_generator.clone(),
318            tls_config,
319            connect_timeout: self.connect_timeout,
320            event_sender: self.event_sender.clone(),
321            default_consistency: self.default_consistency,
322            authenticator: self.authenticator.clone(),
323            address_translator: self.address_translator.clone(),
324            write_coalescing_delay: self.write_coalescing_delay.clone(),
325            keepalive_interval: self.keepalive_interval,
326            keepalive_timeout: self.keepalive_timeout,
327            tablet_sender: self.tablet_sender.clone(),
328            identity: self.identity.clone(),
329        }
330    }
331}
332
333/// Configuration used for new connections, customized for a specific endpoint.
334///
335/// Created from [ConnectionConfig] using [ConnectionConfig::to_host_connection_config].
336#[derive(Clone)]
337pub(crate) struct HostConnectionConfig {
338    pub(crate) local_ip_address: Option<IpAddr>,
339    pub(crate) shard_aware_local_port_range: ShardAwarePortRange,
340    pub(crate) compression: Option<Compression>,
341    pub(crate) tcp_nodelay: bool,
342    pub(crate) tcp_keepalive_interval: Option<Duration>,
343    pub(crate) timestamp_generator: Option<Arc<dyn TimestampGenerator>>,
344    pub(crate) tls_config: Option<TlsConfig>,
345    pub(crate) connect_timeout: std::time::Duration,
346    // should be Some only in control connections,
347    pub(crate) event_sender: Option<mpsc::Sender<Event>>,
348    pub(crate) default_consistency: Consistency,
349    pub(crate) authenticator: Option<Arc<dyn AuthenticatorProvider>>,
350    pub(crate) address_translator: Option<Arc<dyn AddressTranslator>>,
351    pub(crate) write_coalescing_delay: Option<WriteCoalescingDelay>,
352
353    pub(crate) keepalive_interval: Option<Duration>,
354    pub(crate) keepalive_timeout: Option<Duration>,
355    pub(crate) tablet_sender: Option<mpsc::Sender<(TableSpec<'static>, RawTablet)>>,
356
357    pub(crate) identity: SelfIdentity<'static>,
358}
359
360#[cfg(test)]
361impl Default for HostConnectionConfig {
362    fn default() -> Self {
363        Self {
364            local_ip_address: None,
365            shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
366            compression: None,
367            tcp_nodelay: true,
368            tcp_keepalive_interval: None,
369            timestamp_generator: None,
370            event_sender: None,
371            tls_config: None,
372            connect_timeout: std::time::Duration::from_secs(5),
373            default_consistency: Default::default(),
374            authenticator: None,
375            address_translator: None,
376            write_coalescing_delay: Some(WriteCoalescingDelay::SmallNondeterministic),
377
378            // Note: this is different than SessionConfig default values.
379            keepalive_interval: None,
380            keepalive_timeout: None,
381
382            tablet_sender: None,
383
384            identity: SelfIdentity::default(),
385        }
386    }
387}
388
389#[cfg(test)]
390impl Default for ConnectionConfig {
391    fn default() -> Self {
392        Self {
393            local_ip_address: None,
394            shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
395            compression: None,
396            tcp_nodelay: true,
397            tcp_keepalive_interval: None,
398            timestamp_generator: None,
399            event_sender: None,
400            tls_provider: None,
401            connect_timeout: std::time::Duration::from_secs(5),
402            default_consistency: Default::default(),
403            authenticator: None,
404            address_translator: None,
405            write_coalescing_delay: Some(WriteCoalescingDelay::SmallNondeterministic),
406
407            // Note: this is different than SessionConfig default values.
408            keepalive_interval: None,
409            keepalive_timeout: None,
410
411            tablet_sender: None,
412
413            identity: SelfIdentity::default(),
414        }
415    }
416}
417
418impl HostConnectionConfig {
419    fn is_tls(&self) -> bool {
420        self.tls_config.is_some()
421    }
422}
423
424// Used to listen for fatal error in connection
425pub(crate) type ErrorReceiver = tokio::sync::oneshot::Receiver<ConnectionError>;
426
427impl Connection {
428    // Returns new connection and ErrorReceiver which can be used to wait for a fatal error
429    /// Opens a connection and makes it ready to send/receive CQL frames on it,
430    /// but does not yet send any frames (no OPTIONS/STARTUP handshake nor REGISTER requests).
431    async fn new(
432        connect_address: SocketAddr,
433        source_port: Option<u16>,
434        config: HostConnectionConfig,
435    ) -> Result<(Self, ErrorReceiver), ConnectionError> {
436        let stream_connector = tokio::time::timeout(
437            config.connect_timeout,
438            connect_with_source_ip_and_port(connect_address, config.local_ip_address, source_port),
439        )
440        .await;
441        let stream = match stream_connector {
442            Ok(stream) => stream?,
443            Err(_) => {
444                return Err(ConnectionError::ConnectTimeout);
445            }
446        };
447        stream.set_nodelay(config.tcp_nodelay)?;
448
449        if let Some(tcp_keepalive_interval) = config.tcp_keepalive_interval {
450            Self::setup_tcp_keepalive(&stream, tcp_keepalive_interval)?;
451        }
452
453        // TODO: What should be the size of the channel?
454        let (sender, receiver) = mpsc::channel(1024);
455        let (error_sender, error_receiver) = tokio::sync::oneshot::channel();
456        // Unbounded because it allows for synchronous pushes
457        let (orphan_notification_sender, orphan_notification_receiver) = mpsc::unbounded_channel();
458
459        let router_handle = Arc::new(RouterHandle {
460            submit_channel: sender,
461            request_id_generator: AtomicU64::new(0),
462            orphan_notification_sender,
463        });
464
465        let _worker_handle = Self::run_router(
466            config.clone(),
467            stream,
468            receiver,
469            error_sender,
470            orphan_notification_receiver,
471            router_handle.clone(),
472            connect_address.ip(),
473        )
474        .await?;
475
476        let connection = Connection {
477            _worker_handle,
478            config,
479            features: Default::default(),
480            connect_address,
481            router_handle,
482        };
483
484        Ok((connection, error_receiver))
485    }
486
487    fn setup_tcp_keepalive(
488        stream: &TcpStream,
489        tcp_keepalive_interval: Duration,
490    ) -> std::io::Result<()> {
491        // It may be surprising why we call `with_time()` with `tcp_keepalive_interval`
492        // and `with_interval() with some other value. This is due to inconsistent naming:
493        // our interval means time after connection becomes idle until keepalives
494        // begin to be sent (they call it "time"), and their interval is time between
495        // sending keepalives.
496        // We insist on our naming due to other drivers following the same convention.
497        let mut tcp_keepalive = TcpKeepalive::new().with_time(tcp_keepalive_interval);
498
499        // These cfg values are taken from socket2 library, which uses the same constraints.
500        #[cfg(any(
501            target_os = "android",
502            target_os = "dragonfly",
503            target_os = "freebsd",
504            target_os = "fuchsia",
505            target_os = "illumos",
506            target_os = "ios",
507            target_os = "linux",
508            target_os = "macos",
509            target_os = "netbsd",
510            target_os = "tvos",
511            target_os = "watchos",
512            target_os = "windows",
513        ))]
514        {
515            tcp_keepalive = tcp_keepalive.with_interval(Duration::from_secs(1));
516        }
517
518        #[cfg(any(
519            target_os = "android",
520            target_os = "dragonfly",
521            target_os = "freebsd",
522            target_os = "fuchsia",
523            target_os = "illumos",
524            target_os = "ios",
525            target_os = "linux",
526            target_os = "macos",
527            target_os = "netbsd",
528            target_os = "tvos",
529            target_os = "watchos",
530        ))]
531        {
532            tcp_keepalive = tcp_keepalive.with_retries(10);
533        }
534
535        let sf = SockRef::from(&stream);
536        sf.set_tcp_keepalive(&tcp_keepalive)
537    }
538
539    async fn startup(
540        &self,
541        options: HashMap<Cow<'_, str>, Cow<'_, str>>,
542    ) -> Result<NonErrorStartupResponse, ConnectionSetupRequestError> {
543        let err = |kind: ConnectionSetupRequestErrorKind| {
544            ConnectionSetupRequestError::new(CqlRequestKind::Startup, kind)
545        };
546
547        let req_result = self
548            .send_request(&request::Startup { options }, false, false, None)
549            .await;
550
551        // Extract the response to STARTUP request and tidy up the errors.
552        let response = match req_result {
553            Ok(r) => match r.response {
554                Response::Ready => NonErrorStartupResponse::Ready,
555                Response::Authenticate(auth) => NonErrorStartupResponse::Authenticate(auth),
556                Response::Error(Error { error, reason }) => {
557                    return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
558                }
559                _ => {
560                    return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
561                        r.response.to_response_kind(),
562                    )))
563                }
564            },
565            Err(e) => match e {
566                InternalRequestError::CqlRequestSerialization(e) => return Err(err(e.into())),
567                InternalRequestError::BodyExtensionsParseError(e) => return Err(err(e.into())),
568                InternalRequestError::CqlResponseParseError(e) => match e {
569                    // Parsing of READY response cannot fail, since its body is empty.
570                    // Remaining valid responses are AUTHENTICATE and ERROR.
571                    CqlResponseParseError::CqlAuthenticateParseError(e) => {
572                        return Err(err(e.into()))
573                    }
574                    CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())),
575                    _ => {
576                        return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
577                            e.to_response_kind(),
578                        )))
579                    }
580                },
581                InternalRequestError::BrokenConnection(e) => return Err(err(e.into())),
582                InternalRequestError::UnableToAllocStreamId => {
583                    return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
584                }
585            },
586        };
587
588        Ok(response)
589    }
590
591    async fn get_options(&self) -> Result<response::Supported, ConnectionSetupRequestError> {
592        let err = |kind: ConnectionSetupRequestErrorKind| {
593            ConnectionSetupRequestError::new(CqlRequestKind::Options, kind)
594        };
595
596        let req_result = self
597            .send_request(&request::Options {}, false, false, None)
598            .await;
599
600        // Extract the supported options and tidy up the errors.
601        let supported = match req_result {
602            Ok(r) => match r.response {
603                Response::Supported(supported) => supported,
604                Response::Error(Error { error, reason }) => {
605                    return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
606                }
607                _ => {
608                    return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
609                        r.response.to_response_kind(),
610                    )))
611                }
612            },
613            Err(e) => match e {
614                InternalRequestError::CqlRequestSerialization(e) => return Err(err(e.into())),
615                InternalRequestError::BodyExtensionsParseError(e) => return Err(err(e.into())),
616                InternalRequestError::CqlResponseParseError(e) => match e {
617                    CqlResponseParseError::CqlSupportedParseError(e) => return Err(err(e.into())),
618                    CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())),
619                    _ => {
620                        return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
621                            e.to_response_kind(),
622                        )))
623                    }
624                },
625                InternalRequestError::BrokenConnection(e) => return Err(err(e.into())),
626                InternalRequestError::UnableToAllocStreamId => {
627                    return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
628                }
629            },
630        };
631
632        Ok(supported)
633    }
634
635    pub(crate) async fn prepare(
636        &self,
637        query: &Statement,
638    ) -> Result<PreparedStatement, RequestAttemptError> {
639        let query_response = self
640            .send_request(
641                &request::Prepare {
642                    query: &query.contents,
643                },
644                true,
645                query.config.tracing,
646                None,
647            )
648            .await?;
649
650        let mut prepared_statement = match query_response.response {
651            Response::Error(error::Error { error, reason }) => {
652                return Err(RequestAttemptError::DbError(error, reason))
653            }
654            Response::Result(result::Result::Prepared(p)) => PreparedStatement::new(
655                p.id,
656                self.features
657                    .protocol_features
658                    .prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32),
659                p.prepared_metadata,
660                Arc::new(p.result_metadata),
661                query.contents.clone(),
662                query.get_validated_page_size(),
663                query.config.clone(),
664            ),
665            _ => {
666                return Err(RequestAttemptError::UnexpectedResponse(
667                    query_response.response.to_response_kind(),
668                ))
669            }
670        };
671
672        if let Some(tracing_id) = query_response.tracing_id {
673            prepared_statement.prepare_tracing_ids.push(tracing_id);
674        }
675        Ok(prepared_statement)
676    }
677
678    async fn reprepare(
679        &self,
680        query: impl Into<Statement>,
681        previous_prepared: &PreparedStatement,
682    ) -> Result<(), RequestAttemptError> {
683        let reprepare_query: Statement = query.into();
684        let reprepared = self.prepare(&reprepare_query).await?;
685        // Reprepared statement should keep its id - it's the md5 sum
686        // of statement contents
687        if reprepared.get_id() != previous_prepared.get_id() {
688            Err(RequestAttemptError::RepreparedIdChanged {
689                statement: reprepare_query.contents,
690                expected_id: previous_prepared.get_id().clone().into(),
691                reprepared_id: reprepared.get_id().clone().into(),
692            })
693        } else {
694            Ok(())
695        }
696    }
697
698    async fn perform_authenticate(
699        &mut self,
700        authenticate: &Authenticate,
701    ) -> Result<(), ConnectionSetupRequestError> {
702        let err = |kind: ConnectionSetupRequestErrorKind| {
703            ConnectionSetupRequestError::new(CqlRequestKind::AuthResponse, kind)
704        };
705
706        let authenticator = &authenticate.authenticator_name as &str;
707
708        match self.config.authenticator {
709            Some(ref authenticator_provider) => {
710                let (mut response, mut auth_session) = authenticator_provider
711                    .start_authentication_session(authenticator)
712                    .await
713                    .map_err(|e| err(ConnectionSetupRequestErrorKind::StartAuthSessionError(e)))?;
714
715                loop {
716                    match self.authenticate_response(response).await? {
717                        NonErrorAuthResponse::AuthChallenge(challenge) => {
718                            response = auth_session
719                                .evaluate_challenge(challenge.authenticate_message.as_deref())
720                                .await
721                                .map_err(|e| {
722                                    err(
723                                        ConnectionSetupRequestErrorKind::AuthChallengeEvaluationError(
724                                            e,
725                                        ),
726                                    )
727                                })?;
728                        }
729                        NonErrorAuthResponse::AuthSuccess(success) => {
730                            auth_session
731                                .success(success.success_message.as_deref())
732                                .await
733                                .map_err(|e| {
734                                    err(ConnectionSetupRequestErrorKind::AuthFinishError(e))
735                                })?;
736                            break;
737                        }
738                    }
739                }
740            }
741            None => return Err(err(ConnectionSetupRequestErrorKind::MissingAuthentication)),
742        }
743
744        Ok(())
745    }
746
747    async fn authenticate_response(
748        &self,
749        response: Option<Vec<u8>>,
750    ) -> Result<NonErrorAuthResponse, ConnectionSetupRequestError> {
751        let err = |kind: ConnectionSetupRequestErrorKind| {
752            ConnectionSetupRequestError::new(CqlRequestKind::AuthResponse, kind)
753        };
754
755        let req_result = self
756            .send_request(&request::AuthResponse { response }, false, false, None)
757            .await;
758
759        // Extract non-error response to AUTH_RESPONSE request and tidy up errors.
760        let response = match req_result {
761            Ok(r) => match r.response {
762                Response::AuthSuccess(auth_success) => {
763                    NonErrorAuthResponse::AuthSuccess(auth_success)
764                }
765                Response::AuthChallenge(auth_challenge) => {
766                    NonErrorAuthResponse::AuthChallenge(auth_challenge)
767                }
768                Response::Error(Error { error, reason }) => {
769                    return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
770                }
771                _ => {
772                    return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
773                        r.response.to_response_kind(),
774                    )))
775                }
776            },
777            Err(e) => match e {
778                InternalRequestError::CqlRequestSerialization(e) => return Err(err(e.into())),
779                InternalRequestError::BodyExtensionsParseError(e) => return Err(err(e.into())),
780                InternalRequestError::CqlResponseParseError(e) => match e {
781                    CqlResponseParseError::CqlAuthSuccessParseError(e) => {
782                        return Err(err(e.into()))
783                    }
784                    CqlResponseParseError::CqlAuthChallengeParseError(e) => {
785                        return Err(err(e.into()))
786                    }
787                    CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())),
788                    _ => {
789                        return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
790                            e.to_response_kind(),
791                        )))
792                    }
793                },
794                InternalRequestError::BrokenConnection(e) => return Err(err(e.into())),
795                InternalRequestError::UnableToAllocStreamId => {
796                    return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
797                }
798            },
799        };
800
801        Ok(response)
802    }
803
804    #[allow(dead_code)]
805    pub(crate) async fn query_single_page(
806        &self,
807        query: impl Into<Statement>,
808        paging_state: PagingState,
809    ) -> Result<(QueryResult, PagingStateResponse), RequestAttemptError> {
810        let query: Statement = query.into();
811
812        // This method is used only for driver internal queries, so no need to consult execution profile here.
813        let consistency = query
814            .config
815            .determine_consistency(self.config.default_consistency);
816        let serial_consistency = query.config.serial_consistency;
817
818        self.query_single_page_with_consistency(
819            query,
820            paging_state,
821            consistency,
822            serial_consistency.flatten(),
823        )
824        .await
825    }
826
827    #[allow(dead_code)]
828    pub(crate) async fn query_single_page_with_consistency(
829        &self,
830        query: impl Into<Statement>,
831        paging_state: PagingState,
832        consistency: Consistency,
833        serial_consistency: Option<SerialConsistency>,
834    ) -> Result<(QueryResult, PagingStateResponse), RequestAttemptError> {
835        let query: Statement = query.into();
836        let page_size = query.get_validated_page_size();
837
838        self.query_raw_with_consistency(
839            &query,
840            consistency,
841            serial_consistency,
842            Some(page_size),
843            paging_state,
844        )
845        .await?
846        .into_query_result_and_paging_state()
847    }
848
849    #[allow(dead_code)]
850    pub(crate) async fn query_unpaged(
851        &self,
852        statement: impl Into<Statement>,
853    ) -> Result<QueryResult, RequestAttemptError> {
854        // This method is used only for driver internal queries, so no need to consult execution profile here.
855        let statement: Statement = statement.into();
856
857        self.query_raw_unpaged(&statement)
858            .await
859            .and_then(QueryResponse::into_query_result)
860    }
861
862    pub(crate) async fn query_raw_unpaged(
863        &self,
864        statement: &Statement,
865    ) -> Result<QueryResponse, RequestAttemptError> {
866        // This method is used only for driver internal queries, so no need to consult execution profile here.
867        self.query_raw_with_consistency(
868            statement,
869            statement
870                .config
871                .determine_consistency(self.config.default_consistency),
872            statement.config.serial_consistency.flatten(),
873            None,
874            PagingState::start(),
875        )
876        .await
877    }
878
879    pub(crate) async fn query_raw_with_consistency(
880        &self,
881        statement: &Statement,
882        consistency: Consistency,
883        serial_consistency: Option<SerialConsistency>,
884        page_size: Option<PageSize>,
885        paging_state: PagingState,
886    ) -> Result<QueryResponse, RequestAttemptError> {
887        let get_timestamp_from_gen = || {
888            self.config
889                .timestamp_generator
890                .as_ref()
891                .map(|gen| gen.next_timestamp())
892        };
893        let timestamp = statement.get_timestamp().or_else(get_timestamp_from_gen);
894
895        let query_frame = query::Query {
896            contents: Cow::Borrowed(&statement.contents),
897            parameters: query::QueryParameters {
898                consistency,
899                serial_consistency,
900                values: Cow::Borrowed(SerializedValues::EMPTY),
901                page_size: page_size.map(Into::into),
902                paging_state,
903                skip_metadata: false,
904                timestamp,
905            },
906        };
907
908        let response = self
909            .send_request(&query_frame, true, statement.config.tracing, None)
910            .await?;
911
912        Ok(response)
913    }
914
915    #[allow(dead_code)]
916    pub(crate) async fn execute_unpaged(
917        &self,
918        prepared: &PreparedStatement,
919        values: SerializedValues,
920    ) -> Result<QueryResult, RequestAttemptError> {
921        // This method is used only for driver internal queries, so no need to consult execution profile here.
922        self.execute_raw_unpaged(prepared, values)
923            .await
924            .and_then(QueryResponse::into_query_result)
925    }
926
927    #[allow(dead_code)]
928    pub(crate) async fn execute_raw_unpaged(
929        &self,
930        prepared: &PreparedStatement,
931        values: SerializedValues,
932    ) -> Result<QueryResponse, RequestAttemptError> {
933        // This method is used only for driver internal queries, so no need to consult execution profile here.
934        self.execute_raw_with_consistency(
935            prepared,
936            &values,
937            prepared
938                .config
939                .determine_consistency(self.config.default_consistency),
940            prepared.config.serial_consistency.flatten(),
941            None,
942            PagingState::start(),
943        )
944        .await
945    }
946
947    pub(crate) async fn execute_raw_with_consistency(
948        &self,
949        prepared_statement: &PreparedStatement,
950        values: &SerializedValues,
951        consistency: Consistency,
952        serial_consistency: Option<SerialConsistency>,
953        page_size: Option<PageSize>,
954        paging_state: PagingState,
955    ) -> Result<QueryResponse, RequestAttemptError> {
956        let get_timestamp_from_gen = || {
957            self.config
958                .timestamp_generator
959                .as_ref()
960                .map(|gen| gen.next_timestamp())
961        };
962        let timestamp = prepared_statement
963            .get_timestamp()
964            .or_else(get_timestamp_from_gen);
965
966        let execute_frame = execute::Execute {
967            id: prepared_statement.get_id().to_owned(),
968            parameters: query::QueryParameters {
969                consistency,
970                serial_consistency,
971                values: Cow::Borrowed(values),
972                page_size: page_size.map(Into::into),
973                timestamp,
974                skip_metadata: prepared_statement.get_use_cached_result_metadata(),
975                paging_state,
976            },
977        };
978
979        let cached_metadata = prepared_statement
980            .get_use_cached_result_metadata()
981            .then(|| prepared_statement.get_result_metadata());
982
983        let query_response = self
984            .send_request(
985                &execute_frame,
986                true,
987                prepared_statement.config.tracing,
988                cached_metadata,
989            )
990            .await?;
991
992        if let Some(spec) = prepared_statement.get_table_spec() {
993            if let Err(e) = self
994                .update_tablets_from_response(spec, &query_response)
995                .await
996            {
997                tracing::warn!("Error while parsing tablet info from custom payload: {}", e);
998            }
999        }
1000
1001        match &query_response.response {
1002            Response::Error(frame::response::Error {
1003                error: DbError::Unprepared { statement_id },
1004                ..
1005            }) => {
1006                debug!("Connection::execute: Got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1007                // Repreparation of a statement is needed
1008                self.reprepare(prepared_statement.get_statement(), prepared_statement)
1009                    .await?;
1010                let new_response = self
1011                    .send_request(
1012                        &execute_frame,
1013                        true,
1014                        prepared_statement.config.tracing,
1015                        cached_metadata,
1016                    )
1017                    .await?;
1018
1019                if let Some(spec) = prepared_statement.get_table_spec() {
1020                    if let Err(e) = self.update_tablets_from_response(spec, &new_response).await {
1021                        tracing::warn!(
1022                            "Error while parsing tablet info from custom payload: {}",
1023                            e
1024                        );
1025                    }
1026                }
1027
1028                Ok(new_response)
1029            }
1030            _ => Ok(query_response),
1031        }
1032    }
1033
1034    /// Executes a query and fetches its results over multiple pages, using
1035    /// the asynchronous iterator interface.
1036    pub(crate) async fn query_iter(
1037        self: Arc<Self>,
1038        query: Statement,
1039    ) -> Result<QueryPager, NextRowError> {
1040        let consistency = query
1041            .config
1042            .determine_consistency(self.config.default_consistency);
1043        let serial_consistency = query.config.serial_consistency.flatten();
1044
1045        QueryPager::new_for_connection_query_iter(query, self, consistency, serial_consistency)
1046            .await
1047            .map_err(NextRowError::NextPageError)
1048    }
1049
1050    /// Executes a prepared statements and fetches its results over multiple pages, using
1051    /// the asynchronous iterator interface.
1052    pub(crate) async fn execute_iter(
1053        self: Arc<Self>,
1054        prepared_statement: PreparedStatement,
1055        values: SerializedValues,
1056    ) -> Result<QueryPager, NextRowError> {
1057        let consistency = prepared_statement
1058            .config
1059            .determine_consistency(self.config.default_consistency);
1060        let serial_consistency = prepared_statement.config.serial_consistency.flatten();
1061
1062        QueryPager::new_for_connection_execute_iter(
1063            prepared_statement,
1064            values,
1065            self,
1066            consistency,
1067            serial_consistency,
1068        )
1069        .await
1070        .map_err(NextRowError::NextPageError)
1071    }
1072
1073    #[allow(dead_code)]
1074    pub(crate) async fn batch(
1075        &self,
1076        batch: &Batch,
1077        values: impl BatchValues,
1078    ) -> Result<QueryResult, RequestAttemptError> {
1079        self.batch_with_consistency(
1080            batch,
1081            values,
1082            batch
1083                .config
1084                .determine_consistency(self.config.default_consistency),
1085            batch.config.serial_consistency.flatten(),
1086        )
1087        .await
1088        .and_then(QueryResponse::into_query_result)
1089    }
1090
1091    pub(crate) async fn batch_with_consistency(
1092        &self,
1093        init_batch: &Batch,
1094        values: impl BatchValues,
1095        consistency: Consistency,
1096        serial_consistency: Option<SerialConsistency>,
1097    ) -> Result<QueryResponse, RequestAttemptError> {
1098        let batch = self.prepare_batch(init_batch, &values).await?;
1099
1100        let contexts = batch.statements.iter().map(|bs| match bs {
1101            BatchStatement::Query(_) => RowSerializationContext::empty(),
1102            BatchStatement::PreparedStatement(ps) => {
1103                RowSerializationContext::from_prepared(ps.get_prepared_metadata())
1104            }
1105        });
1106
1107        let values = RawBatchValuesAdapter::new(values, contexts);
1108
1109        let get_timestamp_from_gen = || {
1110            self.config
1111                .timestamp_generator
1112                .as_ref()
1113                .map(|gen| gen.next_timestamp())
1114        };
1115        let timestamp = batch.get_timestamp().or_else(get_timestamp_from_gen);
1116
1117        let batch_frame = batch::Batch {
1118            statements: Cow::Borrowed(&batch.statements),
1119            values,
1120            batch_type: batch.get_type(),
1121            consistency,
1122            serial_consistency,
1123            timestamp,
1124        };
1125
1126        loop {
1127            let query_response = self
1128                .send_request(&batch_frame, true, batch.config.tracing, None)
1129                .await
1130                .map_err(RequestAttemptError::from)?;
1131
1132            return match query_response.response {
1133                Response::Error(err) => match err.error {
1134                    DbError::Unprepared { statement_id } => {
1135                        debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1136                        let prepared_statement = batch.statements.iter().find_map(|s| match s {
1137                            BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
1138                                Some(s)
1139                            }
1140                            _ => None,
1141                        });
1142                        if let Some(p) = prepared_statement {
1143                            self.reprepare(p.get_statement(), p).await?;
1144                            continue;
1145                        } else {
1146                            return Err(RequestAttemptError::RepreparedIdMissingInBatch);
1147                        }
1148                    }
1149                    _ => Err(err.into()),
1150                },
1151                Response::Result(_) => Ok(query_response),
1152                _ => Err(RequestAttemptError::UnexpectedResponse(
1153                    query_response.response.to_response_kind(),
1154                )),
1155            };
1156        }
1157    }
1158
1159    async fn prepare_batch<'b>(
1160        &self,
1161        init_batch: &'b Batch,
1162        values: impl BatchValues,
1163    ) -> Result<Cow<'b, Batch>, RequestAttemptError> {
1164        let mut to_prepare = HashSet::<&str>::new();
1165
1166        {
1167            let mut values_iter = values.batch_values_iter();
1168            for stmt in &init_batch.statements {
1169                if let BatchStatement::Query(query) = stmt {
1170                    if let Some(false) = values_iter.is_empty_next() {
1171                        to_prepare.insert(&query.contents);
1172                    }
1173                } else {
1174                    values_iter.skip_next();
1175                }
1176            }
1177        }
1178
1179        if to_prepare.is_empty() {
1180            return Ok(Cow::Borrowed(init_batch));
1181        }
1182
1183        let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1184
1185        for query in &to_prepare {
1186            let prepared = self.prepare(&Statement::new(query.to_string())).await?;
1187            prepared_queries.insert(query, prepared);
1188        }
1189
1190        let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1191        for stmt in &init_batch.statements {
1192            match stmt {
1193                BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1194                {
1195                    Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1196                    None => batch.to_mut().append_statement(query.clone()),
1197                },
1198                BatchStatement::PreparedStatement(prepared) => {
1199                    batch.to_mut().append_statement(prepared.clone());
1200                }
1201            }
1202        }
1203
1204        Ok(batch)
1205    }
1206
1207    pub(super) async fn use_keyspace(
1208        &self,
1209        keyspace_name: &VerifiedKeyspaceName,
1210    ) -> Result<(), UseKeyspaceError> {
1211        // Trying to pass keyspace_name as bound value doesn't work
1212        // We have to send "USE " + keyspace_name
1213        let query: Statement = match keyspace_name.is_case_sensitive {
1214            true => format!("USE \"{}\"", keyspace_name.as_str()).into(),
1215            false => format!("USE {}", keyspace_name.as_str()).into(),
1216        };
1217
1218        let query_response = self.query_raw_unpaged(&query).await?;
1219        Self::verify_use_keyspace_result(keyspace_name, query_response)
1220    }
1221
1222    fn verify_use_keyspace_result(
1223        keyspace_name: &VerifiedKeyspaceName,
1224        query_response: QueryResponse,
1225    ) -> Result<(), UseKeyspaceError> {
1226        match query_response.response {
1227            Response::Result(result::Result::SetKeyspace(set_keyspace)) => {
1228                if !set_keyspace
1229                    .keyspace_name
1230                    .eq_ignore_ascii_case(keyspace_name.as_str())
1231                {
1232                    let expected_keyspace_name_lowercase = keyspace_name.as_str().to_lowercase();
1233                    let result_keyspace_name_lowercase = set_keyspace.keyspace_name.to_lowercase();
1234
1235                    return Err(UseKeyspaceError::KeyspaceNameMismatch {
1236                        expected_keyspace_name_lowercase,
1237                        result_keyspace_name_lowercase,
1238                    });
1239                }
1240
1241                Ok(())
1242            }
1243            Response::Error(err) => Err(UseKeyspaceError::RequestError(
1244                RequestAttemptError::DbError(err.error, err.reason),
1245            )),
1246            _ => Err(UseKeyspaceError::RequestError(
1247                RequestAttemptError::UnexpectedResponse(query_response.response.to_response_kind()),
1248            )),
1249        }
1250    }
1251
1252    async fn register(
1253        &self,
1254        event_types_to_register_for: Vec<EventType>,
1255    ) -> Result<(), ConnectionSetupRequestError> {
1256        let err = |kind: ConnectionSetupRequestErrorKind| {
1257            ConnectionSetupRequestError::new(CqlRequestKind::Register, kind)
1258        };
1259
1260        let register_frame = register::Register {
1261            event_types_to_register_for,
1262        };
1263
1264        // Extract the response and tidy up the errors.
1265        match self.send_request(&register_frame, true, false, None).await {
1266            Ok(r) => match r.response {
1267                Response::Ready => Ok(()),
1268                Response::Error(Error { error, reason }) => {
1269                    Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
1270                }
1271                _ => Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
1272                    r.response.to_response_kind(),
1273                ))),
1274            },
1275            Err(e) => match e {
1276                InternalRequestError::CqlRequestSerialization(e) => Err(err(e.into())),
1277                InternalRequestError::BodyExtensionsParseError(e) => Err(err(e.into())),
1278                InternalRequestError::CqlResponseParseError(e) => match e {
1279                    // Parsing the READY response cannot fail. Only remaining valid response is ERROR.
1280                    CqlResponseParseError::CqlErrorParseError(e) => Err(err(e.into())),
1281                    _ => Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
1282                        e.to_response_kind(),
1283                    ))),
1284                },
1285                InternalRequestError::BrokenConnection(e) => Err(err(e.into())),
1286                InternalRequestError::UnableToAllocStreamId => {
1287                    Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
1288                }
1289            },
1290        }
1291    }
1292
1293    pub(crate) async fn fetch_schema_version(&self) -> Result<Uuid, SchemaAgreementError> {
1294        let (version_id,) = self
1295            .query_unpaged(LOCAL_VERSION)
1296            .await?
1297            .into_rows_result()
1298            .map_err(SchemaAgreementError::TracesEventsIntoRowsResultError)?
1299            .single_row::<(Uuid,)>()
1300            .map_err(SchemaAgreementError::SingleRowError)?;
1301
1302        Ok(version_id)
1303    }
1304
1305    async fn send_request(
1306        &self,
1307        request: &impl SerializableRequest,
1308        compress: bool,
1309        tracing: bool,
1310        cached_metadata: Option<&Arc<ResultMetadata<'static>>>,
1311    ) -> Result<QueryResponse, InternalRequestError> {
1312        let compression = if compress {
1313            self.config.compression
1314        } else {
1315            None
1316        };
1317
1318        let task_response = self
1319            .router_handle
1320            .send_request(request, compression, tracing)
1321            .await?;
1322
1323        let response = Self::parse_response(
1324            task_response,
1325            self.config.compression,
1326            &self.features.protocol_features,
1327            cached_metadata,
1328        )?;
1329
1330        Ok(response)
1331    }
1332
1333    fn parse_response(
1334        task_response: TaskResponse,
1335        compression: Option<Compression>,
1336        features: &ProtocolFeatures,
1337        cached_metadata: Option<&Arc<ResultMetadata<'static>>>,
1338    ) -> Result<QueryResponse, ResponseParseError> {
1339        let body_with_ext = frame::parse_response_body_extensions(
1340            task_response.params.flags,
1341            compression,
1342            task_response.body,
1343        )?;
1344
1345        for warn_description in &body_with_ext.warnings {
1346            warn!(
1347                warning = warn_description.as_str(),
1348                "Response from the database contains a warning",
1349            );
1350        }
1351
1352        let response = Response::deserialize(
1353            features,
1354            task_response.opcode,
1355            body_with_ext.body,
1356            cached_metadata,
1357        )?;
1358
1359        Ok(QueryResponse {
1360            response,
1361            warnings: body_with_ext.warnings,
1362            tracing_id: body_with_ext.trace_id,
1363            custom_payload: body_with_ext.custom_payload,
1364        })
1365    }
1366
1367    async fn run_router(
1368        config: HostConnectionConfig,
1369        stream: TcpStream,
1370        receiver: mpsc::Receiver<Task>,
1371        error_sender: tokio::sync::oneshot::Sender<ConnectionError>,
1372        orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
1373        router_handle: Arc<RouterHandle>,
1374        node_address: IpAddr,
1375    ) -> Result<RemoteHandle<()>, std::io::Error> {
1376        async fn spawn_router_and_get_handle(
1377            config: HostConnectionConfig,
1378            stream: (impl AsyncRead + AsyncWrite + Send + 'static),
1379            receiver: mpsc::Receiver<Task>,
1380            error_sender: tokio::sync::oneshot::Sender<ConnectionError>,
1381            orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
1382            router_handle: Arc<RouterHandle>,
1383            node_address: IpAddr,
1384        ) -> RemoteHandle<()> {
1385            let (task, handle) = Connection::router(
1386                config,
1387                stream,
1388                receiver,
1389                error_sender,
1390                orphan_notification_receiver,
1391                router_handle,
1392                node_address,
1393            )
1394            .remote_handle();
1395            tokio::task::spawn(task);
1396            handle
1397        }
1398
1399        if let Some(tls_config) = &config.tls_config {
1400            // To silence warnings when TlsContext is an empty enum (tls features are disabled).
1401            #[allow(unreachable_code)]
1402            match tls_config.new_tls()? {
1403                #[cfg(feature = "openssl-010")]
1404                crate::network::tls::Tls::OpenSsl010(ssl) => {
1405                    let mut stream = tokio_openssl::SslStream::new(ssl, stream)
1406                        .map_err(crate::network::tls::TlsError::OpenSsl010)?;
1407                    std::pin::Pin::new(&mut stream)
1408                        .connect()
1409                        .await
1410                        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
1411                    return Ok(spawn_router_and_get_handle(
1412                        config,
1413                        stream,
1414                        receiver,
1415                        error_sender,
1416                        orphan_notification_receiver,
1417                        router_handle,
1418                        node_address,
1419                    )
1420                    .await);
1421                }
1422                #[cfg(feature = "rustls-023")]
1423                crate::network::tls::Tls::Rustls023 {
1424                    connector,
1425                    #[cfg(feature = "unstable-cloud")]
1426                    sni,
1427                } => {
1428                    use rustls::pki_types::ServerName;
1429                    #[cfg(feature = "unstable-cloud")]
1430                    let server_name =
1431                        sni.unwrap_or_else(|| ServerName::IpAddress(node_address.into()));
1432                    #[cfg(not(feature = "unstable-cloud"))]
1433                    let server_name = ServerName::IpAddress(node_address.into());
1434                    let stream = connector.connect(server_name, stream).await?;
1435                    return Ok(spawn_router_and_get_handle(
1436                        config,
1437                        stream,
1438                        receiver,
1439                        error_sender,
1440                        orphan_notification_receiver,
1441                        router_handle,
1442                        node_address,
1443                    )
1444                    .await);
1445                }
1446            }
1447        }
1448
1449        Ok(spawn_router_and_get_handle(
1450            config,
1451            stream,
1452            receiver,
1453            error_sender,
1454            orphan_notification_receiver,
1455            router_handle,
1456            node_address,
1457        )
1458        .await)
1459    }
1460
1461    async fn router(
1462        config: HostConnectionConfig,
1463        stream: (impl AsyncRead + AsyncWrite),
1464        receiver: mpsc::Receiver<Task>,
1465        error_sender: tokio::sync::oneshot::Sender<ConnectionError>,
1466        orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
1467        router_handle: Arc<RouterHandle>,
1468        node_address: IpAddr,
1469    ) {
1470        let (read_half, write_half) = split(stream);
1471        // Why are we using a mutex here?
1472        //
1473        // The handler_map is supposed to be shared between reader and writer
1474        // futures, which will be run on the same task. The mutex should not
1475        // be normally required here, but Rust complains if we try to use
1476        // a RefCell instead of Mutex (the whole future becomes !Sync and we
1477        // cannot use it in tasks).
1478        //
1479        // Notice that this lock will have no contention, because reader
1480        // and writer futures are run on the same fiber, and both of them
1481        // are carefully written in such a way that they do not hold the lock
1482        // across .await points. Therefore, it should not be too expensive.
1483        let handler_map = StdMutex::new(ResponseHandlerMap::new());
1484
1485        let write_coalescing_delay = config.write_coalescing_delay;
1486
1487        let k = Self::keepaliver(
1488            router_handle,
1489            config.keepalive_interval,
1490            config.keepalive_timeout,
1491            node_address,
1492        );
1493
1494        let r = Self::reader(
1495            BufReader::with_capacity(8192, read_half),
1496            &handler_map,
1497            config.event_sender,
1498            config.compression,
1499        );
1500        let w = Self::writer(
1501            BufWriter::with_capacity(8192, write_half),
1502            &handler_map,
1503            receiver,
1504            write_coalescing_delay,
1505        );
1506        let o = Self::orphaner(&handler_map, orphan_notification_receiver);
1507
1508        let result = futures::try_join!(r, w, o, k);
1509
1510        let error: BrokenConnectionError = match result {
1511            Ok(_) => return, // Connection was dropped, we can return
1512            Err(err) => err,
1513        };
1514
1515        // Respond to all pending requests with the error
1516        let response_handlers: HashMap<i16, ResponseHandler> =
1517            handler_map.into_inner().unwrap().into_handlers();
1518
1519        for (_, handler) in response_handlers {
1520            // Ignore sending error, request was dropped
1521            let _ = handler.response_sender.send(Err(error.clone().into()));
1522        }
1523
1524        // If someone is listening for connection errors notify them
1525        let _ = error_sender.send(error.into());
1526    }
1527
1528    async fn reader(
1529        mut read_half: (impl AsyncRead + Unpin),
1530        handler_map: &StdMutex<ResponseHandlerMap>,
1531        event_sender: Option<mpsc::Sender<Event>>,
1532        compression: Option<Compression>,
1533    ) -> Result<(), BrokenConnectionError> {
1534        loop {
1535            let (params, opcode, body) = frame::read_response_frame(&mut read_half)
1536                .await
1537                .map_err(BrokenConnectionErrorKind::FrameHeaderParseError)?;
1538            let response = TaskResponse {
1539                params,
1540                opcode,
1541                body,
1542            };
1543
1544            match params.stream.cmp(&-1) {
1545                Ordering::Less => {
1546                    // The spec reserves negative-numbered streams for server-generated
1547                    // events. As of writing this driver, there are no other negative
1548                    // streams used apart from -1, so ignore it.
1549                    continue;
1550                }
1551                Ordering::Equal => {
1552                    if let Some(event_sender) = event_sender.as_ref() {
1553                        Self::handle_event(response, compression, event_sender)
1554                            .await
1555                            .map_err(BrokenConnectionErrorKind::CqlEventHandlingError)?
1556                    }
1557                    continue;
1558                }
1559                _ => {}
1560            }
1561
1562            let handler_lookup_res = {
1563                // We are guaranteed here that handler_map will not be locked
1564                // by anybody else, so we can do try_lock().unwrap()
1565                let mut handler_map_guard = handler_map.try_lock().unwrap();
1566                handler_map_guard.lookup(params.stream)
1567            };
1568
1569            use HandlerLookupResult::*;
1570            match handler_lookup_res {
1571                Handler(handler) => {
1572                    // Don't care if sending of the response fails. This must
1573                    // mean that the receiver side was impatient and is not
1574                    // waiting for the result anymore.
1575                    let _ = handler.response_sender.send(Ok(response));
1576                }
1577                Missing => {
1578                    // Unsolicited frame. This should not happen and indicates
1579                    // a bug either in the driver, or in the database
1580                    debug!(
1581                        "Received response with unexpected StreamId {}",
1582                        params.stream
1583                    );
1584                    return Err(BrokenConnectionErrorKind::UnexpectedStreamId(params.stream).into());
1585                }
1586                Orphaned => {
1587                    // Do nothing, handler was freed because this stream_id has
1588                    // been marked as orphaned
1589                }
1590            }
1591        }
1592    }
1593
1594    fn alloc_stream_id(
1595        handler_map: &StdMutex<ResponseHandlerMap>,
1596        response_handler: ResponseHandler,
1597    ) -> Option<i16> {
1598        // We are guaranteed here that handler_map will not be locked
1599        // by anybody else, so we can do try_lock().unwrap()
1600        let mut handler_map_guard = handler_map.try_lock().unwrap();
1601        match handler_map_guard.allocate(response_handler) {
1602            Ok(stream_id) => Some(stream_id),
1603            Err(response_handler) => {
1604                error!("Could not allocate stream id");
1605                let _ = response_handler
1606                    .response_sender
1607                    .send(Err(InternalRequestError::UnableToAllocStreamId));
1608                None
1609            }
1610        }
1611    }
1612
1613    async fn writer(
1614        mut write_half: (impl AsyncWrite + Unpin),
1615        handler_map: &StdMutex<ResponseHandlerMap>,
1616        mut task_receiver: mpsc::Receiver<Task>,
1617        write_coalescing_delay: Option<WriteCoalescingDelay>,
1618    ) -> Result<(), BrokenConnectionError> {
1619        // When the Connection object is dropped, the sender half
1620        // of the channel will be dropped, this task will return an error
1621        // and the whole worker will be stopped
1622        while let Some(mut task) = task_receiver.recv().await {
1623            let mut num_requests = 0;
1624            let mut total_sent = 0;
1625            while let Some(stream_id) = Self::alloc_stream_id(handler_map, task.response_handler) {
1626                let mut req = task.serialized_request;
1627                req.set_stream(stream_id);
1628                let req_data: &[u8] = req.get_data();
1629                total_sent += req_data.len();
1630                num_requests += 1;
1631                write_half
1632                    .write_all(req_data)
1633                    .await
1634                    .map_err(BrokenConnectionErrorKind::WriteError)?;
1635                task = match task_receiver.try_recv() {
1636                    Ok(t) => t,
1637                    Err(_) => match write_coalescing_delay {
1638                        Some(WriteCoalescingDelay::SmallNondeterministic) => {
1639                            // Yielding was empirically tested to inject a 1-300µs delay,
1640                            // much better than tokio::time::sleep's 1ms granularity.
1641                            // Also, yielding in a busy system let's the queue catch up with new items.
1642                            tokio::task::yield_now().await;
1643                            match task_receiver.try_recv() {
1644                                Ok(t) => t,
1645                                Err(_) => break,
1646                            }
1647                        }
1648                        Some(WriteCoalescingDelay::Milliseconds(ms)) => {
1649                            tokio::time::sleep(Duration::from_millis(ms.get())).await;
1650                            match task_receiver.try_recv() {
1651                                Ok(t) => t,
1652                                Err(_) => break,
1653                            }
1654                        }
1655                        None => break,
1656                    },
1657                }
1658            }
1659            trace!("Sending {} requests; {} bytes", num_requests, total_sent);
1660            write_half
1661                .flush()
1662                .await
1663                .map_err(BrokenConnectionErrorKind::WriteError)?;
1664        }
1665
1666        Ok(())
1667    }
1668
1669    // This task receives notifications from `OrphanhoodNotifier`s and tries to
1670    // mark streams as orphaned. It also checks count of old orphans periodically.
1671    // After an ald orphan threshold is reached, `orphaner` returns an error
1672    // causing the connection to break.
1673    async fn orphaner(
1674        handler_map: &StdMutex<ResponseHandlerMap>,
1675        mut orphan_receiver: mpsc::UnboundedReceiver<RequestId>,
1676    ) -> Result<(), BrokenConnectionError> {
1677        let mut interval = tokio::time::interval(OLD_AGE_ORPHAN_THRESHOLD);
1678        loop {
1679            tokio::select! {
1680                _ = interval.tick() => {
1681                    // We are guaranteed here that handler_map will not be locked
1682                    // by anybody else, so we can do try_lock().unwrap()
1683                    let handler_map_guard = handler_map.try_lock().unwrap();
1684                    let old_orphan_count = handler_map_guard.old_orphans_count();
1685                    if old_orphan_count > OLD_ORPHAN_COUNT_THRESHOLD {
1686                        warn!(
1687                            "Too many old orphaned stream ids: {}",
1688                            old_orphan_count,
1689                        );
1690                        return Err(BrokenConnectionErrorKind::TooManyOrphanedStreamIds(old_orphan_count as u16).into())
1691                    }
1692                }
1693                Some(request_id) = orphan_receiver.recv() => {
1694                    trace!(
1695                        "Trying to orphan stream id associated with request_id = {}",
1696                        request_id,
1697                    );
1698                    let mut handler_map_guard = handler_map.try_lock().unwrap(); // Same as above
1699                    handler_map_guard.orphan(request_id);
1700                }
1701                else => { break }
1702            }
1703        }
1704
1705        Ok(())
1706    }
1707
1708    async fn keepaliver(
1709        router_handle: Arc<RouterHandle>,
1710        keepalive_interval: Option<Duration>,
1711        keepalive_timeout: Option<Duration>,
1712        node_address: IpAddr, // This address is only used to enrich the log messages
1713    ) -> Result<(), BrokenConnectionError> {
1714        async fn issue_keepalive_query(
1715            router_handle: &RouterHandle,
1716        ) -> Result<(), BrokenConnectionError> {
1717            router_handle
1718                .send_request(&Options, None, false)
1719                .await
1720                .map(|_| ())
1721                .map_err(|req_err| {
1722                    BrokenConnectionErrorKind::KeepaliveRequestError(Arc::new(req_err)).into()
1723                })
1724        }
1725
1726        if let Some(keepalive_interval) = keepalive_interval {
1727            let mut interval = tokio::time::interval(keepalive_interval);
1728            interval.tick().await; // Use up the first, instant tick.
1729
1730            // Default behaviour (Burst) is not suitable for sending keepalives.
1731            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
1732
1733            loop {
1734                interval.tick().await;
1735
1736                let keepalive_query = issue_keepalive_query(&router_handle);
1737                let query_result = if let Some(timeout) = keepalive_timeout {
1738                    match tokio::time::timeout(timeout, keepalive_query).await {
1739                        Ok(res) => res,
1740                        Err(_) => {
1741                            warn!(
1742                                "Timed out while waiting for response to keepalive request on connection to node {}",
1743                                node_address
1744                            );
1745                            return Err(
1746                                BrokenConnectionErrorKind::KeepaliveTimeout(node_address).into()
1747                            );
1748                        }
1749                    }
1750                } else {
1751                    keepalive_query.await
1752                };
1753                if let Err(err) = query_result {
1754                    warn!(
1755                        "Failed to execute keepalive request on connection to node {} - {}",
1756                        node_address, err
1757                    );
1758                    return Err(err);
1759                }
1760
1761                trace!(
1762                    "Keepalive request successful on connection to node {}",
1763                    node_address
1764                );
1765            }
1766        } else {
1767            // No keepalives are to be sent.
1768            Ok(())
1769        }
1770    }
1771
1772    async fn handle_event(
1773        task_response: TaskResponse,
1774        compression: Option<Compression>,
1775        event_sender: &mpsc::Sender<Event>,
1776    ) -> Result<(), CqlEventHandlingError> {
1777        // Protocol features are negotiated during connection handshake.
1778        // However, the router is already created and sent to a different tokio
1779        // task before the handshake begins, therefore it's hard to cleanly
1780        // update the protocol features in the router at this point.
1781        // Making it possible would require restructuring the handshake process,
1782        // or passing the negotiated features via a channel/mutex/etc.
1783        // Fortunately, events do not need information about protocol features
1784        // to be serialized (yet), therefore I'm leaving this problem for
1785        // future implementers.
1786        let features = ProtocolFeatures::default(); // TODO: Use the right features
1787
1788        let event = match Self::parse_response(task_response, compression, &features, None) {
1789            Ok(r) => match r.response {
1790                Response::Event(event) => event,
1791                _ => {
1792                    error!("Expected to receive Event response, got {:?}", r.response);
1793                    return Err(CqlEventHandlingError::UnexpectedResponse(
1794                        r.response.to_response_kind(),
1795                    ));
1796                }
1797            },
1798            Err(e) => match e {
1799                ResponseParseError::BodyExtensionsParseError(e) => return Err(e.into()),
1800                ResponseParseError::CqlResponseParseError(e) => match e {
1801                    CqlResponseParseError::CqlEventParseError(e) => return Err(e.into()),
1802                    // Received a response other than EVENT, but failed to deserialize it.
1803                    _ => {
1804                        return Err(CqlEventHandlingError::UnexpectedResponse(
1805                            e.to_response_kind(),
1806                        ))
1807                    }
1808                },
1809            },
1810        };
1811
1812        event_sender
1813            .send(event)
1814            .await
1815            .map_err(|_| CqlEventHandlingError::SendError)
1816    }
1817
1818    pub(crate) fn get_shard_info(&self) -> &Option<ShardInfo> {
1819        &self.features.shard_info
1820    }
1821
1822    pub(crate) fn get_shard_aware_port(&self) -> Option<u16> {
1823        self.features.shard_aware_port
1824    }
1825
1826    fn set_features(&mut self, features: ConnectionFeatures) {
1827        self.features = features;
1828    }
1829
1830    pub(crate) fn get_connect_address(&self) -> SocketAddr {
1831        self.connect_address
1832    }
1833
1834    async fn update_tablets_from_response(
1835        &self,
1836        table: &TableSpec<'_>,
1837        response: &QueryResponse,
1838    ) -> Result<(), TabletParsingError> {
1839        if let (Some(sender), Some(tablet_data)) = (
1840            self.config.tablet_sender.as_ref(),
1841            response.custom_payload.as_ref(),
1842        ) {
1843            let tablet = match RawTablet::from_custom_payload(tablet_data) {
1844                Some(Ok(v)) => v,
1845                Some(Err(e)) => return Err(e),
1846                None => return Ok(()),
1847            };
1848            tracing::trace!(
1849                "Received tablet info for table {}.{} in custom payload: {:?}",
1850                table.ks_name(),
1851                table.table_name(),
1852                tablet
1853            );
1854            let _ = sender.send((table.to_owned(), tablet)).await;
1855        }
1856
1857        Ok(())
1858    }
1859}
1860
1861async fn maybe_translated_addr(
1862    endpoint: &UntranslatedEndpoint,
1863    address_translator: Option<&dyn AddressTranslator>,
1864) -> Result<SocketAddr, TranslationError> {
1865    match *endpoint {
1866        UntranslatedEndpoint::ContactPoint(ref addr) => Ok(addr.address),
1867        UntranslatedEndpoint::Peer(PeerEndpoint {
1868            host_id,
1869            address,
1870            ref datacenter,
1871            ref rack,
1872        }) => match address {
1873            NodeAddr::Translatable(addr) => {
1874                // In this case, addr is subject to AddressTranslator.
1875                if let Some(translator) = address_translator {
1876                    let res = translator
1877                        .translate_address(&UntranslatedPeer {
1878                            host_id,
1879                            untranslated_address: addr,
1880                            datacenter: datacenter.as_deref(),
1881                            rack: rack.as_deref(),
1882                        })
1883                        .await;
1884                    if let Err(ref err) = res {
1885                        error!("Address translation failed for addr {}: {}", addr, err);
1886                    }
1887                    res
1888                } else {
1889                    Ok(addr)
1890                }
1891            }
1892            NodeAddr::Untranslatable(addr) => {
1893                // In this case, addr is considered to be translated, as it is the control connection's address.
1894                Ok(addr)
1895            }
1896        },
1897    }
1898}
1899
1900/// Opens a connection and performs its setup on CQL level:
1901/// - performs OPTIONS/STARTUP handshake (chooses desired connections options);
1902/// - registers for all event types using REGISTER request (if this is control connection).
1903///
1904/// At the beginning, translates node's address, if it is subject to address translation.
1905pub(crate) async fn open_connection(
1906    endpoint: &UntranslatedEndpoint,
1907    source_port: Option<u16>,
1908    config: &HostConnectionConfig,
1909) -> Result<(Connection, ErrorReceiver), ConnectionError> {
1910    /* Translate the address, if applicable. */
1911    let addr = maybe_translated_addr(endpoint, config.address_translator.as_deref()).await?;
1912
1913    /* Setup connection on TCP level and prepare for sending/receiving CQL frames. */
1914    let (mut connection, error_receiver) =
1915        Connection::new(addr, source_port, config.clone()).await?;
1916
1917    /* Perform OPTIONS/SUPPORTED/STARTUP handshake. */
1918
1919    // Get OPTIONS SUPPORTED by the cluster.
1920    let mut supported = connection.get_options().await?;
1921
1922    let shard_aware_port_key = match config.is_tls() {
1923        true => options::SCYLLA_SHARD_AWARE_PORT_SSL,
1924        false => options::SCYLLA_SHARD_AWARE_PORT,
1925    };
1926
1927    // If this is ScyllaDB that we connected to, we received sharding information.
1928    let shard_info = match ShardInfo::try_from(&supported.options) {
1929        Ok(info) => Some(info),
1930        Err(ShardingError::NoShardInfo) => {
1931            tracing::info!(
1932                "[{}] No sharding information received. Proceeding with no sharding info.",
1933                addr
1934            );
1935            None
1936        }
1937        Err(e) => {
1938            tracing::error!(
1939                "[{}] Error while parsing sharding information: {}. Proceeding with no sharding info.",
1940                addr, e
1941            );
1942            None
1943        }
1944    };
1945    let supported_compression = supported
1946        .options
1947        .remove(options::COMPRESSION)
1948        .unwrap_or_default();
1949    let shard_aware_port = supported
1950        .options
1951        .remove(shard_aware_port_key)
1952        .unwrap_or_default()
1953        .first()
1954        .and_then(|p| p.parse::<u16>().ok());
1955
1956    // Parse nonstandard protocol extensions.
1957    let protocol_features = ProtocolFeatures::parse_from_supported(&supported.options);
1958
1959    // At the beginning, Connection assumes no sharding and no protocol extensions;
1960    // now that we know them, let's turn them on in the driver.
1961    let features = ConnectionFeatures {
1962        shard_info,
1963        shard_aware_port,
1964        protocol_features,
1965    };
1966    connection.set_features(features);
1967
1968    /* Prepare options that the driver opts-in in STARTUP frame. */
1969    let mut options = HashMap::new();
1970    protocol_features.add_startup_options(&mut options);
1971
1972    // The only CQL protocol version supported by the driver.
1973    options.insert(
1974        Cow::Borrowed(options::CQL_VERSION),
1975        Cow::Borrowed(options::DEFAULT_CQL_PROTOCOL_VERSION),
1976    );
1977
1978    // Application & driver's identity.
1979    config.identity.add_startup_options(&mut options);
1980
1981    // Optional compression.
1982    if let Some(compression) = &config.compression {
1983        let compression_str = compression.as_str();
1984        if supported_compression.iter().any(|c| c == compression_str) {
1985            // Compression is reported to be supported by the server,
1986            // request it from the server
1987            options.insert(
1988                Cow::Borrowed(options::COMPRESSION),
1989                Cow::Borrowed(compression_str),
1990            );
1991        } else {
1992            // Fall back to no compression
1993            tracing::warn!(
1994                "Requested compression <{}> is not supported by the cluster. Falling back to no compression",
1995                compression_str
1996            );
1997            connection.config.compression = None;
1998        }
1999    }
2000
2001    /* Send the STARTUP frame with all the requested options. */
2002    let startup_result = connection.startup(options).await?;
2003    match startup_result {
2004        NonErrorStartupResponse::Ready => {}
2005        NonErrorStartupResponse::Authenticate(authenticate) => {
2006            connection.perform_authenticate(&authenticate).await?;
2007        }
2008    }
2009
2010    /* If this is a control connection, REGISTER to receive all event types. */
2011    if connection.config.event_sender.is_some() {
2012        let all_event_types = vec![
2013            EventType::TopologyChange,
2014            EventType::StatusChange,
2015            EventType::SchemaChange,
2016        ];
2017        connection.register(all_event_types).await?;
2018    }
2019
2020    Ok((connection, error_receiver))
2021}
2022
2023pub(super) async fn open_connection_to_shard_aware_port(
2024    endpoint: &UntranslatedEndpoint,
2025    shard: Shard,
2026    sharder: Sharder,
2027    config: &HostConnectionConfig,
2028) -> Result<(Connection, ErrorReceiver), ConnectionError> {
2029    // Create iterator over all possible source ports for this shard
2030    let source_port_iter =
2031        sharder.iter_source_ports_for_shard_from_range(shard, &config.shard_aware_local_port_range);
2032
2033    for port in source_port_iter {
2034        let connect_result = open_connection(endpoint, Some(port), config).await;
2035
2036        match connect_result {
2037            Err(err) if err.is_address_unavailable_for_use() => continue, // If we can't use this port, try the next one
2038            result => return result,
2039        }
2040    }
2041
2042    // Tried all source ports for that shard, give up
2043    Err(ConnectionError::NoSourcePortForShard(shard))
2044}
2045
2046async fn connect_with_source_ip_and_port(
2047    connect_address: SocketAddr,
2048    source_ip: Option<IpAddr>,
2049    source_port: Option<u16>,
2050) -> Result<TcpStream, std::io::Error> {
2051    // Binding to port 0 is equivalent to choosing random ephemeral port.
2052    let source_port = source_port.unwrap_or(0);
2053
2054    match connect_address {
2055        SocketAddr::V4(_) => {
2056            // If source_ip not provided, bind to INADDR_ANY.
2057            let source_ipv4 = source_ip.unwrap_or(Ipv4Addr::UNSPECIFIED.into());
2058            let socket = TcpSocket::new_v4()?;
2059            socket.bind(SocketAddr::new(source_ipv4, source_port))?;
2060            Ok(socket.connect(connect_address).await?)
2061        }
2062        SocketAddr::V6(_) => {
2063            // If source_ip not provided, bind to in6addr_any.
2064            let source_ipv6 = source_ip.unwrap_or(Ipv6Addr::UNSPECIFIED.into());
2065            let socket = TcpSocket::new_v6()?;
2066            socket.bind(SocketAddr::new(source_ipv6, source_port))?;
2067            Ok(socket.connect(connect_address).await?)
2068        }
2069    }
2070}
2071
2072struct OrphanageTracker {
2073    orphans: HashMap<i16, Instant>,
2074    by_orphaning_times: BTreeSet<(Instant, i16)>,
2075}
2076
2077impl OrphanageTracker {
2078    fn new() -> Self {
2079        Self {
2080            orphans: HashMap::new(),
2081            by_orphaning_times: BTreeSet::new(),
2082        }
2083    }
2084
2085    fn insert(&mut self, stream_id: i16) {
2086        let now = Instant::now();
2087        self.orphans.insert(stream_id, now);
2088        self.by_orphaning_times.insert((now, stream_id));
2089    }
2090
2091    fn remove(&mut self, stream_id: i16) {
2092        if let Some(time) = self.orphans.remove(&stream_id) {
2093            self.by_orphaning_times.remove(&(time, stream_id));
2094        }
2095    }
2096
2097    fn contains(&self, stream_id: i16) -> bool {
2098        self.orphans.contains_key(&stream_id)
2099    }
2100
2101    fn orphans_older_than(&self, age: std::time::Duration) -> usize {
2102        let minimal_age = Instant::now() - age;
2103        self.by_orphaning_times
2104            .range(..(minimal_age, i16::MAX))
2105            .count() // This has linear time complexity, but in terms of
2106                     // the number of old orphans. Healthy connection - one
2107                     // that does not have old orphaned stream ids, will
2108                     // calculate this function quickly.
2109    }
2110}
2111
2112struct ResponseHandlerMap {
2113    stream_set: StreamIdSet,
2114    handlers: HashMap<i16, ResponseHandler>,
2115
2116    request_to_stream: HashMap<RequestId, i16>,
2117    orphanage_tracker: OrphanageTracker,
2118}
2119
2120enum HandlerLookupResult {
2121    Orphaned,
2122    Handler(ResponseHandler),
2123    Missing,
2124}
2125
2126impl ResponseHandlerMap {
2127    fn new() -> Self {
2128        Self {
2129            stream_set: StreamIdSet::new(),
2130            handlers: HashMap::new(),
2131            request_to_stream: HashMap::new(),
2132            orphanage_tracker: OrphanageTracker::new(),
2133        }
2134    }
2135
2136    fn allocate(&mut self, response_handler: ResponseHandler) -> Result<i16, ResponseHandler> {
2137        if let Some(stream_id) = self.stream_set.allocate() {
2138            self.request_to_stream
2139                .insert(response_handler.request_id, stream_id);
2140            let prev_handler = self.handlers.insert(stream_id, response_handler);
2141            assert!(prev_handler.is_none());
2142
2143            Ok(stream_id)
2144        } else {
2145            Err(response_handler)
2146        }
2147    }
2148
2149    // Orphan stream_id (associated with this request_id) by moving it to
2150    // `orphanage_tracker`, and freeing its handler
2151    fn orphan(&mut self, request_id: RequestId) {
2152        if let Some(stream_id) = self.request_to_stream.get(&request_id) {
2153            debug!(
2154                "Orphaning stream_id = {} associated with request_id = {}",
2155                stream_id, request_id
2156            );
2157            self.orphanage_tracker.insert(*stream_id);
2158            self.handlers.remove(stream_id);
2159            self.request_to_stream.remove(&request_id);
2160        }
2161    }
2162
2163    fn old_orphans_count(&self) -> usize {
2164        self.orphanage_tracker
2165            .orphans_older_than(OLD_AGE_ORPHAN_THRESHOLD)
2166    }
2167
2168    fn lookup(&mut self, stream_id: i16) -> HandlerLookupResult {
2169        self.stream_set.free(stream_id);
2170
2171        if self.orphanage_tracker.contains(stream_id) {
2172            self.orphanage_tracker.remove(stream_id);
2173            // This `stream_id` had been orphaned, so its handler got removed.
2174            // This is a valid state (as opposed to missing handler)
2175            return HandlerLookupResult::Orphaned;
2176        }
2177
2178        if let Some(handler) = self.handlers.remove(&stream_id) {
2179            // A mapping `request_id` -> `stream_id` must be removed, to
2180            // prevent marking this `stream_id` as orphaned by some late
2181            // orphan notification.
2182            self.request_to_stream.remove(&handler.request_id);
2183
2184            HandlerLookupResult::Handler(handler)
2185        } else {
2186            HandlerLookupResult::Missing
2187        }
2188    }
2189
2190    // Retrieves the map of handlers, used after connection breaks
2191    // and we have to respond to all of them with an error
2192    fn into_handlers(self) -> HashMap<i16, ResponseHandler> {
2193        self.handlers
2194    }
2195}
2196
2197struct StreamIdSet {
2198    used_bitmap: Box<[u64]>,
2199}
2200
2201impl StreamIdSet {
2202    fn new() -> Self {
2203        const BITMAP_SIZE: usize = (i16::MAX as usize + 1) / 64;
2204        Self {
2205            used_bitmap: vec![0; BITMAP_SIZE].into_boxed_slice(),
2206        }
2207    }
2208
2209    fn allocate(&mut self) -> Option<i16> {
2210        for (block_id, block) in self.used_bitmap.iter_mut().enumerate() {
2211            if *block != !0 {
2212                let off = block.trailing_ones();
2213                *block |= 1u64 << off;
2214                let stream_id = off as i16 + block_id as i16 * 64;
2215                return Some(stream_id);
2216            }
2217        }
2218        None
2219    }
2220
2221    fn free(&mut self, stream_id: i16) {
2222        let block_id = stream_id as usize / 64;
2223        let off = stream_id as usize % 64;
2224        self.used_bitmap[block_id] &= !(1 << off);
2225    }
2226}
2227
2228/// This type can only hold a valid keyspace name
2229#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
2230pub(crate) struct VerifiedKeyspaceName {
2231    name: Arc<String>,
2232    pub(crate) is_case_sensitive: bool,
2233}
2234
2235impl VerifiedKeyspaceName {
2236    pub(crate) fn new(
2237        keyspace_name: String,
2238        case_sensitive: bool,
2239    ) -> Result<Self, BadKeyspaceName> {
2240        Self::verify_keyspace_name_is_valid(&keyspace_name)?;
2241
2242        Ok(VerifiedKeyspaceName {
2243            name: Arc::new(keyspace_name),
2244            is_case_sensitive: case_sensitive,
2245        })
2246    }
2247
2248    pub(crate) fn as_str(&self) -> &str {
2249        self.name.as_str()
2250    }
2251
2252    // "Keyspace names can have up to 48 alphanumeric characters and contain underscores;
2253    // only letters and numbers are supported as the first character."
2254    // https://docs.datastax.com/en/cql-oss/3.3/cql/cql_reference/cqlCreateKeyspace.html
2255    // Despite that cassandra accepts underscore as first character so we do too
2256    // https://github.com/scylladb/scylla/blob/62551b3bd382c7c47371eb3fc38173bd0cfed44d/test/cql-pytest/test_keyspace.py#L58
2257    // https://github.com/scylladb/scylla/blob/718976e794790253c4b24e2c78208e11f24e7502/cql3/statements/create_keyspace_statement.cc#L75
2258    fn verify_keyspace_name_is_valid(keyspace_name: &str) -> Result<(), BadKeyspaceName> {
2259        if keyspace_name.is_empty() {
2260            return Err(BadKeyspaceName::Empty);
2261        }
2262
2263        // Verify that length <= 48
2264        let keyspace_name_len: usize = keyspace_name.chars().count(); // Only ascii allowed so it's equal to .len()
2265        if keyspace_name_len > 48 {
2266            return Err(BadKeyspaceName::TooLong(
2267                keyspace_name.to_string(),
2268                keyspace_name_len,
2269            ));
2270        }
2271
2272        // Verify all chars are alphanumeric or underscore
2273        for character in keyspace_name.chars() {
2274            match character {
2275                'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {}
2276                _ => {
2277                    return Err(BadKeyspaceName::IllegalCharacter(
2278                        keyspace_name.to_string(),
2279                        character,
2280                    ))
2281                }
2282            };
2283        }
2284
2285        Ok(())
2286    }
2287}
2288
2289#[cfg(test)]
2290mod tests {
2291    use assert_matches::assert_matches;
2292    use scylla_cql::frame::protocol_features::{
2293        LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
2294    };
2295    use scylla_cql::frame::types;
2296    use scylla_proxy::{
2297        Condition, Node, Proxy, Reaction, RequestFrame, RequestOpcode, RequestReaction,
2298        RequestRule, ResponseFrame, ShardAwareness,
2299    };
2300
2301    use tokio::select;
2302    use tokio::sync::mpsc;
2303
2304    use super::{open_connection, HostConnectionConfig};
2305    use crate::cluster::metadata::UntranslatedEndpoint;
2306    use crate::cluster::node::ResolvedContactPoint;
2307    use crate::statement::unprepared::Statement;
2308    use crate::test_utils::setup_tracing;
2309    use crate::utils::test_utils::{resolve_hostname, unique_keyspace_name, PerformDDL};
2310    use futures::{StreamExt, TryStreamExt};
2311    use std::collections::HashMap;
2312    use std::net::SocketAddr;
2313    use std::sync::Arc;
2314    use std::time::Duration;
2315
2316    /// Tests for Connection::query_iter
2317    /// 1. SELECT from an empty table.
2318    /// 2. Create table and insert ints 0..100.
2319    ///    Then use query_iter with page_size set to 7 to select all 100 rows.
2320    /// 3. INSERT query_iter should work and not return any rows.
2321    #[tokio::test]
2322    #[cfg(not(scylla_cloud_tests))]
2323    async fn connection_query_iter_test() {
2324        use crate::client::session_builder::SessionBuilder;
2325
2326        setup_tracing();
2327        let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
2328        let addr: SocketAddr = resolve_hostname(&uri).await;
2329
2330        let (connection, _) = super::open_connection(
2331            &UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2332                address: addr,
2333                datacenter: None,
2334            }),
2335            None,
2336            &HostConnectionConfig::default(),
2337        )
2338        .await
2339        .unwrap();
2340        let connection = Arc::new(connection);
2341
2342        let ks = unique_keyspace_name();
2343
2344        {
2345            // Preparation phase
2346            let session = SessionBuilder::new()
2347                .known_node_addr(addr)
2348                .build()
2349                .await
2350                .unwrap();
2351            session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks.clone())).await.unwrap();
2352            session.use_keyspace(ks.clone(), false).await.unwrap();
2353            session
2354                .ddl("DROP TABLE IF EXISTS connection_query_iter_tab")
2355                .await
2356                .unwrap();
2357            session
2358                .ddl("CREATE TABLE IF NOT EXISTS connection_query_iter_tab (p int primary key)")
2359                .await
2360                .unwrap();
2361        }
2362
2363        connection
2364            .use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
2365            .await
2366            .unwrap();
2367
2368        // 1. SELECT from an empty table returns query result where rows are Some(Vec::new())
2369        let select_query =
2370            Statement::new("SELECT p FROM connection_query_iter_tab").with_page_size(7);
2371        let empty_res = connection
2372            .clone()
2373            .query_iter(select_query.clone())
2374            .await
2375            .unwrap()
2376            .rows_stream::<(i32,)>()
2377            .unwrap()
2378            .try_collect::<Vec<_>>()
2379            .await
2380            .unwrap();
2381        assert!(empty_res.is_empty());
2382
2383        // 2. Insert 100 and select using query_iter with page_size 7
2384        let values: Vec<i32> = (0..100).collect();
2385        let insert_query = Statement::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)")
2386            .with_page_size(7);
2387        let prepared = connection.prepare(&insert_query).await.unwrap();
2388        let mut insert_futures = Vec::new();
2389        for v in &values {
2390            let values = prepared.serialize_values(&(*v,)).unwrap();
2391            let fut = async { connection.execute_raw_unpaged(&prepared, values).await };
2392            insert_futures.push(fut);
2393        }
2394
2395        futures::future::try_join_all(insert_futures).await.unwrap();
2396
2397        let mut results: Vec<i32> = connection
2398            .clone()
2399            .query_iter(select_query.clone())
2400            .await
2401            .unwrap()
2402            .rows_stream::<(i32,)>()
2403            .unwrap()
2404            .map(|ret| ret.unwrap().0)
2405            .collect::<Vec<_>>()
2406            .await;
2407        results.sort_unstable(); // Clippy recommended to use sort_unstable instead of sort()
2408        assert_eq!(results, values);
2409
2410        // 3. INSERT query_iter should work and not return any rows.
2411        let insert_res1 = connection
2412            .query_iter(Statement::new(
2413                "INSERT INTO connection_query_iter_tab (p) VALUES (0)",
2414            ))
2415            .await
2416            .unwrap()
2417            .rows_stream::<()>()
2418            .unwrap()
2419            .try_collect::<Vec<_>>()
2420            .await
2421            .unwrap();
2422        assert!(insert_res1.is_empty());
2423    }
2424
2425    #[tokio::test]
2426    #[cfg(not(scylla_cloud_tests))]
2427    async fn test_coalescing() {
2428        use std::num::NonZeroU64;
2429
2430        use super::WriteCoalescingDelay;
2431        use crate::client::session_builder::SessionBuilder;
2432
2433        setup_tracing();
2434        // It's difficult to write a reliable test that checks whether coalescing
2435        // works like intended or not. Instead, this is a smoke test which is supposed
2436        // to trigger the coalescing logic and check that everything works fine
2437        // no matter whether coalescing is enabled or not.
2438
2439        let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
2440        let addr: SocketAddr = resolve_hostname(&uri).await;
2441        let ks = unique_keyspace_name();
2442
2443        {
2444            // Preparation phase
2445            let session = SessionBuilder::new()
2446                .known_node_addr(addr)
2447                .build()
2448                .await
2449                .unwrap();
2450            session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks.clone())).await.unwrap();
2451            session.use_keyspace(ks.clone(), false).await.unwrap();
2452            session
2453                .ddl("CREATE TABLE IF NOT EXISTS t (p int primary key, v blob)")
2454                .await
2455                .unwrap();
2456        }
2457
2458        let subtest = |write_coalescing_delay: Option<WriteCoalescingDelay>, ks: String| async move {
2459            let (connection, _) = super::open_connection(
2460                &UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2461                    address: addr,
2462                    datacenter: None,
2463                }),
2464                None,
2465                &HostConnectionConfig {
2466                    write_coalescing_delay,
2467                    ..HostConnectionConfig::default()
2468                },
2469            )
2470            .await
2471            .unwrap();
2472            let connection = Arc::new(connection);
2473
2474            connection
2475                .use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
2476                .await
2477                .unwrap();
2478
2479            connection.ddl("TRUNCATE t").await.unwrap();
2480
2481            let mut futs = Vec::new();
2482
2483            const NUM_BATCHES: i32 = 10;
2484
2485            for batch_size in 0..NUM_BATCHES {
2486                // Each future should issue more and more queries in the first poll
2487                let base = arithmetic_sequence_sum(batch_size);
2488                let conn = connection.clone();
2489                futs.push(tokio::task::spawn(async move {
2490                    let futs = (base..base + batch_size).map(|j| {
2491                        let q = Statement::new("INSERT INTO t (p, v) VALUES (?, ?)");
2492                        let conn = conn.clone();
2493                        async move {
2494                            let prepared = conn.prepare(&q).await.unwrap();
2495                            let values = prepared
2496                                .serialize_values(&(j, vec![j as u8; j as usize]))
2497                                .unwrap();
2498                            let response =
2499                                conn.execute_raw_unpaged(&prepared, values).await.unwrap();
2500                            // QueryResponse might contain an error - make sure that there were no errors
2501                            let _nonerror_response =
2502                                response.into_non_error_query_response().unwrap();
2503                        }
2504                    });
2505                    let _joined: Vec<()> = futures::future::join_all(futs).await;
2506                }));
2507
2508                tokio::task::yield_now().await;
2509            }
2510
2511            let _joined: Vec<()> = futures::future::try_join_all(futs).await.unwrap();
2512
2513            // Check that everything was written properly
2514            let range_end = arithmetic_sequence_sum(NUM_BATCHES);
2515            let mut results = connection
2516                .query_unpaged("SELECT p, v FROM t")
2517                .await
2518                .unwrap()
2519                .into_rows_result()
2520                .unwrap()
2521                .rows::<(i32, Vec<u8>)>()
2522                .unwrap()
2523                .collect::<Result<Vec<_>, _>>()
2524                .unwrap();
2525            results.sort();
2526
2527            let expected = (0..range_end)
2528                .map(|i| (i, vec![i as u8; i as usize]))
2529                .collect::<Vec<_>>();
2530
2531            assert_eq!(results, expected);
2532        };
2533
2534        // Non-deterministic sub-millisecond delay
2535        subtest(
2536            Some(WriteCoalescingDelay::SmallNondeterministic),
2537            ks.clone(),
2538        )
2539        .await;
2540        // 1ms delay
2541        subtest(
2542            Some(WriteCoalescingDelay::Milliseconds(
2543                NonZeroU64::new(1).unwrap(),
2544            )),
2545            ks.clone(),
2546        )
2547        .await;
2548        // No delay - coalescing disabled
2549        subtest(None, ks.clone()).await;
2550    }
2551
2552    // Returns the sum of integral numbers in the range [0..n)
2553    fn arithmetic_sequence_sum(n: i32) -> i32 {
2554        n * (n - 1) / 2
2555    }
2556
2557    #[tokio::test]
2558    async fn test_lwt_optimisation_mark_negotiation() {
2559        setup_tracing();
2560        const MASK: &str = "2137";
2561
2562        let lwt_optimisation_entry = format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, MASK);
2563
2564        let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042);
2565
2566        let config = HostConnectionConfig::default();
2567
2568        let (startup_tx, mut startup_rx) = mpsc::unbounded_channel();
2569
2570        let options_without_lwt_optimisation_support = HashMap::<String, Vec<String>>::new();
2571        let options_with_lwt_optimisation_support = [(
2572            SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION.into(),
2573            vec![lwt_optimisation_entry.clone()],
2574        )]
2575        .into_iter()
2576        .collect::<HashMap<String, Vec<String>>>();
2577
2578        let make_rules = |options| {
2579            vec![
2580                RequestRule(
2581                    Condition::RequestOpcode(RequestOpcode::Options),
2582                    RequestReaction::forge_response(Arc::new(move |frame: RequestFrame| {
2583                        ResponseFrame::forged_supported(frame.params, &options).unwrap()
2584                    })),
2585                ),
2586                RequestRule(
2587                    Condition::RequestOpcode(RequestOpcode::Startup),
2588                    RequestReaction::drop_frame().with_feedback_when_performed(startup_tx.clone()),
2589                ),
2590            ]
2591        };
2592
2593        let mut proxy = Proxy::builder()
2594            .with_node(
2595                Node::builder()
2596                    .proxy_address(proxy_addr)
2597                    .request_rules(make_rules(options_without_lwt_optimisation_support))
2598                    .build_dry_mode(),
2599            )
2600            .build()
2601            .run()
2602            .await
2603            .unwrap();
2604
2605        // We must interrupt the driver's full connection opening, because our proxy does not interact further after Startup.
2606        let endpoint = UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2607            address: proxy_addr,
2608            datacenter: None,
2609        });
2610        let (startup_without_lwt_optimisation, _shard) = select! {
2611            _ = open_connection(&endpoint, None, &config) => unreachable!(),
2612            startup = startup_rx.recv() => startup.unwrap(),
2613        };
2614
2615        proxy.running_nodes[0]
2616            .change_request_rules(Some(make_rules(options_with_lwt_optimisation_support)));
2617
2618        let (startup_with_lwt_optimisation, _shard) = select! {
2619            _ = open_connection(&endpoint, None, &config) => unreachable!(),
2620            startup = startup_rx.recv() => startup.unwrap(),
2621        };
2622
2623        let _ = proxy.finish().await;
2624
2625        let chosen_options =
2626            types::read_string_map(&mut &*startup_without_lwt_optimisation.body).unwrap();
2627        assert!(!chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
2628
2629        let chosen_options =
2630            types::read_string_map(&mut &startup_with_lwt_optimisation.body[..]).unwrap();
2631        assert!(chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
2632        assert_eq!(
2633            chosen_options
2634                .get(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION)
2635                .unwrap(),
2636            &lwt_optimisation_entry
2637        )
2638    }
2639
2640    #[tokio::test]
2641    #[ntest::timeout(20000)]
2642    #[cfg(not(scylla_cloud_tests))]
2643    async fn connection_is_closed_on_no_response_to_keepalives() {
2644        use crate::errors::BrokenConnectionErrorKind;
2645
2646        setup_tracing();
2647
2648        let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042);
2649        let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
2650        let node_addr: SocketAddr = resolve_hostname(&uri).await;
2651
2652        let drop_options_rule = RequestRule(
2653            Condition::RequestOpcode(RequestOpcode::Options),
2654            RequestReaction::drop_frame(),
2655        );
2656
2657        let config = HostConnectionConfig {
2658            keepalive_interval: Some(Duration::from_millis(500)),
2659            keepalive_timeout: Some(Duration::from_secs(1)),
2660            ..Default::default()
2661        };
2662
2663        let mut proxy = Proxy::builder()
2664            .with_node(
2665                Node::builder()
2666                    .proxy_address(proxy_addr)
2667                    .real_address(node_addr)
2668                    .shard_awareness(ShardAwareness::QueryNode)
2669                    .build(),
2670            )
2671            .build()
2672            .run()
2673            .await
2674            .unwrap();
2675
2676        // Setup connection normally, without obstruction
2677        let (conn, mut error_receiver) = open_connection(
2678            &UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2679                address: proxy_addr,
2680                datacenter: None,
2681            }),
2682            None,
2683            &config,
2684        )
2685        .await
2686        .unwrap();
2687
2688        // As everything is normal, these queries should succeed.
2689        for _ in 0..3 {
2690            tokio::time::sleep(Duration::from_millis(500)).await;
2691            conn.query_unpaged("SELECT host_id FROM system.local WHERE key='local'")
2692                .await
2693                .unwrap();
2694        }
2695        // As everything is normal, no error should have been reported.
2696        assert_matches!(
2697            error_receiver.try_recv(),
2698            Err(tokio::sync::oneshot::error::TryRecvError::Empty)
2699        );
2700
2701        // Set up proxy to drop keepalive messages
2702        proxy.running_nodes[0].change_request_rules(Some(vec![drop_options_rule]));
2703
2704        // Wait until keepaliver gots impatient and terminates router.
2705        // Then, the error from keepaliver will be propagated to the error receiver.
2706        let err = error_receiver.await.unwrap();
2707        let err_inner: &BrokenConnectionErrorKind = match err {
2708            super::ConnectionError::BrokenConnection(ref e) => e.downcast_ref().unwrap(),
2709            _ => panic!("Bad error type. Expected keepalive timeout."),
2710        };
2711        assert_matches!(err_inner, BrokenConnectionErrorKind::KeepaliveTimeout(_));
2712
2713        // As the router is invalidated, all further queries should immediately
2714        // return error.
2715        conn.query_unpaged("SELECT host_id FROM system.local WHERE key='local'")
2716            .await
2717            .unwrap_err();
2718
2719        let _ = proxy.finish().await;
2720    }
2721}