1use super::tls::{TlsConfig, TlsProvider};
2use crate::authentication::AuthenticatorProvider;
3use crate::client::pager::{NextRowError, QueryPager};
4use crate::client::Compression;
5use crate::client::SelfIdentity;
6use crate::cluster::metadata::{PeerEndpoint, UntranslatedEndpoint};
7use crate::cluster::NodeAddr;
8use crate::errors::{
9 BadKeyspaceName, BrokenConnectionError, BrokenConnectionErrorKind, ConnectionError,
10 ConnectionSetupRequestError, ConnectionSetupRequestErrorKind, CqlEventHandlingError, DbError,
11 InternalRequestError, RequestAttemptError, ResponseParseError, SchemaAgreementError,
12 TranslationError, UseKeyspaceError,
13};
14use crate::frame::protocol_features::ProtocolFeatures;
15use crate::frame::{
16 self,
17 request::{self, batch, execute, query, register, SerializableRequest},
18 response::{event::Event, result, Response, ResponseOpcode},
19 server_event_type::EventType,
20 FrameParams, SerializedRequest,
21};
22use crate::policies::address_translator::{AddressTranslator, UntranslatedPeer};
23use crate::policies::timestamp_generator::TimestampGenerator;
24use crate::response::query_result::QueryResult;
25use crate::response::{
26 NonErrorAuthResponse, NonErrorStartupResponse, PagingState, PagingStateResponse, QueryResponse,
27};
28use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
29use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
30use crate::statement::batch::{Batch, BatchStatement};
31use crate::statement::prepared::PreparedStatement;
32use crate::statement::unprepared::Statement;
33use crate::statement::{Consistency, PageSize};
34use bytes::Bytes;
35use futures::{future::RemoteHandle, FutureExt};
36use scylla_cql::frame::frame_errors::CqlResponseParseError;
37use scylla_cql::frame::request::options::{self, Options};
38use scylla_cql::frame::request::CqlRequestKind;
39use scylla_cql::frame::response::authenticate::Authenticate;
40use scylla_cql::frame::response::result::{ResultMetadata, TableSpec};
41use scylla_cql::frame::response::Error;
42use scylla_cql::frame::response::{self, error};
43use scylla_cql::frame::types::SerialConsistency;
44use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
45use scylla_cql::serialize::raw_batch::RawBatchValuesAdapter;
46use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
47use socket2::{SockRef, TcpKeepalive};
48use std::borrow::Cow;
49use std::collections::{BTreeSet, HashMap, HashSet};
50use std::convert::TryFrom;
51use std::net::{IpAddr, SocketAddr};
52use std::num::NonZeroU64;
53use std::sync::atomic::AtomicU64;
54use std::sync::Arc;
55use std::sync::Mutex as StdMutex;
56use std::time::Duration;
57use std::{
58 cmp::Ordering,
59 net::{Ipv4Addr, Ipv6Addr},
60};
61use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
62use tokio::net::{TcpSocket, TcpStream};
63use tokio::sync::{mpsc, oneshot};
64use tokio::time::Instant;
65use tracing::{debug, error, trace, warn};
66use uuid::Uuid;
67
68const LOCAL_VERSION: &str = "SELECT schema_version FROM system.local WHERE key='local'";
70
71const OLD_ORPHAN_COUNT_THRESHOLD: usize = 1024;
78const OLD_AGE_ORPHAN_THRESHOLD: std::time::Duration = std::time::Duration::from_secs(1);
79
80#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub enum WriteCoalescingDelay {
84 SmallNondeterministic,
90
91 Milliseconds(NonZeroU64),
97}
98
99pub(crate) struct Connection {
100 _worker_handle: RemoteHandle<()>,
101
102 connect_address: SocketAddr,
103 config: HostConnectionConfig,
104 features: ConnectionFeatures,
105 router_handle: Arc<RouterHandle>,
106}
107
108struct RouterHandle {
109 submit_channel: mpsc::Sender<Task>,
110
111 request_id_generator: AtomicU64,
114 orphan_notification_sender: mpsc::UnboundedSender<RequestId>,
120}
121
122impl RouterHandle {
123 fn allocate_request_id(&self) -> RequestId {
124 self.request_id_generator
125 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
126 }
127
128 async fn send_request(
129 &self,
130 request: &impl SerializableRequest,
131 compression: Option<Compression>,
132 tracing: bool,
133 ) -> Result<TaskResponse, InternalRequestError> {
134 let serialized_request = SerializedRequest::make(request, compression, tracing)?;
135 let request_id = self.allocate_request_id();
136
137 let (response_sender, receiver) = oneshot::channel();
138 let response_handler = ResponseHandler {
139 response_sender,
140 request_id,
141 };
142
143 let notifier = OrphanhoodNotifier::new(request_id, &self.orphan_notification_sender);
147
148 self.submit_channel
149 .send(Task {
150 serialized_request,
151 response_handler,
152 })
153 .await
154 .map_err(|_| -> BrokenConnectionError {
155 BrokenConnectionErrorKind::ChannelError.into()
156 })?;
157
158 let task_response = receiver.await.map_err(|_| -> BrokenConnectionError {
159 BrokenConnectionErrorKind::ChannelError.into()
160 })?;
161
162 notifier.disable();
165
166 task_response
167 }
168}
169
170#[derive(Default)]
171pub(crate) struct ConnectionFeatures {
172 shard_info: Option<ShardInfo>,
173 shard_aware_port: Option<u16>,
174 protocol_features: ProtocolFeatures,
175}
176
177type RequestId = u64;
178
179struct ResponseHandler {
180 response_sender: oneshot::Sender<Result<TaskResponse, InternalRequestError>>,
181 request_id: RequestId,
182}
183
184struct OrphanhoodNotifier<'a> {
187 enabled: bool,
188 request_id: RequestId,
189 notification_sender: &'a mpsc::UnboundedSender<RequestId>,
190}
191
192impl<'a> OrphanhoodNotifier<'a> {
193 fn new(
194 request_id: RequestId,
195 notification_sender: &'a mpsc::UnboundedSender<RequestId>,
196 ) -> Self {
197 Self {
198 enabled: true,
199 request_id,
200 notification_sender,
201 }
202 }
203
204 fn disable(mut self) {
205 self.enabled = false;
206 }
207}
208
209impl Drop for OrphanhoodNotifier<'_> {
210 fn drop(&mut self) {
211 if self.enabled {
212 let _ = self.notification_sender.send(self.request_id);
213 }
214 }
215}
216
217struct Task {
218 serialized_request: SerializedRequest,
219 response_handler: ResponseHandler,
220}
221
222struct TaskResponse {
223 params: FrameParams,
224 opcode: ResponseOpcode,
225 body: Bytes,
226}
227
228impl<'id: 'map, 'map> SelfIdentity<'id> {
229 fn add_startup_options(&'id self, options: &'map mut HashMap<Cow<'id, str>, Cow<'id, str>>) {
230 let driver_name = self
232 .get_custom_driver_name()
233 .unwrap_or(options::DEFAULT_DRIVER_NAME);
234 options.insert(
235 Cow::Borrowed(options::DRIVER_NAME),
236 Cow::Borrowed(driver_name),
237 );
238
239 let driver_version = self
240 .get_custom_driver_version()
241 .unwrap_or(options::DEFAULT_DRIVER_VERSION);
242 options.insert(
243 Cow::Borrowed(options::DRIVER_VERSION),
244 Cow::Borrowed(driver_version),
245 );
246
247 if let Some(application_name) = self.get_application_name() {
249 options.insert(
250 Cow::Borrowed(options::APPLICATION_NAME),
251 Cow::Borrowed(application_name),
252 );
253 }
254
255 if let Some(application_version) = self.get_application_version() {
256 options.insert(
257 Cow::Borrowed(options::APPLICATION_VERSION),
258 Cow::Borrowed(application_version),
259 );
260 }
261
262 if let Some(client_id) = self.get_client_id() {
264 options.insert(Cow::Borrowed(options::CLIENT_ID), Cow::Borrowed(client_id));
265 }
266 }
267}
268
269#[derive(Clone)]
275pub(crate) struct ConnectionConfig {
276 pub(crate) local_ip_address: Option<IpAddr>,
277 pub(crate) shard_aware_local_port_range: ShardAwarePortRange,
278 pub(crate) compression: Option<Compression>,
279 pub(crate) tcp_nodelay: bool,
280 pub(crate) tcp_keepalive_interval: Option<Duration>,
281 pub(crate) timestamp_generator: Option<Arc<dyn TimestampGenerator>>,
282 pub(crate) tls_provider: Option<TlsProvider>,
283 pub(crate) connect_timeout: std::time::Duration,
284 pub(crate) event_sender: Option<mpsc::Sender<Event>>,
286 pub(crate) default_consistency: Consistency,
287 pub(crate) authenticator: Option<Arc<dyn AuthenticatorProvider>>,
288 pub(crate) address_translator: Option<Arc<dyn AddressTranslator>>,
289 pub(crate) write_coalescing_delay: Option<WriteCoalescingDelay>,
290
291 pub(crate) keepalive_interval: Option<Duration>,
292 pub(crate) keepalive_timeout: Option<Duration>,
293 pub(crate) tablet_sender: Option<mpsc::Sender<(TableSpec<'static>, RawTablet)>>,
294
295 pub(crate) identity: SelfIdentity<'static>,
296}
297
298impl ConnectionConfig {
299 pub(crate) fn to_host_connection_config(
301 &self,
302 #[allow(unused)] endpoint: &UntranslatedEndpoint,
305 ) -> HostConnectionConfig {
306 let tls_config = self
307 .tls_provider
308 .as_ref()
309 .and_then(|provider| provider.make_tls_config(endpoint));
310
311 HostConnectionConfig {
312 local_ip_address: self.local_ip_address,
313 shard_aware_local_port_range: self.shard_aware_local_port_range.clone(),
314 compression: self.compression,
315 tcp_nodelay: self.tcp_nodelay,
316 tcp_keepalive_interval: self.tcp_keepalive_interval,
317 timestamp_generator: self.timestamp_generator.clone(),
318 tls_config,
319 connect_timeout: self.connect_timeout,
320 event_sender: self.event_sender.clone(),
321 default_consistency: self.default_consistency,
322 authenticator: self.authenticator.clone(),
323 address_translator: self.address_translator.clone(),
324 write_coalescing_delay: self.write_coalescing_delay.clone(),
325 keepalive_interval: self.keepalive_interval,
326 keepalive_timeout: self.keepalive_timeout,
327 tablet_sender: self.tablet_sender.clone(),
328 identity: self.identity.clone(),
329 }
330 }
331}
332
333#[derive(Clone)]
337pub(crate) struct HostConnectionConfig {
338 pub(crate) local_ip_address: Option<IpAddr>,
339 pub(crate) shard_aware_local_port_range: ShardAwarePortRange,
340 pub(crate) compression: Option<Compression>,
341 pub(crate) tcp_nodelay: bool,
342 pub(crate) tcp_keepalive_interval: Option<Duration>,
343 pub(crate) timestamp_generator: Option<Arc<dyn TimestampGenerator>>,
344 pub(crate) tls_config: Option<TlsConfig>,
345 pub(crate) connect_timeout: std::time::Duration,
346 pub(crate) event_sender: Option<mpsc::Sender<Event>>,
348 pub(crate) default_consistency: Consistency,
349 pub(crate) authenticator: Option<Arc<dyn AuthenticatorProvider>>,
350 pub(crate) address_translator: Option<Arc<dyn AddressTranslator>>,
351 pub(crate) write_coalescing_delay: Option<WriteCoalescingDelay>,
352
353 pub(crate) keepalive_interval: Option<Duration>,
354 pub(crate) keepalive_timeout: Option<Duration>,
355 pub(crate) tablet_sender: Option<mpsc::Sender<(TableSpec<'static>, RawTablet)>>,
356
357 pub(crate) identity: SelfIdentity<'static>,
358}
359
360#[cfg(test)]
361impl Default for HostConnectionConfig {
362 fn default() -> Self {
363 Self {
364 local_ip_address: None,
365 shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
366 compression: None,
367 tcp_nodelay: true,
368 tcp_keepalive_interval: None,
369 timestamp_generator: None,
370 event_sender: None,
371 tls_config: None,
372 connect_timeout: std::time::Duration::from_secs(5),
373 default_consistency: Default::default(),
374 authenticator: None,
375 address_translator: None,
376 write_coalescing_delay: Some(WriteCoalescingDelay::SmallNondeterministic),
377
378 keepalive_interval: None,
380 keepalive_timeout: None,
381
382 tablet_sender: None,
383
384 identity: SelfIdentity::default(),
385 }
386 }
387}
388
389#[cfg(test)]
390impl Default for ConnectionConfig {
391 fn default() -> Self {
392 Self {
393 local_ip_address: None,
394 shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
395 compression: None,
396 tcp_nodelay: true,
397 tcp_keepalive_interval: None,
398 timestamp_generator: None,
399 event_sender: None,
400 tls_provider: None,
401 connect_timeout: std::time::Duration::from_secs(5),
402 default_consistency: Default::default(),
403 authenticator: None,
404 address_translator: None,
405 write_coalescing_delay: Some(WriteCoalescingDelay::SmallNondeterministic),
406
407 keepalive_interval: None,
409 keepalive_timeout: None,
410
411 tablet_sender: None,
412
413 identity: SelfIdentity::default(),
414 }
415 }
416}
417
418impl HostConnectionConfig {
419 fn is_tls(&self) -> bool {
420 self.tls_config.is_some()
421 }
422}
423
424pub(crate) type ErrorReceiver = tokio::sync::oneshot::Receiver<ConnectionError>;
426
427impl Connection {
428 async fn new(
432 connect_address: SocketAddr,
433 source_port: Option<u16>,
434 config: HostConnectionConfig,
435 ) -> Result<(Self, ErrorReceiver), ConnectionError> {
436 let stream_connector = tokio::time::timeout(
437 config.connect_timeout,
438 connect_with_source_ip_and_port(connect_address, config.local_ip_address, source_port),
439 )
440 .await;
441 let stream = match stream_connector {
442 Ok(stream) => stream?,
443 Err(_) => {
444 return Err(ConnectionError::ConnectTimeout);
445 }
446 };
447 stream.set_nodelay(config.tcp_nodelay)?;
448
449 if let Some(tcp_keepalive_interval) = config.tcp_keepalive_interval {
450 Self::setup_tcp_keepalive(&stream, tcp_keepalive_interval)?;
451 }
452
453 let (sender, receiver) = mpsc::channel(1024);
455 let (error_sender, error_receiver) = tokio::sync::oneshot::channel();
456 let (orphan_notification_sender, orphan_notification_receiver) = mpsc::unbounded_channel();
458
459 let router_handle = Arc::new(RouterHandle {
460 submit_channel: sender,
461 request_id_generator: AtomicU64::new(0),
462 orphan_notification_sender,
463 });
464
465 let _worker_handle = Self::run_router(
466 config.clone(),
467 stream,
468 receiver,
469 error_sender,
470 orphan_notification_receiver,
471 router_handle.clone(),
472 connect_address.ip(),
473 )
474 .await?;
475
476 let connection = Connection {
477 _worker_handle,
478 config,
479 features: Default::default(),
480 connect_address,
481 router_handle,
482 };
483
484 Ok((connection, error_receiver))
485 }
486
487 fn setup_tcp_keepalive(
488 stream: &TcpStream,
489 tcp_keepalive_interval: Duration,
490 ) -> std::io::Result<()> {
491 let mut tcp_keepalive = TcpKeepalive::new().with_time(tcp_keepalive_interval);
498
499 #[cfg(any(
501 target_os = "android",
502 target_os = "dragonfly",
503 target_os = "freebsd",
504 target_os = "fuchsia",
505 target_os = "illumos",
506 target_os = "ios",
507 target_os = "linux",
508 target_os = "macos",
509 target_os = "netbsd",
510 target_os = "tvos",
511 target_os = "watchos",
512 target_os = "windows",
513 ))]
514 {
515 tcp_keepalive = tcp_keepalive.with_interval(Duration::from_secs(1));
516 }
517
518 #[cfg(any(
519 target_os = "android",
520 target_os = "dragonfly",
521 target_os = "freebsd",
522 target_os = "fuchsia",
523 target_os = "illumos",
524 target_os = "ios",
525 target_os = "linux",
526 target_os = "macos",
527 target_os = "netbsd",
528 target_os = "tvos",
529 target_os = "watchos",
530 ))]
531 {
532 tcp_keepalive = tcp_keepalive.with_retries(10);
533 }
534
535 let sf = SockRef::from(&stream);
536 sf.set_tcp_keepalive(&tcp_keepalive)
537 }
538
539 async fn startup(
540 &self,
541 options: HashMap<Cow<'_, str>, Cow<'_, str>>,
542 ) -> Result<NonErrorStartupResponse, ConnectionSetupRequestError> {
543 let err = |kind: ConnectionSetupRequestErrorKind| {
544 ConnectionSetupRequestError::new(CqlRequestKind::Startup, kind)
545 };
546
547 let req_result = self
548 .send_request(&request::Startup { options }, false, false, None)
549 .await;
550
551 let response = match req_result {
553 Ok(r) => match r.response {
554 Response::Ready => NonErrorStartupResponse::Ready,
555 Response::Authenticate(auth) => NonErrorStartupResponse::Authenticate(auth),
556 Response::Error(Error { error, reason }) => {
557 return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
558 }
559 _ => {
560 return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
561 r.response.to_response_kind(),
562 )))
563 }
564 },
565 Err(e) => match e {
566 InternalRequestError::CqlRequestSerialization(e) => return Err(err(e.into())),
567 InternalRequestError::BodyExtensionsParseError(e) => return Err(err(e.into())),
568 InternalRequestError::CqlResponseParseError(e) => match e {
569 CqlResponseParseError::CqlAuthenticateParseError(e) => {
572 return Err(err(e.into()))
573 }
574 CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())),
575 _ => {
576 return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
577 e.to_response_kind(),
578 )))
579 }
580 },
581 InternalRequestError::BrokenConnection(e) => return Err(err(e.into())),
582 InternalRequestError::UnableToAllocStreamId => {
583 return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
584 }
585 },
586 };
587
588 Ok(response)
589 }
590
591 async fn get_options(&self) -> Result<response::Supported, ConnectionSetupRequestError> {
592 let err = |kind: ConnectionSetupRequestErrorKind| {
593 ConnectionSetupRequestError::new(CqlRequestKind::Options, kind)
594 };
595
596 let req_result = self
597 .send_request(&request::Options {}, false, false, None)
598 .await;
599
600 let supported = match req_result {
602 Ok(r) => match r.response {
603 Response::Supported(supported) => supported,
604 Response::Error(Error { error, reason }) => {
605 return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
606 }
607 _ => {
608 return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
609 r.response.to_response_kind(),
610 )))
611 }
612 },
613 Err(e) => match e {
614 InternalRequestError::CqlRequestSerialization(e) => return Err(err(e.into())),
615 InternalRequestError::BodyExtensionsParseError(e) => return Err(err(e.into())),
616 InternalRequestError::CqlResponseParseError(e) => match e {
617 CqlResponseParseError::CqlSupportedParseError(e) => return Err(err(e.into())),
618 CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())),
619 _ => {
620 return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
621 e.to_response_kind(),
622 )))
623 }
624 },
625 InternalRequestError::BrokenConnection(e) => return Err(err(e.into())),
626 InternalRequestError::UnableToAllocStreamId => {
627 return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
628 }
629 },
630 };
631
632 Ok(supported)
633 }
634
635 pub(crate) async fn prepare(
636 &self,
637 query: &Statement,
638 ) -> Result<PreparedStatement, RequestAttemptError> {
639 let query_response = self
640 .send_request(
641 &request::Prepare {
642 query: &query.contents,
643 },
644 true,
645 query.config.tracing,
646 None,
647 )
648 .await?;
649
650 let mut prepared_statement = match query_response.response {
651 Response::Error(error::Error { error, reason }) => {
652 return Err(RequestAttemptError::DbError(error, reason))
653 }
654 Response::Result(result::Result::Prepared(p)) => PreparedStatement::new(
655 p.id,
656 self.features
657 .protocol_features
658 .prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32),
659 p.prepared_metadata,
660 Arc::new(p.result_metadata),
661 query.contents.clone(),
662 query.get_validated_page_size(),
663 query.config.clone(),
664 ),
665 _ => {
666 return Err(RequestAttemptError::UnexpectedResponse(
667 query_response.response.to_response_kind(),
668 ))
669 }
670 };
671
672 if let Some(tracing_id) = query_response.tracing_id {
673 prepared_statement.prepare_tracing_ids.push(tracing_id);
674 }
675 Ok(prepared_statement)
676 }
677
678 async fn reprepare(
679 &self,
680 query: impl Into<Statement>,
681 previous_prepared: &PreparedStatement,
682 ) -> Result<(), RequestAttemptError> {
683 let reprepare_query: Statement = query.into();
684 let reprepared = self.prepare(&reprepare_query).await?;
685 if reprepared.get_id() != previous_prepared.get_id() {
688 Err(RequestAttemptError::RepreparedIdChanged {
689 statement: reprepare_query.contents,
690 expected_id: previous_prepared.get_id().clone().into(),
691 reprepared_id: reprepared.get_id().clone().into(),
692 })
693 } else {
694 Ok(())
695 }
696 }
697
698 async fn perform_authenticate(
699 &mut self,
700 authenticate: &Authenticate,
701 ) -> Result<(), ConnectionSetupRequestError> {
702 let err = |kind: ConnectionSetupRequestErrorKind| {
703 ConnectionSetupRequestError::new(CqlRequestKind::AuthResponse, kind)
704 };
705
706 let authenticator = &authenticate.authenticator_name as &str;
707
708 match self.config.authenticator {
709 Some(ref authenticator_provider) => {
710 let (mut response, mut auth_session) = authenticator_provider
711 .start_authentication_session(authenticator)
712 .await
713 .map_err(|e| err(ConnectionSetupRequestErrorKind::StartAuthSessionError(e)))?;
714
715 loop {
716 match self.authenticate_response(response).await? {
717 NonErrorAuthResponse::AuthChallenge(challenge) => {
718 response = auth_session
719 .evaluate_challenge(challenge.authenticate_message.as_deref())
720 .await
721 .map_err(|e| {
722 err(
723 ConnectionSetupRequestErrorKind::AuthChallengeEvaluationError(
724 e,
725 ),
726 )
727 })?;
728 }
729 NonErrorAuthResponse::AuthSuccess(success) => {
730 auth_session
731 .success(success.success_message.as_deref())
732 .await
733 .map_err(|e| {
734 err(ConnectionSetupRequestErrorKind::AuthFinishError(e))
735 })?;
736 break;
737 }
738 }
739 }
740 }
741 None => return Err(err(ConnectionSetupRequestErrorKind::MissingAuthentication)),
742 }
743
744 Ok(())
745 }
746
747 async fn authenticate_response(
748 &self,
749 response: Option<Vec<u8>>,
750 ) -> Result<NonErrorAuthResponse, ConnectionSetupRequestError> {
751 let err = |kind: ConnectionSetupRequestErrorKind| {
752 ConnectionSetupRequestError::new(CqlRequestKind::AuthResponse, kind)
753 };
754
755 let req_result = self
756 .send_request(&request::AuthResponse { response }, false, false, None)
757 .await;
758
759 let response = match req_result {
761 Ok(r) => match r.response {
762 Response::AuthSuccess(auth_success) => {
763 NonErrorAuthResponse::AuthSuccess(auth_success)
764 }
765 Response::AuthChallenge(auth_challenge) => {
766 NonErrorAuthResponse::AuthChallenge(auth_challenge)
767 }
768 Response::Error(Error { error, reason }) => {
769 return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
770 }
771 _ => {
772 return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
773 r.response.to_response_kind(),
774 )))
775 }
776 },
777 Err(e) => match e {
778 InternalRequestError::CqlRequestSerialization(e) => return Err(err(e.into())),
779 InternalRequestError::BodyExtensionsParseError(e) => return Err(err(e.into())),
780 InternalRequestError::CqlResponseParseError(e) => match e {
781 CqlResponseParseError::CqlAuthSuccessParseError(e) => {
782 return Err(err(e.into()))
783 }
784 CqlResponseParseError::CqlAuthChallengeParseError(e) => {
785 return Err(err(e.into()))
786 }
787 CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())),
788 _ => {
789 return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
790 e.to_response_kind(),
791 )))
792 }
793 },
794 InternalRequestError::BrokenConnection(e) => return Err(err(e.into())),
795 InternalRequestError::UnableToAllocStreamId => {
796 return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
797 }
798 },
799 };
800
801 Ok(response)
802 }
803
804 #[allow(dead_code)]
805 pub(crate) async fn query_single_page(
806 &self,
807 query: impl Into<Statement>,
808 paging_state: PagingState,
809 ) -> Result<(QueryResult, PagingStateResponse), RequestAttemptError> {
810 let query: Statement = query.into();
811
812 let consistency = query
814 .config
815 .determine_consistency(self.config.default_consistency);
816 let serial_consistency = query.config.serial_consistency;
817
818 self.query_single_page_with_consistency(
819 query,
820 paging_state,
821 consistency,
822 serial_consistency.flatten(),
823 )
824 .await
825 }
826
827 #[allow(dead_code)]
828 pub(crate) async fn query_single_page_with_consistency(
829 &self,
830 query: impl Into<Statement>,
831 paging_state: PagingState,
832 consistency: Consistency,
833 serial_consistency: Option<SerialConsistency>,
834 ) -> Result<(QueryResult, PagingStateResponse), RequestAttemptError> {
835 let query: Statement = query.into();
836 let page_size = query.get_validated_page_size();
837
838 self.query_raw_with_consistency(
839 &query,
840 consistency,
841 serial_consistency,
842 Some(page_size),
843 paging_state,
844 )
845 .await?
846 .into_query_result_and_paging_state()
847 }
848
849 #[allow(dead_code)]
850 pub(crate) async fn query_unpaged(
851 &self,
852 statement: impl Into<Statement>,
853 ) -> Result<QueryResult, RequestAttemptError> {
854 let statement: Statement = statement.into();
856
857 self.query_raw_unpaged(&statement)
858 .await
859 .and_then(QueryResponse::into_query_result)
860 }
861
862 pub(crate) async fn query_raw_unpaged(
863 &self,
864 statement: &Statement,
865 ) -> Result<QueryResponse, RequestAttemptError> {
866 self.query_raw_with_consistency(
868 statement,
869 statement
870 .config
871 .determine_consistency(self.config.default_consistency),
872 statement.config.serial_consistency.flatten(),
873 None,
874 PagingState::start(),
875 )
876 .await
877 }
878
879 pub(crate) async fn query_raw_with_consistency(
880 &self,
881 statement: &Statement,
882 consistency: Consistency,
883 serial_consistency: Option<SerialConsistency>,
884 page_size: Option<PageSize>,
885 paging_state: PagingState,
886 ) -> Result<QueryResponse, RequestAttemptError> {
887 let get_timestamp_from_gen = || {
888 self.config
889 .timestamp_generator
890 .as_ref()
891 .map(|gen| gen.next_timestamp())
892 };
893 let timestamp = statement.get_timestamp().or_else(get_timestamp_from_gen);
894
895 let query_frame = query::Query {
896 contents: Cow::Borrowed(&statement.contents),
897 parameters: query::QueryParameters {
898 consistency,
899 serial_consistency,
900 values: Cow::Borrowed(SerializedValues::EMPTY),
901 page_size: page_size.map(Into::into),
902 paging_state,
903 skip_metadata: false,
904 timestamp,
905 },
906 };
907
908 let response = self
909 .send_request(&query_frame, true, statement.config.tracing, None)
910 .await?;
911
912 Ok(response)
913 }
914
915 #[allow(dead_code)]
916 pub(crate) async fn execute_unpaged(
917 &self,
918 prepared: &PreparedStatement,
919 values: SerializedValues,
920 ) -> Result<QueryResult, RequestAttemptError> {
921 self.execute_raw_unpaged(prepared, values)
923 .await
924 .and_then(QueryResponse::into_query_result)
925 }
926
927 #[allow(dead_code)]
928 pub(crate) async fn execute_raw_unpaged(
929 &self,
930 prepared: &PreparedStatement,
931 values: SerializedValues,
932 ) -> Result<QueryResponse, RequestAttemptError> {
933 self.execute_raw_with_consistency(
935 prepared,
936 &values,
937 prepared
938 .config
939 .determine_consistency(self.config.default_consistency),
940 prepared.config.serial_consistency.flatten(),
941 None,
942 PagingState::start(),
943 )
944 .await
945 }
946
947 pub(crate) async fn execute_raw_with_consistency(
948 &self,
949 prepared_statement: &PreparedStatement,
950 values: &SerializedValues,
951 consistency: Consistency,
952 serial_consistency: Option<SerialConsistency>,
953 page_size: Option<PageSize>,
954 paging_state: PagingState,
955 ) -> Result<QueryResponse, RequestAttemptError> {
956 let get_timestamp_from_gen = || {
957 self.config
958 .timestamp_generator
959 .as_ref()
960 .map(|gen| gen.next_timestamp())
961 };
962 let timestamp = prepared_statement
963 .get_timestamp()
964 .or_else(get_timestamp_from_gen);
965
966 let execute_frame = execute::Execute {
967 id: prepared_statement.get_id().to_owned(),
968 parameters: query::QueryParameters {
969 consistency,
970 serial_consistency,
971 values: Cow::Borrowed(values),
972 page_size: page_size.map(Into::into),
973 timestamp,
974 skip_metadata: prepared_statement.get_use_cached_result_metadata(),
975 paging_state,
976 },
977 };
978
979 let cached_metadata = prepared_statement
980 .get_use_cached_result_metadata()
981 .then(|| prepared_statement.get_result_metadata());
982
983 let query_response = self
984 .send_request(
985 &execute_frame,
986 true,
987 prepared_statement.config.tracing,
988 cached_metadata,
989 )
990 .await?;
991
992 if let Some(spec) = prepared_statement.get_table_spec() {
993 if let Err(e) = self
994 .update_tablets_from_response(spec, &query_response)
995 .await
996 {
997 tracing::warn!("Error while parsing tablet info from custom payload: {}", e);
998 }
999 }
1000
1001 match &query_response.response {
1002 Response::Error(frame::response::Error {
1003 error: DbError::Unprepared { statement_id },
1004 ..
1005 }) => {
1006 debug!("Connection::execute: Got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1007 self.reprepare(prepared_statement.get_statement(), prepared_statement)
1009 .await?;
1010 let new_response = self
1011 .send_request(
1012 &execute_frame,
1013 true,
1014 prepared_statement.config.tracing,
1015 cached_metadata,
1016 )
1017 .await?;
1018
1019 if let Some(spec) = prepared_statement.get_table_spec() {
1020 if let Err(e) = self.update_tablets_from_response(spec, &new_response).await {
1021 tracing::warn!(
1022 "Error while parsing tablet info from custom payload: {}",
1023 e
1024 );
1025 }
1026 }
1027
1028 Ok(new_response)
1029 }
1030 _ => Ok(query_response),
1031 }
1032 }
1033
1034 pub(crate) async fn query_iter(
1037 self: Arc<Self>,
1038 query: Statement,
1039 ) -> Result<QueryPager, NextRowError> {
1040 let consistency = query
1041 .config
1042 .determine_consistency(self.config.default_consistency);
1043 let serial_consistency = query.config.serial_consistency.flatten();
1044
1045 QueryPager::new_for_connection_query_iter(query, self, consistency, serial_consistency)
1046 .await
1047 .map_err(NextRowError::NextPageError)
1048 }
1049
1050 pub(crate) async fn execute_iter(
1053 self: Arc<Self>,
1054 prepared_statement: PreparedStatement,
1055 values: SerializedValues,
1056 ) -> Result<QueryPager, NextRowError> {
1057 let consistency = prepared_statement
1058 .config
1059 .determine_consistency(self.config.default_consistency);
1060 let serial_consistency = prepared_statement.config.serial_consistency.flatten();
1061
1062 QueryPager::new_for_connection_execute_iter(
1063 prepared_statement,
1064 values,
1065 self,
1066 consistency,
1067 serial_consistency,
1068 )
1069 .await
1070 .map_err(NextRowError::NextPageError)
1071 }
1072
1073 #[allow(dead_code)]
1074 pub(crate) async fn batch(
1075 &self,
1076 batch: &Batch,
1077 values: impl BatchValues,
1078 ) -> Result<QueryResult, RequestAttemptError> {
1079 self.batch_with_consistency(
1080 batch,
1081 values,
1082 batch
1083 .config
1084 .determine_consistency(self.config.default_consistency),
1085 batch.config.serial_consistency.flatten(),
1086 )
1087 .await
1088 .and_then(QueryResponse::into_query_result)
1089 }
1090
1091 pub(crate) async fn batch_with_consistency(
1092 &self,
1093 init_batch: &Batch,
1094 values: impl BatchValues,
1095 consistency: Consistency,
1096 serial_consistency: Option<SerialConsistency>,
1097 ) -> Result<QueryResponse, RequestAttemptError> {
1098 let batch = self.prepare_batch(init_batch, &values).await?;
1099
1100 let contexts = batch.statements.iter().map(|bs| match bs {
1101 BatchStatement::Query(_) => RowSerializationContext::empty(),
1102 BatchStatement::PreparedStatement(ps) => {
1103 RowSerializationContext::from_prepared(ps.get_prepared_metadata())
1104 }
1105 });
1106
1107 let values = RawBatchValuesAdapter::new(values, contexts);
1108
1109 let get_timestamp_from_gen = || {
1110 self.config
1111 .timestamp_generator
1112 .as_ref()
1113 .map(|gen| gen.next_timestamp())
1114 };
1115 let timestamp = batch.get_timestamp().or_else(get_timestamp_from_gen);
1116
1117 let batch_frame = batch::Batch {
1118 statements: Cow::Borrowed(&batch.statements),
1119 values,
1120 batch_type: batch.get_type(),
1121 consistency,
1122 serial_consistency,
1123 timestamp,
1124 };
1125
1126 loop {
1127 let query_response = self
1128 .send_request(&batch_frame, true, batch.config.tracing, None)
1129 .await
1130 .map_err(RequestAttemptError::from)?;
1131
1132 return match query_response.response {
1133 Response::Error(err) => match err.error {
1134 DbError::Unprepared { statement_id } => {
1135 debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1136 let prepared_statement = batch.statements.iter().find_map(|s| match s {
1137 BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
1138 Some(s)
1139 }
1140 _ => None,
1141 });
1142 if let Some(p) = prepared_statement {
1143 self.reprepare(p.get_statement(), p).await?;
1144 continue;
1145 } else {
1146 return Err(RequestAttemptError::RepreparedIdMissingInBatch);
1147 }
1148 }
1149 _ => Err(err.into()),
1150 },
1151 Response::Result(_) => Ok(query_response),
1152 _ => Err(RequestAttemptError::UnexpectedResponse(
1153 query_response.response.to_response_kind(),
1154 )),
1155 };
1156 }
1157 }
1158
1159 async fn prepare_batch<'b>(
1160 &self,
1161 init_batch: &'b Batch,
1162 values: impl BatchValues,
1163 ) -> Result<Cow<'b, Batch>, RequestAttemptError> {
1164 let mut to_prepare = HashSet::<&str>::new();
1165
1166 {
1167 let mut values_iter = values.batch_values_iter();
1168 for stmt in &init_batch.statements {
1169 if let BatchStatement::Query(query) = stmt {
1170 if let Some(false) = values_iter.is_empty_next() {
1171 to_prepare.insert(&query.contents);
1172 }
1173 } else {
1174 values_iter.skip_next();
1175 }
1176 }
1177 }
1178
1179 if to_prepare.is_empty() {
1180 return Ok(Cow::Borrowed(init_batch));
1181 }
1182
1183 let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1184
1185 for query in &to_prepare {
1186 let prepared = self.prepare(&Statement::new(query.to_string())).await?;
1187 prepared_queries.insert(query, prepared);
1188 }
1189
1190 let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1191 for stmt in &init_batch.statements {
1192 match stmt {
1193 BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1194 {
1195 Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1196 None => batch.to_mut().append_statement(query.clone()),
1197 },
1198 BatchStatement::PreparedStatement(prepared) => {
1199 batch.to_mut().append_statement(prepared.clone());
1200 }
1201 }
1202 }
1203
1204 Ok(batch)
1205 }
1206
1207 pub(super) async fn use_keyspace(
1208 &self,
1209 keyspace_name: &VerifiedKeyspaceName,
1210 ) -> Result<(), UseKeyspaceError> {
1211 let query: Statement = match keyspace_name.is_case_sensitive {
1214 true => format!("USE \"{}\"", keyspace_name.as_str()).into(),
1215 false => format!("USE {}", keyspace_name.as_str()).into(),
1216 };
1217
1218 let query_response = self.query_raw_unpaged(&query).await?;
1219 Self::verify_use_keyspace_result(keyspace_name, query_response)
1220 }
1221
1222 fn verify_use_keyspace_result(
1223 keyspace_name: &VerifiedKeyspaceName,
1224 query_response: QueryResponse,
1225 ) -> Result<(), UseKeyspaceError> {
1226 match query_response.response {
1227 Response::Result(result::Result::SetKeyspace(set_keyspace)) => {
1228 if !set_keyspace
1229 .keyspace_name
1230 .eq_ignore_ascii_case(keyspace_name.as_str())
1231 {
1232 let expected_keyspace_name_lowercase = keyspace_name.as_str().to_lowercase();
1233 let result_keyspace_name_lowercase = set_keyspace.keyspace_name.to_lowercase();
1234
1235 return Err(UseKeyspaceError::KeyspaceNameMismatch {
1236 expected_keyspace_name_lowercase,
1237 result_keyspace_name_lowercase,
1238 });
1239 }
1240
1241 Ok(())
1242 }
1243 Response::Error(err) => Err(UseKeyspaceError::RequestError(
1244 RequestAttemptError::DbError(err.error, err.reason),
1245 )),
1246 _ => Err(UseKeyspaceError::RequestError(
1247 RequestAttemptError::UnexpectedResponse(query_response.response.to_response_kind()),
1248 )),
1249 }
1250 }
1251
1252 async fn register(
1253 &self,
1254 event_types_to_register_for: Vec<EventType>,
1255 ) -> Result<(), ConnectionSetupRequestError> {
1256 let err = |kind: ConnectionSetupRequestErrorKind| {
1257 ConnectionSetupRequestError::new(CqlRequestKind::Register, kind)
1258 };
1259
1260 let register_frame = register::Register {
1261 event_types_to_register_for,
1262 };
1263
1264 match self.send_request(®ister_frame, true, false, None).await {
1266 Ok(r) => match r.response {
1267 Response::Ready => Ok(()),
1268 Response::Error(Error { error, reason }) => {
1269 Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason)))
1270 }
1271 _ => Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
1272 r.response.to_response_kind(),
1273 ))),
1274 },
1275 Err(e) => match e {
1276 InternalRequestError::CqlRequestSerialization(e) => Err(err(e.into())),
1277 InternalRequestError::BodyExtensionsParseError(e) => Err(err(e.into())),
1278 InternalRequestError::CqlResponseParseError(e) => match e {
1279 CqlResponseParseError::CqlErrorParseError(e) => Err(err(e.into())),
1281 _ => Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse(
1282 e.to_response_kind(),
1283 ))),
1284 },
1285 InternalRequestError::BrokenConnection(e) => Err(err(e.into())),
1286 InternalRequestError::UnableToAllocStreamId => {
1287 Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId))
1288 }
1289 },
1290 }
1291 }
1292
1293 pub(crate) async fn fetch_schema_version(&self) -> Result<Uuid, SchemaAgreementError> {
1294 let (version_id,) = self
1295 .query_unpaged(LOCAL_VERSION)
1296 .await?
1297 .into_rows_result()
1298 .map_err(SchemaAgreementError::TracesEventsIntoRowsResultError)?
1299 .single_row::<(Uuid,)>()
1300 .map_err(SchemaAgreementError::SingleRowError)?;
1301
1302 Ok(version_id)
1303 }
1304
1305 async fn send_request(
1306 &self,
1307 request: &impl SerializableRequest,
1308 compress: bool,
1309 tracing: bool,
1310 cached_metadata: Option<&Arc<ResultMetadata<'static>>>,
1311 ) -> Result<QueryResponse, InternalRequestError> {
1312 let compression = if compress {
1313 self.config.compression
1314 } else {
1315 None
1316 };
1317
1318 let task_response = self
1319 .router_handle
1320 .send_request(request, compression, tracing)
1321 .await?;
1322
1323 let response = Self::parse_response(
1324 task_response,
1325 self.config.compression,
1326 &self.features.protocol_features,
1327 cached_metadata,
1328 )?;
1329
1330 Ok(response)
1331 }
1332
1333 fn parse_response(
1334 task_response: TaskResponse,
1335 compression: Option<Compression>,
1336 features: &ProtocolFeatures,
1337 cached_metadata: Option<&Arc<ResultMetadata<'static>>>,
1338 ) -> Result<QueryResponse, ResponseParseError> {
1339 let body_with_ext = frame::parse_response_body_extensions(
1340 task_response.params.flags,
1341 compression,
1342 task_response.body,
1343 )?;
1344
1345 for warn_description in &body_with_ext.warnings {
1346 warn!(
1347 warning = warn_description.as_str(),
1348 "Response from the database contains a warning",
1349 );
1350 }
1351
1352 let response = Response::deserialize(
1353 features,
1354 task_response.opcode,
1355 body_with_ext.body,
1356 cached_metadata,
1357 )?;
1358
1359 Ok(QueryResponse {
1360 response,
1361 warnings: body_with_ext.warnings,
1362 tracing_id: body_with_ext.trace_id,
1363 custom_payload: body_with_ext.custom_payload,
1364 })
1365 }
1366
1367 async fn run_router(
1368 config: HostConnectionConfig,
1369 stream: TcpStream,
1370 receiver: mpsc::Receiver<Task>,
1371 error_sender: tokio::sync::oneshot::Sender<ConnectionError>,
1372 orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
1373 router_handle: Arc<RouterHandle>,
1374 node_address: IpAddr,
1375 ) -> Result<RemoteHandle<()>, std::io::Error> {
1376 async fn spawn_router_and_get_handle(
1377 config: HostConnectionConfig,
1378 stream: (impl AsyncRead + AsyncWrite + Send + 'static),
1379 receiver: mpsc::Receiver<Task>,
1380 error_sender: tokio::sync::oneshot::Sender<ConnectionError>,
1381 orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
1382 router_handle: Arc<RouterHandle>,
1383 node_address: IpAddr,
1384 ) -> RemoteHandle<()> {
1385 let (task, handle) = Connection::router(
1386 config,
1387 stream,
1388 receiver,
1389 error_sender,
1390 orphan_notification_receiver,
1391 router_handle,
1392 node_address,
1393 )
1394 .remote_handle();
1395 tokio::task::spawn(task);
1396 handle
1397 }
1398
1399 if let Some(tls_config) = &config.tls_config {
1400 #[allow(unreachable_code)]
1402 match tls_config.new_tls()? {
1403 #[cfg(feature = "openssl-010")]
1404 crate::network::tls::Tls::OpenSsl010(ssl) => {
1405 let mut stream = tokio_openssl::SslStream::new(ssl, stream)
1406 .map_err(crate::network::tls::TlsError::OpenSsl010)?;
1407 std::pin::Pin::new(&mut stream)
1408 .connect()
1409 .await
1410 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
1411 return Ok(spawn_router_and_get_handle(
1412 config,
1413 stream,
1414 receiver,
1415 error_sender,
1416 orphan_notification_receiver,
1417 router_handle,
1418 node_address,
1419 )
1420 .await);
1421 }
1422 #[cfg(feature = "rustls-023")]
1423 crate::network::tls::Tls::Rustls023 {
1424 connector,
1425 #[cfg(feature = "unstable-cloud")]
1426 sni,
1427 } => {
1428 use rustls::pki_types::ServerName;
1429 #[cfg(feature = "unstable-cloud")]
1430 let server_name =
1431 sni.unwrap_or_else(|| ServerName::IpAddress(node_address.into()));
1432 #[cfg(not(feature = "unstable-cloud"))]
1433 let server_name = ServerName::IpAddress(node_address.into());
1434 let stream = connector.connect(server_name, stream).await?;
1435 return Ok(spawn_router_and_get_handle(
1436 config,
1437 stream,
1438 receiver,
1439 error_sender,
1440 orphan_notification_receiver,
1441 router_handle,
1442 node_address,
1443 )
1444 .await);
1445 }
1446 }
1447 }
1448
1449 Ok(spawn_router_and_get_handle(
1450 config,
1451 stream,
1452 receiver,
1453 error_sender,
1454 orphan_notification_receiver,
1455 router_handle,
1456 node_address,
1457 )
1458 .await)
1459 }
1460
1461 async fn router(
1462 config: HostConnectionConfig,
1463 stream: (impl AsyncRead + AsyncWrite),
1464 receiver: mpsc::Receiver<Task>,
1465 error_sender: tokio::sync::oneshot::Sender<ConnectionError>,
1466 orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
1467 router_handle: Arc<RouterHandle>,
1468 node_address: IpAddr,
1469 ) {
1470 let (read_half, write_half) = split(stream);
1471 let handler_map = StdMutex::new(ResponseHandlerMap::new());
1484
1485 let write_coalescing_delay = config.write_coalescing_delay;
1486
1487 let k = Self::keepaliver(
1488 router_handle,
1489 config.keepalive_interval,
1490 config.keepalive_timeout,
1491 node_address,
1492 );
1493
1494 let r = Self::reader(
1495 BufReader::with_capacity(8192, read_half),
1496 &handler_map,
1497 config.event_sender,
1498 config.compression,
1499 );
1500 let w = Self::writer(
1501 BufWriter::with_capacity(8192, write_half),
1502 &handler_map,
1503 receiver,
1504 write_coalescing_delay,
1505 );
1506 let o = Self::orphaner(&handler_map, orphan_notification_receiver);
1507
1508 let result = futures::try_join!(r, w, o, k);
1509
1510 let error: BrokenConnectionError = match result {
1511 Ok(_) => return, Err(err) => err,
1513 };
1514
1515 let response_handlers: HashMap<i16, ResponseHandler> =
1517 handler_map.into_inner().unwrap().into_handlers();
1518
1519 for (_, handler) in response_handlers {
1520 let _ = handler.response_sender.send(Err(error.clone().into()));
1522 }
1523
1524 let _ = error_sender.send(error.into());
1526 }
1527
1528 async fn reader(
1529 mut read_half: (impl AsyncRead + Unpin),
1530 handler_map: &StdMutex<ResponseHandlerMap>,
1531 event_sender: Option<mpsc::Sender<Event>>,
1532 compression: Option<Compression>,
1533 ) -> Result<(), BrokenConnectionError> {
1534 loop {
1535 let (params, opcode, body) = frame::read_response_frame(&mut read_half)
1536 .await
1537 .map_err(BrokenConnectionErrorKind::FrameHeaderParseError)?;
1538 let response = TaskResponse {
1539 params,
1540 opcode,
1541 body,
1542 };
1543
1544 match params.stream.cmp(&-1) {
1545 Ordering::Less => {
1546 continue;
1550 }
1551 Ordering::Equal => {
1552 if let Some(event_sender) = event_sender.as_ref() {
1553 Self::handle_event(response, compression, event_sender)
1554 .await
1555 .map_err(BrokenConnectionErrorKind::CqlEventHandlingError)?
1556 }
1557 continue;
1558 }
1559 _ => {}
1560 }
1561
1562 let handler_lookup_res = {
1563 let mut handler_map_guard = handler_map.try_lock().unwrap();
1566 handler_map_guard.lookup(params.stream)
1567 };
1568
1569 use HandlerLookupResult::*;
1570 match handler_lookup_res {
1571 Handler(handler) => {
1572 let _ = handler.response_sender.send(Ok(response));
1576 }
1577 Missing => {
1578 debug!(
1581 "Received response with unexpected StreamId {}",
1582 params.stream
1583 );
1584 return Err(BrokenConnectionErrorKind::UnexpectedStreamId(params.stream).into());
1585 }
1586 Orphaned => {
1587 }
1590 }
1591 }
1592 }
1593
1594 fn alloc_stream_id(
1595 handler_map: &StdMutex<ResponseHandlerMap>,
1596 response_handler: ResponseHandler,
1597 ) -> Option<i16> {
1598 let mut handler_map_guard = handler_map.try_lock().unwrap();
1601 match handler_map_guard.allocate(response_handler) {
1602 Ok(stream_id) => Some(stream_id),
1603 Err(response_handler) => {
1604 error!("Could not allocate stream id");
1605 let _ = response_handler
1606 .response_sender
1607 .send(Err(InternalRequestError::UnableToAllocStreamId));
1608 None
1609 }
1610 }
1611 }
1612
1613 async fn writer(
1614 mut write_half: (impl AsyncWrite + Unpin),
1615 handler_map: &StdMutex<ResponseHandlerMap>,
1616 mut task_receiver: mpsc::Receiver<Task>,
1617 write_coalescing_delay: Option<WriteCoalescingDelay>,
1618 ) -> Result<(), BrokenConnectionError> {
1619 while let Some(mut task) = task_receiver.recv().await {
1623 let mut num_requests = 0;
1624 let mut total_sent = 0;
1625 while let Some(stream_id) = Self::alloc_stream_id(handler_map, task.response_handler) {
1626 let mut req = task.serialized_request;
1627 req.set_stream(stream_id);
1628 let req_data: &[u8] = req.get_data();
1629 total_sent += req_data.len();
1630 num_requests += 1;
1631 write_half
1632 .write_all(req_data)
1633 .await
1634 .map_err(BrokenConnectionErrorKind::WriteError)?;
1635 task = match task_receiver.try_recv() {
1636 Ok(t) => t,
1637 Err(_) => match write_coalescing_delay {
1638 Some(WriteCoalescingDelay::SmallNondeterministic) => {
1639 tokio::task::yield_now().await;
1643 match task_receiver.try_recv() {
1644 Ok(t) => t,
1645 Err(_) => break,
1646 }
1647 }
1648 Some(WriteCoalescingDelay::Milliseconds(ms)) => {
1649 tokio::time::sleep(Duration::from_millis(ms.get())).await;
1650 match task_receiver.try_recv() {
1651 Ok(t) => t,
1652 Err(_) => break,
1653 }
1654 }
1655 None => break,
1656 },
1657 }
1658 }
1659 trace!("Sending {} requests; {} bytes", num_requests, total_sent);
1660 write_half
1661 .flush()
1662 .await
1663 .map_err(BrokenConnectionErrorKind::WriteError)?;
1664 }
1665
1666 Ok(())
1667 }
1668
1669 async fn orphaner(
1674 handler_map: &StdMutex<ResponseHandlerMap>,
1675 mut orphan_receiver: mpsc::UnboundedReceiver<RequestId>,
1676 ) -> Result<(), BrokenConnectionError> {
1677 let mut interval = tokio::time::interval(OLD_AGE_ORPHAN_THRESHOLD);
1678 loop {
1679 tokio::select! {
1680 _ = interval.tick() => {
1681 let handler_map_guard = handler_map.try_lock().unwrap();
1684 let old_orphan_count = handler_map_guard.old_orphans_count();
1685 if old_orphan_count > OLD_ORPHAN_COUNT_THRESHOLD {
1686 warn!(
1687 "Too many old orphaned stream ids: {}",
1688 old_orphan_count,
1689 );
1690 return Err(BrokenConnectionErrorKind::TooManyOrphanedStreamIds(old_orphan_count as u16).into())
1691 }
1692 }
1693 Some(request_id) = orphan_receiver.recv() => {
1694 trace!(
1695 "Trying to orphan stream id associated with request_id = {}",
1696 request_id,
1697 );
1698 let mut handler_map_guard = handler_map.try_lock().unwrap(); handler_map_guard.orphan(request_id);
1700 }
1701 else => { break }
1702 }
1703 }
1704
1705 Ok(())
1706 }
1707
1708 async fn keepaliver(
1709 router_handle: Arc<RouterHandle>,
1710 keepalive_interval: Option<Duration>,
1711 keepalive_timeout: Option<Duration>,
1712 node_address: IpAddr, ) -> Result<(), BrokenConnectionError> {
1714 async fn issue_keepalive_query(
1715 router_handle: &RouterHandle,
1716 ) -> Result<(), BrokenConnectionError> {
1717 router_handle
1718 .send_request(&Options, None, false)
1719 .await
1720 .map(|_| ())
1721 .map_err(|req_err| {
1722 BrokenConnectionErrorKind::KeepaliveRequestError(Arc::new(req_err)).into()
1723 })
1724 }
1725
1726 if let Some(keepalive_interval) = keepalive_interval {
1727 let mut interval = tokio::time::interval(keepalive_interval);
1728 interval.tick().await; interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
1732
1733 loop {
1734 interval.tick().await;
1735
1736 let keepalive_query = issue_keepalive_query(&router_handle);
1737 let query_result = if let Some(timeout) = keepalive_timeout {
1738 match tokio::time::timeout(timeout, keepalive_query).await {
1739 Ok(res) => res,
1740 Err(_) => {
1741 warn!(
1742 "Timed out while waiting for response to keepalive request on connection to node {}",
1743 node_address
1744 );
1745 return Err(
1746 BrokenConnectionErrorKind::KeepaliveTimeout(node_address).into()
1747 );
1748 }
1749 }
1750 } else {
1751 keepalive_query.await
1752 };
1753 if let Err(err) = query_result {
1754 warn!(
1755 "Failed to execute keepalive request on connection to node {} - {}",
1756 node_address, err
1757 );
1758 return Err(err);
1759 }
1760
1761 trace!(
1762 "Keepalive request successful on connection to node {}",
1763 node_address
1764 );
1765 }
1766 } else {
1767 Ok(())
1769 }
1770 }
1771
1772 async fn handle_event(
1773 task_response: TaskResponse,
1774 compression: Option<Compression>,
1775 event_sender: &mpsc::Sender<Event>,
1776 ) -> Result<(), CqlEventHandlingError> {
1777 let features = ProtocolFeatures::default(); let event = match Self::parse_response(task_response, compression, &features, None) {
1789 Ok(r) => match r.response {
1790 Response::Event(event) => event,
1791 _ => {
1792 error!("Expected to receive Event response, got {:?}", r.response);
1793 return Err(CqlEventHandlingError::UnexpectedResponse(
1794 r.response.to_response_kind(),
1795 ));
1796 }
1797 },
1798 Err(e) => match e {
1799 ResponseParseError::BodyExtensionsParseError(e) => return Err(e.into()),
1800 ResponseParseError::CqlResponseParseError(e) => match e {
1801 CqlResponseParseError::CqlEventParseError(e) => return Err(e.into()),
1802 _ => {
1804 return Err(CqlEventHandlingError::UnexpectedResponse(
1805 e.to_response_kind(),
1806 ))
1807 }
1808 },
1809 },
1810 };
1811
1812 event_sender
1813 .send(event)
1814 .await
1815 .map_err(|_| CqlEventHandlingError::SendError)
1816 }
1817
1818 pub(crate) fn get_shard_info(&self) -> &Option<ShardInfo> {
1819 &self.features.shard_info
1820 }
1821
1822 pub(crate) fn get_shard_aware_port(&self) -> Option<u16> {
1823 self.features.shard_aware_port
1824 }
1825
1826 fn set_features(&mut self, features: ConnectionFeatures) {
1827 self.features = features;
1828 }
1829
1830 pub(crate) fn get_connect_address(&self) -> SocketAddr {
1831 self.connect_address
1832 }
1833
1834 async fn update_tablets_from_response(
1835 &self,
1836 table: &TableSpec<'_>,
1837 response: &QueryResponse,
1838 ) -> Result<(), TabletParsingError> {
1839 if let (Some(sender), Some(tablet_data)) = (
1840 self.config.tablet_sender.as_ref(),
1841 response.custom_payload.as_ref(),
1842 ) {
1843 let tablet = match RawTablet::from_custom_payload(tablet_data) {
1844 Some(Ok(v)) => v,
1845 Some(Err(e)) => return Err(e),
1846 None => return Ok(()),
1847 };
1848 tracing::trace!(
1849 "Received tablet info for table {}.{} in custom payload: {:?}",
1850 table.ks_name(),
1851 table.table_name(),
1852 tablet
1853 );
1854 let _ = sender.send((table.to_owned(), tablet)).await;
1855 }
1856
1857 Ok(())
1858 }
1859}
1860
1861async fn maybe_translated_addr(
1862 endpoint: &UntranslatedEndpoint,
1863 address_translator: Option<&dyn AddressTranslator>,
1864) -> Result<SocketAddr, TranslationError> {
1865 match *endpoint {
1866 UntranslatedEndpoint::ContactPoint(ref addr) => Ok(addr.address),
1867 UntranslatedEndpoint::Peer(PeerEndpoint {
1868 host_id,
1869 address,
1870 ref datacenter,
1871 ref rack,
1872 }) => match address {
1873 NodeAddr::Translatable(addr) => {
1874 if let Some(translator) = address_translator {
1876 let res = translator
1877 .translate_address(&UntranslatedPeer {
1878 host_id,
1879 untranslated_address: addr,
1880 datacenter: datacenter.as_deref(),
1881 rack: rack.as_deref(),
1882 })
1883 .await;
1884 if let Err(ref err) = res {
1885 error!("Address translation failed for addr {}: {}", addr, err);
1886 }
1887 res
1888 } else {
1889 Ok(addr)
1890 }
1891 }
1892 NodeAddr::Untranslatable(addr) => {
1893 Ok(addr)
1895 }
1896 },
1897 }
1898}
1899
1900pub(crate) async fn open_connection(
1906 endpoint: &UntranslatedEndpoint,
1907 source_port: Option<u16>,
1908 config: &HostConnectionConfig,
1909) -> Result<(Connection, ErrorReceiver), ConnectionError> {
1910 let addr = maybe_translated_addr(endpoint, config.address_translator.as_deref()).await?;
1912
1913 let (mut connection, error_receiver) =
1915 Connection::new(addr, source_port, config.clone()).await?;
1916
1917 let mut supported = connection.get_options().await?;
1921
1922 let shard_aware_port_key = match config.is_tls() {
1923 true => options::SCYLLA_SHARD_AWARE_PORT_SSL,
1924 false => options::SCYLLA_SHARD_AWARE_PORT,
1925 };
1926
1927 let shard_info = match ShardInfo::try_from(&supported.options) {
1929 Ok(info) => Some(info),
1930 Err(ShardingError::NoShardInfo) => {
1931 tracing::info!(
1932 "[{}] No sharding information received. Proceeding with no sharding info.",
1933 addr
1934 );
1935 None
1936 }
1937 Err(e) => {
1938 tracing::error!(
1939 "[{}] Error while parsing sharding information: {}. Proceeding with no sharding info.",
1940 addr, e
1941 );
1942 None
1943 }
1944 };
1945 let supported_compression = supported
1946 .options
1947 .remove(options::COMPRESSION)
1948 .unwrap_or_default();
1949 let shard_aware_port = supported
1950 .options
1951 .remove(shard_aware_port_key)
1952 .unwrap_or_default()
1953 .first()
1954 .and_then(|p| p.parse::<u16>().ok());
1955
1956 let protocol_features = ProtocolFeatures::parse_from_supported(&supported.options);
1958
1959 let features = ConnectionFeatures {
1962 shard_info,
1963 shard_aware_port,
1964 protocol_features,
1965 };
1966 connection.set_features(features);
1967
1968 let mut options = HashMap::new();
1970 protocol_features.add_startup_options(&mut options);
1971
1972 options.insert(
1974 Cow::Borrowed(options::CQL_VERSION),
1975 Cow::Borrowed(options::DEFAULT_CQL_PROTOCOL_VERSION),
1976 );
1977
1978 config.identity.add_startup_options(&mut options);
1980
1981 if let Some(compression) = &config.compression {
1983 let compression_str = compression.as_str();
1984 if supported_compression.iter().any(|c| c == compression_str) {
1985 options.insert(
1988 Cow::Borrowed(options::COMPRESSION),
1989 Cow::Borrowed(compression_str),
1990 );
1991 } else {
1992 tracing::warn!(
1994 "Requested compression <{}> is not supported by the cluster. Falling back to no compression",
1995 compression_str
1996 );
1997 connection.config.compression = None;
1998 }
1999 }
2000
2001 let startup_result = connection.startup(options).await?;
2003 match startup_result {
2004 NonErrorStartupResponse::Ready => {}
2005 NonErrorStartupResponse::Authenticate(authenticate) => {
2006 connection.perform_authenticate(&authenticate).await?;
2007 }
2008 }
2009
2010 if connection.config.event_sender.is_some() {
2012 let all_event_types = vec![
2013 EventType::TopologyChange,
2014 EventType::StatusChange,
2015 EventType::SchemaChange,
2016 ];
2017 connection.register(all_event_types).await?;
2018 }
2019
2020 Ok((connection, error_receiver))
2021}
2022
2023pub(super) async fn open_connection_to_shard_aware_port(
2024 endpoint: &UntranslatedEndpoint,
2025 shard: Shard,
2026 sharder: Sharder,
2027 config: &HostConnectionConfig,
2028) -> Result<(Connection, ErrorReceiver), ConnectionError> {
2029 let source_port_iter =
2031 sharder.iter_source_ports_for_shard_from_range(shard, &config.shard_aware_local_port_range);
2032
2033 for port in source_port_iter {
2034 let connect_result = open_connection(endpoint, Some(port), config).await;
2035
2036 match connect_result {
2037 Err(err) if err.is_address_unavailable_for_use() => continue, result => return result,
2039 }
2040 }
2041
2042 Err(ConnectionError::NoSourcePortForShard(shard))
2044}
2045
2046async fn connect_with_source_ip_and_port(
2047 connect_address: SocketAddr,
2048 source_ip: Option<IpAddr>,
2049 source_port: Option<u16>,
2050) -> Result<TcpStream, std::io::Error> {
2051 let source_port = source_port.unwrap_or(0);
2053
2054 match connect_address {
2055 SocketAddr::V4(_) => {
2056 let source_ipv4 = source_ip.unwrap_or(Ipv4Addr::UNSPECIFIED.into());
2058 let socket = TcpSocket::new_v4()?;
2059 socket.bind(SocketAddr::new(source_ipv4, source_port))?;
2060 Ok(socket.connect(connect_address).await?)
2061 }
2062 SocketAddr::V6(_) => {
2063 let source_ipv6 = source_ip.unwrap_or(Ipv6Addr::UNSPECIFIED.into());
2065 let socket = TcpSocket::new_v6()?;
2066 socket.bind(SocketAddr::new(source_ipv6, source_port))?;
2067 Ok(socket.connect(connect_address).await?)
2068 }
2069 }
2070}
2071
2072struct OrphanageTracker {
2073 orphans: HashMap<i16, Instant>,
2074 by_orphaning_times: BTreeSet<(Instant, i16)>,
2075}
2076
2077impl OrphanageTracker {
2078 fn new() -> Self {
2079 Self {
2080 orphans: HashMap::new(),
2081 by_orphaning_times: BTreeSet::new(),
2082 }
2083 }
2084
2085 fn insert(&mut self, stream_id: i16) {
2086 let now = Instant::now();
2087 self.orphans.insert(stream_id, now);
2088 self.by_orphaning_times.insert((now, stream_id));
2089 }
2090
2091 fn remove(&mut self, stream_id: i16) {
2092 if let Some(time) = self.orphans.remove(&stream_id) {
2093 self.by_orphaning_times.remove(&(time, stream_id));
2094 }
2095 }
2096
2097 fn contains(&self, stream_id: i16) -> bool {
2098 self.orphans.contains_key(&stream_id)
2099 }
2100
2101 fn orphans_older_than(&self, age: std::time::Duration) -> usize {
2102 let minimal_age = Instant::now() - age;
2103 self.by_orphaning_times
2104 .range(..(minimal_age, i16::MAX))
2105 .count() }
2110}
2111
2112struct ResponseHandlerMap {
2113 stream_set: StreamIdSet,
2114 handlers: HashMap<i16, ResponseHandler>,
2115
2116 request_to_stream: HashMap<RequestId, i16>,
2117 orphanage_tracker: OrphanageTracker,
2118}
2119
2120enum HandlerLookupResult {
2121 Orphaned,
2122 Handler(ResponseHandler),
2123 Missing,
2124}
2125
2126impl ResponseHandlerMap {
2127 fn new() -> Self {
2128 Self {
2129 stream_set: StreamIdSet::new(),
2130 handlers: HashMap::new(),
2131 request_to_stream: HashMap::new(),
2132 orphanage_tracker: OrphanageTracker::new(),
2133 }
2134 }
2135
2136 fn allocate(&mut self, response_handler: ResponseHandler) -> Result<i16, ResponseHandler> {
2137 if let Some(stream_id) = self.stream_set.allocate() {
2138 self.request_to_stream
2139 .insert(response_handler.request_id, stream_id);
2140 let prev_handler = self.handlers.insert(stream_id, response_handler);
2141 assert!(prev_handler.is_none());
2142
2143 Ok(stream_id)
2144 } else {
2145 Err(response_handler)
2146 }
2147 }
2148
2149 fn orphan(&mut self, request_id: RequestId) {
2152 if let Some(stream_id) = self.request_to_stream.get(&request_id) {
2153 debug!(
2154 "Orphaning stream_id = {} associated with request_id = {}",
2155 stream_id, request_id
2156 );
2157 self.orphanage_tracker.insert(*stream_id);
2158 self.handlers.remove(stream_id);
2159 self.request_to_stream.remove(&request_id);
2160 }
2161 }
2162
2163 fn old_orphans_count(&self) -> usize {
2164 self.orphanage_tracker
2165 .orphans_older_than(OLD_AGE_ORPHAN_THRESHOLD)
2166 }
2167
2168 fn lookup(&mut self, stream_id: i16) -> HandlerLookupResult {
2169 self.stream_set.free(stream_id);
2170
2171 if self.orphanage_tracker.contains(stream_id) {
2172 self.orphanage_tracker.remove(stream_id);
2173 return HandlerLookupResult::Orphaned;
2176 }
2177
2178 if let Some(handler) = self.handlers.remove(&stream_id) {
2179 self.request_to_stream.remove(&handler.request_id);
2183
2184 HandlerLookupResult::Handler(handler)
2185 } else {
2186 HandlerLookupResult::Missing
2187 }
2188 }
2189
2190 fn into_handlers(self) -> HashMap<i16, ResponseHandler> {
2193 self.handlers
2194 }
2195}
2196
2197struct StreamIdSet {
2198 used_bitmap: Box<[u64]>,
2199}
2200
2201impl StreamIdSet {
2202 fn new() -> Self {
2203 const BITMAP_SIZE: usize = (i16::MAX as usize + 1) / 64;
2204 Self {
2205 used_bitmap: vec![0; BITMAP_SIZE].into_boxed_slice(),
2206 }
2207 }
2208
2209 fn allocate(&mut self) -> Option<i16> {
2210 for (block_id, block) in self.used_bitmap.iter_mut().enumerate() {
2211 if *block != !0 {
2212 let off = block.trailing_ones();
2213 *block |= 1u64 << off;
2214 let stream_id = off as i16 + block_id as i16 * 64;
2215 return Some(stream_id);
2216 }
2217 }
2218 None
2219 }
2220
2221 fn free(&mut self, stream_id: i16) {
2222 let block_id = stream_id as usize / 64;
2223 let off = stream_id as usize % 64;
2224 self.used_bitmap[block_id] &= !(1 << off);
2225 }
2226}
2227
2228#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
2230pub(crate) struct VerifiedKeyspaceName {
2231 name: Arc<String>,
2232 pub(crate) is_case_sensitive: bool,
2233}
2234
2235impl VerifiedKeyspaceName {
2236 pub(crate) fn new(
2237 keyspace_name: String,
2238 case_sensitive: bool,
2239 ) -> Result<Self, BadKeyspaceName> {
2240 Self::verify_keyspace_name_is_valid(&keyspace_name)?;
2241
2242 Ok(VerifiedKeyspaceName {
2243 name: Arc::new(keyspace_name),
2244 is_case_sensitive: case_sensitive,
2245 })
2246 }
2247
2248 pub(crate) fn as_str(&self) -> &str {
2249 self.name.as_str()
2250 }
2251
2252 fn verify_keyspace_name_is_valid(keyspace_name: &str) -> Result<(), BadKeyspaceName> {
2259 if keyspace_name.is_empty() {
2260 return Err(BadKeyspaceName::Empty);
2261 }
2262
2263 let keyspace_name_len: usize = keyspace_name.chars().count(); if keyspace_name_len > 48 {
2266 return Err(BadKeyspaceName::TooLong(
2267 keyspace_name.to_string(),
2268 keyspace_name_len,
2269 ));
2270 }
2271
2272 for character in keyspace_name.chars() {
2274 match character {
2275 'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {}
2276 _ => {
2277 return Err(BadKeyspaceName::IllegalCharacter(
2278 keyspace_name.to_string(),
2279 character,
2280 ))
2281 }
2282 };
2283 }
2284
2285 Ok(())
2286 }
2287}
2288
2289#[cfg(test)]
2290mod tests {
2291 use assert_matches::assert_matches;
2292 use scylla_cql::frame::protocol_features::{
2293 LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
2294 };
2295 use scylla_cql::frame::types;
2296 use scylla_proxy::{
2297 Condition, Node, Proxy, Reaction, RequestFrame, RequestOpcode, RequestReaction,
2298 RequestRule, ResponseFrame, ShardAwareness,
2299 };
2300
2301 use tokio::select;
2302 use tokio::sync::mpsc;
2303
2304 use super::{open_connection, HostConnectionConfig};
2305 use crate::cluster::metadata::UntranslatedEndpoint;
2306 use crate::cluster::node::ResolvedContactPoint;
2307 use crate::statement::unprepared::Statement;
2308 use crate::test_utils::setup_tracing;
2309 use crate::utils::test_utils::{resolve_hostname, unique_keyspace_name, PerformDDL};
2310 use futures::{StreamExt, TryStreamExt};
2311 use std::collections::HashMap;
2312 use std::net::SocketAddr;
2313 use std::sync::Arc;
2314 use std::time::Duration;
2315
2316 #[tokio::test]
2322 #[cfg(not(scylla_cloud_tests))]
2323 async fn connection_query_iter_test() {
2324 use crate::client::session_builder::SessionBuilder;
2325
2326 setup_tracing();
2327 let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
2328 let addr: SocketAddr = resolve_hostname(&uri).await;
2329
2330 let (connection, _) = super::open_connection(
2331 &UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2332 address: addr,
2333 datacenter: None,
2334 }),
2335 None,
2336 &HostConnectionConfig::default(),
2337 )
2338 .await
2339 .unwrap();
2340 let connection = Arc::new(connection);
2341
2342 let ks = unique_keyspace_name();
2343
2344 {
2345 let session = SessionBuilder::new()
2347 .known_node_addr(addr)
2348 .build()
2349 .await
2350 .unwrap();
2351 session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks.clone())).await.unwrap();
2352 session.use_keyspace(ks.clone(), false).await.unwrap();
2353 session
2354 .ddl("DROP TABLE IF EXISTS connection_query_iter_tab")
2355 .await
2356 .unwrap();
2357 session
2358 .ddl("CREATE TABLE IF NOT EXISTS connection_query_iter_tab (p int primary key)")
2359 .await
2360 .unwrap();
2361 }
2362
2363 connection
2364 .use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
2365 .await
2366 .unwrap();
2367
2368 let select_query =
2370 Statement::new("SELECT p FROM connection_query_iter_tab").with_page_size(7);
2371 let empty_res = connection
2372 .clone()
2373 .query_iter(select_query.clone())
2374 .await
2375 .unwrap()
2376 .rows_stream::<(i32,)>()
2377 .unwrap()
2378 .try_collect::<Vec<_>>()
2379 .await
2380 .unwrap();
2381 assert!(empty_res.is_empty());
2382
2383 let values: Vec<i32> = (0..100).collect();
2385 let insert_query = Statement::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)")
2386 .with_page_size(7);
2387 let prepared = connection.prepare(&insert_query).await.unwrap();
2388 let mut insert_futures = Vec::new();
2389 for v in &values {
2390 let values = prepared.serialize_values(&(*v,)).unwrap();
2391 let fut = async { connection.execute_raw_unpaged(&prepared, values).await };
2392 insert_futures.push(fut);
2393 }
2394
2395 futures::future::try_join_all(insert_futures).await.unwrap();
2396
2397 let mut results: Vec<i32> = connection
2398 .clone()
2399 .query_iter(select_query.clone())
2400 .await
2401 .unwrap()
2402 .rows_stream::<(i32,)>()
2403 .unwrap()
2404 .map(|ret| ret.unwrap().0)
2405 .collect::<Vec<_>>()
2406 .await;
2407 results.sort_unstable(); assert_eq!(results, values);
2409
2410 let insert_res1 = connection
2412 .query_iter(Statement::new(
2413 "INSERT INTO connection_query_iter_tab (p) VALUES (0)",
2414 ))
2415 .await
2416 .unwrap()
2417 .rows_stream::<()>()
2418 .unwrap()
2419 .try_collect::<Vec<_>>()
2420 .await
2421 .unwrap();
2422 assert!(insert_res1.is_empty());
2423 }
2424
2425 #[tokio::test]
2426 #[cfg(not(scylla_cloud_tests))]
2427 async fn test_coalescing() {
2428 use std::num::NonZeroU64;
2429
2430 use super::WriteCoalescingDelay;
2431 use crate::client::session_builder::SessionBuilder;
2432
2433 setup_tracing();
2434 let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
2440 let addr: SocketAddr = resolve_hostname(&uri).await;
2441 let ks = unique_keyspace_name();
2442
2443 {
2444 let session = SessionBuilder::new()
2446 .known_node_addr(addr)
2447 .build()
2448 .await
2449 .unwrap();
2450 session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks.clone())).await.unwrap();
2451 session.use_keyspace(ks.clone(), false).await.unwrap();
2452 session
2453 .ddl("CREATE TABLE IF NOT EXISTS t (p int primary key, v blob)")
2454 .await
2455 .unwrap();
2456 }
2457
2458 let subtest = |write_coalescing_delay: Option<WriteCoalescingDelay>, ks: String| async move {
2459 let (connection, _) = super::open_connection(
2460 &UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2461 address: addr,
2462 datacenter: None,
2463 }),
2464 None,
2465 &HostConnectionConfig {
2466 write_coalescing_delay,
2467 ..HostConnectionConfig::default()
2468 },
2469 )
2470 .await
2471 .unwrap();
2472 let connection = Arc::new(connection);
2473
2474 connection
2475 .use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
2476 .await
2477 .unwrap();
2478
2479 connection.ddl("TRUNCATE t").await.unwrap();
2480
2481 let mut futs = Vec::new();
2482
2483 const NUM_BATCHES: i32 = 10;
2484
2485 for batch_size in 0..NUM_BATCHES {
2486 let base = arithmetic_sequence_sum(batch_size);
2488 let conn = connection.clone();
2489 futs.push(tokio::task::spawn(async move {
2490 let futs = (base..base + batch_size).map(|j| {
2491 let q = Statement::new("INSERT INTO t (p, v) VALUES (?, ?)");
2492 let conn = conn.clone();
2493 async move {
2494 let prepared = conn.prepare(&q).await.unwrap();
2495 let values = prepared
2496 .serialize_values(&(j, vec![j as u8; j as usize]))
2497 .unwrap();
2498 let response =
2499 conn.execute_raw_unpaged(&prepared, values).await.unwrap();
2500 let _nonerror_response =
2502 response.into_non_error_query_response().unwrap();
2503 }
2504 });
2505 let _joined: Vec<()> = futures::future::join_all(futs).await;
2506 }));
2507
2508 tokio::task::yield_now().await;
2509 }
2510
2511 let _joined: Vec<()> = futures::future::try_join_all(futs).await.unwrap();
2512
2513 let range_end = arithmetic_sequence_sum(NUM_BATCHES);
2515 let mut results = connection
2516 .query_unpaged("SELECT p, v FROM t")
2517 .await
2518 .unwrap()
2519 .into_rows_result()
2520 .unwrap()
2521 .rows::<(i32, Vec<u8>)>()
2522 .unwrap()
2523 .collect::<Result<Vec<_>, _>>()
2524 .unwrap();
2525 results.sort();
2526
2527 let expected = (0..range_end)
2528 .map(|i| (i, vec![i as u8; i as usize]))
2529 .collect::<Vec<_>>();
2530
2531 assert_eq!(results, expected);
2532 };
2533
2534 subtest(
2536 Some(WriteCoalescingDelay::SmallNondeterministic),
2537 ks.clone(),
2538 )
2539 .await;
2540 subtest(
2542 Some(WriteCoalescingDelay::Milliseconds(
2543 NonZeroU64::new(1).unwrap(),
2544 )),
2545 ks.clone(),
2546 )
2547 .await;
2548 subtest(None, ks.clone()).await;
2550 }
2551
2552 fn arithmetic_sequence_sum(n: i32) -> i32 {
2554 n * (n - 1) / 2
2555 }
2556
2557 #[tokio::test]
2558 async fn test_lwt_optimisation_mark_negotiation() {
2559 setup_tracing();
2560 const MASK: &str = "2137";
2561
2562 let lwt_optimisation_entry = format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, MASK);
2563
2564 let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042);
2565
2566 let config = HostConnectionConfig::default();
2567
2568 let (startup_tx, mut startup_rx) = mpsc::unbounded_channel();
2569
2570 let options_without_lwt_optimisation_support = HashMap::<String, Vec<String>>::new();
2571 let options_with_lwt_optimisation_support = [(
2572 SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION.into(),
2573 vec![lwt_optimisation_entry.clone()],
2574 )]
2575 .into_iter()
2576 .collect::<HashMap<String, Vec<String>>>();
2577
2578 let make_rules = |options| {
2579 vec![
2580 RequestRule(
2581 Condition::RequestOpcode(RequestOpcode::Options),
2582 RequestReaction::forge_response(Arc::new(move |frame: RequestFrame| {
2583 ResponseFrame::forged_supported(frame.params, &options).unwrap()
2584 })),
2585 ),
2586 RequestRule(
2587 Condition::RequestOpcode(RequestOpcode::Startup),
2588 RequestReaction::drop_frame().with_feedback_when_performed(startup_tx.clone()),
2589 ),
2590 ]
2591 };
2592
2593 let mut proxy = Proxy::builder()
2594 .with_node(
2595 Node::builder()
2596 .proxy_address(proxy_addr)
2597 .request_rules(make_rules(options_without_lwt_optimisation_support))
2598 .build_dry_mode(),
2599 )
2600 .build()
2601 .run()
2602 .await
2603 .unwrap();
2604
2605 let endpoint = UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2607 address: proxy_addr,
2608 datacenter: None,
2609 });
2610 let (startup_without_lwt_optimisation, _shard) = select! {
2611 _ = open_connection(&endpoint, None, &config) => unreachable!(),
2612 startup = startup_rx.recv() => startup.unwrap(),
2613 };
2614
2615 proxy.running_nodes[0]
2616 .change_request_rules(Some(make_rules(options_with_lwt_optimisation_support)));
2617
2618 let (startup_with_lwt_optimisation, _shard) = select! {
2619 _ = open_connection(&endpoint, None, &config) => unreachable!(),
2620 startup = startup_rx.recv() => startup.unwrap(),
2621 };
2622
2623 let _ = proxy.finish().await;
2624
2625 let chosen_options =
2626 types::read_string_map(&mut &*startup_without_lwt_optimisation.body).unwrap();
2627 assert!(!chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
2628
2629 let chosen_options =
2630 types::read_string_map(&mut &startup_with_lwt_optimisation.body[..]).unwrap();
2631 assert!(chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
2632 assert_eq!(
2633 chosen_options
2634 .get(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION)
2635 .unwrap(),
2636 &lwt_optimisation_entry
2637 )
2638 }
2639
2640 #[tokio::test]
2641 #[ntest::timeout(20000)]
2642 #[cfg(not(scylla_cloud_tests))]
2643 async fn connection_is_closed_on_no_response_to_keepalives() {
2644 use crate::errors::BrokenConnectionErrorKind;
2645
2646 setup_tracing();
2647
2648 let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042);
2649 let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
2650 let node_addr: SocketAddr = resolve_hostname(&uri).await;
2651
2652 let drop_options_rule = RequestRule(
2653 Condition::RequestOpcode(RequestOpcode::Options),
2654 RequestReaction::drop_frame(),
2655 );
2656
2657 let config = HostConnectionConfig {
2658 keepalive_interval: Some(Duration::from_millis(500)),
2659 keepalive_timeout: Some(Duration::from_secs(1)),
2660 ..Default::default()
2661 };
2662
2663 let mut proxy = Proxy::builder()
2664 .with_node(
2665 Node::builder()
2666 .proxy_address(proxy_addr)
2667 .real_address(node_addr)
2668 .shard_awareness(ShardAwareness::QueryNode)
2669 .build(),
2670 )
2671 .build()
2672 .run()
2673 .await
2674 .unwrap();
2675
2676 let (conn, mut error_receiver) = open_connection(
2678 &UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
2679 address: proxy_addr,
2680 datacenter: None,
2681 }),
2682 None,
2683 &config,
2684 )
2685 .await
2686 .unwrap();
2687
2688 for _ in 0..3 {
2690 tokio::time::sleep(Duration::from_millis(500)).await;
2691 conn.query_unpaged("SELECT host_id FROM system.local WHERE key='local'")
2692 .await
2693 .unwrap();
2694 }
2695 assert_matches!(
2697 error_receiver.try_recv(),
2698 Err(tokio::sync::oneshot::error::TryRecvError::Empty)
2699 );
2700
2701 proxy.running_nodes[0].change_request_rules(Some(vec![drop_options_rule]));
2703
2704 let err = error_receiver.await.unwrap();
2707 let err_inner: &BrokenConnectionErrorKind = match err {
2708 super::ConnectionError::BrokenConnection(ref e) => e.downcast_ref().unwrap(),
2709 _ => panic!("Bad error type. Expected keepalive timeout."),
2710 };
2711 assert_matches!(err_inner, BrokenConnectionErrorKind::KeepaliveTimeout(_));
2712
2713 conn.query_unpaged("SELECT host_id FROM system.local WHERE key='local'")
2716 .await
2717 .unwrap_err();
2718
2719 let _ = proxy.finish().await;
2720 }
2721}