1use super::connection::{
2 open_connection, open_connection_to_shard_aware_port, Connection, ConnectionConfig,
3 ErrorReceiver, HostConnectionConfig, VerifiedKeyspaceName,
4};
5
6use crate::errors::{
7 BrokenConnectionErrorKind, ConnectionError, ConnectionPoolError, UseKeyspaceError,
8};
9use crate::routing::{Shard, ShardCount, Sharder};
10
11use crate::cluster::metadata::{PeerEndpoint, UntranslatedEndpoint};
12
13#[cfg(feature = "metrics")]
14use crate::observability::metrics::Metrics;
15
16use crate::cluster::NodeAddr;
17
18use arc_swap::ArcSwap;
19use futures::{future::RemoteHandle, stream::FuturesUnordered, Future, FutureExt, StreamExt};
20use itertools::Itertools;
21use rand::Rng;
22use std::convert::TryInto;
23use std::num::NonZeroUsize;
24use std::pin::Pin;
25use std::sync::{Arc, RwLock, Weak};
26use std::time::Duration;
27
28use tokio::sync::{broadcast, mpsc, Notify};
29use tracing::{debug, error, trace, warn};
30
31#[derive(Debug, Clone, Copy)]
33pub enum PoolSize {
34 PerHost(NonZeroUsize),
40
41 PerShard(NonZeroUsize),
47}
48
49impl Default for PoolSize {
50 fn default() -> Self {
51 PoolSize::PerShard(NonZeroUsize::new(1).unwrap())
52 }
53}
54
55#[derive(Clone)]
56pub(crate) struct PoolConfig {
57 pub(crate) connection_config: ConnectionConfig,
58 pub(crate) pool_size: PoolSize,
59 pub(crate) can_use_shard_aware_port: bool,
60}
61
62#[cfg(test)]
63impl Default for PoolConfig {
64 fn default() -> Self {
65 Self {
66 connection_config: Default::default(),
67 pool_size: Default::default(),
68 can_use_shard_aware_port: true,
69 }
70 }
71}
72
73impl PoolConfig {
74 fn to_host_pool_config(&self, endpoint: &UntranslatedEndpoint) -> HostPoolConfig {
75 HostPoolConfig {
76 connection_config: self.connection_config.to_host_connection_config(endpoint),
77 pool_size: self.pool_size,
78 can_use_shard_aware_port: self.can_use_shard_aware_port,
79 }
80 }
81}
82
83#[derive(Clone)]
84struct HostPoolConfig {
85 pub(crate) connection_config: HostConnectionConfig,
86 pub(crate) pool_size: PoolSize,
87 pub(crate) can_use_shard_aware_port: bool,
88}
89
90#[cfg(test)]
91impl Default for HostPoolConfig {
92 fn default() -> Self {
93 Self {
94 connection_config: Default::default(),
95 pool_size: Default::default(),
96 can_use_shard_aware_port: true,
97 }
98 }
99}
100
101enum MaybePoolConnections {
102 Initializing,
104
105 Broken(ConnectionError),
109
110 Ready(PoolConnections),
112}
113
114impl std::fmt::Debug for MaybePoolConnections {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 match self {
117 MaybePoolConnections::Initializing => write!(f, "Initializing"),
118 MaybePoolConnections::Broken(err) => write!(f, "Broken({:?})", err),
119 MaybePoolConnections::Ready(conns) => write!(f, "{:?}", conns),
120 }
121 }
122}
123
124#[derive(Clone)]
125enum PoolConnections {
126 NotSharded(Vec<Arc<Connection>>),
127 Sharded {
128 sharder: Sharder,
129 connections: Vec<Vec<Arc<Connection>>>,
130 },
131}
132
133struct ConnectionVectorWrapper<'a>(&'a Vec<Arc<Connection>>);
134impl std::fmt::Debug for ConnectionVectorWrapper<'_> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_list()
137 .entries(self.0.iter().map(|conn| conn.get_connect_address()))
138 .finish()
139 }
140}
141
142struct ShardedConnectionVectorWrapper<'a>(&'a Vec<Vec<Arc<Connection>>>);
143impl std::fmt::Debug for ShardedConnectionVectorWrapper<'_> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_list()
146 .entries(
147 self.0
148 .iter()
149 .enumerate()
150 .map(|(shard_no, conn_vec)| (shard_no, ConnectionVectorWrapper(conn_vec))),
151 )
152 .finish()
153 }
154}
155
156impl std::fmt::Debug for PoolConnections {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 match self {
159 PoolConnections::NotSharded(conns) => {
160 write!(f, "non-sharded: {:?}", ConnectionVectorWrapper(conns))
161 }
162 PoolConnections::Sharded {
163 sharder,
164 connections,
165 } => write!(
166 f,
167 "sharded(nr_shards:{}, msb_ignore_bits:{}): {:?}",
168 sharder.nr_shards,
169 sharder.msb_ignore,
170 ShardedConnectionVectorWrapper(connections)
171 ),
172 }
173 }
174}
175
176#[derive(Clone)]
177pub(crate) struct NodeConnectionPool {
178 conns: Arc<ArcSwap<MaybePoolConnections>>,
179 use_keyspace_request_sender: mpsc::Sender<UseKeyspaceRequest>,
180 _refiller_handle: Arc<RemoteHandle<()>>,
181 pool_updated_notify: Arc<Notify>,
182 endpoint: Arc<RwLock<UntranslatedEndpoint>>,
183}
184
185impl std::fmt::Debug for NodeConnectionPool {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("NodeConnectionPool")
188 .field("conns", &self.conns)
189 .finish_non_exhaustive()
190 }
191}
192
193impl NodeConnectionPool {
194 pub(crate) fn new(
195 endpoint: UntranslatedEndpoint,
196 pool_config: &PoolConfig,
197 current_keyspace: Option<VerifiedKeyspaceName>,
198 pool_empty_notifier: broadcast::Sender<()>,
199 #[cfg(feature = "metrics")] metrics: Arc<Metrics>,
200 ) -> Self {
201 let (use_keyspace_request_sender, use_keyspace_request_receiver) = mpsc::channel(1);
202 let pool_updated_notify = Arc::new(Notify::new());
203
204 let host_pool_config = pool_config.to_host_pool_config(&endpoint);
205
206 let arced_endpoint = Arc::new(RwLock::new(endpoint));
207
208 let refiller = PoolRefiller::new(
209 arced_endpoint.clone(),
210 host_pool_config,
211 current_keyspace,
212 pool_updated_notify.clone(),
213 pool_empty_notifier,
214 #[cfg(feature = "metrics")]
215 metrics,
216 );
217
218 let conns = refiller.get_shared_connections();
219 let (fut, refiller_handle) = refiller.run(use_keyspace_request_receiver).remote_handle();
220 tokio::spawn(fut);
221
222 Self {
223 conns,
224 use_keyspace_request_sender,
225 _refiller_handle: Arc::new(refiller_handle),
226 pool_updated_notify,
227 endpoint: arced_endpoint,
228 }
229 }
230
231 pub(crate) fn is_connected(&self) -> bool {
232 let maybe_conns = self.conns.load();
233 match maybe_conns.as_ref() {
234 MaybePoolConnections::Initializing => false,
235 MaybePoolConnections::Broken(_) => false,
236 MaybePoolConnections::Ready(_pool_connections) => true,
238 }
239 }
240
241 pub(crate) fn update_endpoint(&self, new_endpoint: PeerEndpoint) {
242 *self.endpoint.write().unwrap() = UntranslatedEndpoint::Peer(new_endpoint);
243 }
244
245 pub(crate) fn sharder(&self) -> Option<Sharder> {
246 self.with_connections(|pool_conns| match pool_conns {
247 PoolConnections::NotSharded(_) => None,
248 PoolConnections::Sharded { sharder, .. } => Some(sharder.clone()),
249 })
250 .unwrap_or(None)
251 }
252
253 pub(crate) fn connection_for_shard(
254 &self,
255 shard: Shard,
256 ) -> Result<Arc<Connection>, ConnectionPoolError> {
257 trace!(shard = shard, "Selecting connection for shard");
258 self.with_connections(|pool_conns| match pool_conns {
259 PoolConnections::NotSharded(conns) => {
260 Self::choose_random_connection_from_slice(conns).unwrap()
261 }
262 PoolConnections::Sharded {
263 connections,
264 sharder
265 } => {
266 let shard = shard
267 .try_into()
268 .unwrap_or_else(|_| {
271 error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard);
272 0
273 });
274 Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
275 }
276 })
277 }
278
279 pub(crate) fn random_connection(&self) -> Result<Arc<Connection>, ConnectionPoolError> {
280 trace!("Selecting random connection");
281 self.with_connections(|pool_conns| match pool_conns {
282 PoolConnections::NotSharded(conns) => {
283 Self::choose_random_connection_from_slice(conns).unwrap()
284 }
285 PoolConnections::Sharded {
286 sharder,
287 connections,
288 } => {
289 let shard: u16 = rand::rng().random_range(0..sharder.nr_shards.get());
290 Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
291 }
292 })
293 }
294
295 fn connection_for_shard_helper(
297 shard: u16,
298 nr_shards: ShardCount,
299 shard_conns: &[Vec<Arc<Connection>>],
300 ) -> Arc<Connection> {
301 if let Some(conn) = Self::choose_random_connection_from_slice(&shard_conns[shard as usize])
303 {
304 trace!(shard = shard, "Found connection for the target shard");
305 return conn;
306 }
307
308 let mut shards_to_try: Vec<u16> = (0..shard).chain(shard + 1..nr_shards.get()).collect();
310
311 let orig_shard = shard;
312 while !shards_to_try.is_empty() {
313 let idx = rand::rng().random_range(0..shards_to_try.len());
314 let shard = shards_to_try.swap_remove(idx);
315
316 if let Some(conn) =
317 Self::choose_random_connection_from_slice(&shard_conns[shard as usize])
318 {
319 trace!(
320 orig_shard = orig_shard,
321 shard = shard,
322 "Choosing connection for a different shard"
323 );
324 return conn;
325 }
326 }
327
328 unreachable!("could not find any connection in supposedly non-empty pool")
329 }
330
331 pub(crate) async fn use_keyspace(
332 &self,
333 keyspace_name: VerifiedKeyspaceName,
334 ) -> Result<(), UseKeyspaceError> {
335 let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
336
337 self.use_keyspace_request_sender
338 .send(UseKeyspaceRequest {
339 keyspace_name,
340 response_sender,
341 })
342 .await
343 .expect("Bug in NodeConnectionPool::use_keyspace sending");
344 response_receiver.await.unwrap() }
348
349 pub(crate) async fn wait_until_initialized(&self) {
353 let notified = self.pool_updated_notify.notified();
356
357 if let MaybePoolConnections::Initializing = **self.conns.load() {
358 notified.await;
360 }
361 }
362
363 pub(crate) fn get_working_connections(
364 &self,
365 ) -> Result<Vec<Arc<Connection>>, ConnectionPoolError> {
366 self.with_connections(|pool_conns| match pool_conns {
367 PoolConnections::NotSharded(conns) => conns.clone(),
368 PoolConnections::Sharded { connections, .. } => {
369 connections.iter().flatten().cloned().collect()
370 }
371 })
372 }
373
374 fn choose_random_connection_from_slice(v: &[Arc<Connection>]) -> Option<Arc<Connection>> {
375 trace!(
376 connections = tracing::field::display(
377 v.iter().map(|conn| conn.get_connect_address()).format(", ")
378 ),
379 "Available"
380 );
381 if v.is_empty() {
382 None
383 } else if v.len() == 1 {
384 Some(v[0].clone())
385 } else {
386 let idx = rand::rng().random_range(0..v.len());
387 Some(v[idx].clone())
388 }
389 }
390
391 fn with_connections<T>(
392 &self,
393 f: impl FnOnce(&PoolConnections) -> T,
394 ) -> Result<T, ConnectionPoolError> {
395 let conns = self.conns.load_full();
396 match &*conns {
397 MaybePoolConnections::Ready(pool_connections) => Ok(f(pool_connections)),
398 MaybePoolConnections::Broken(err) => Err(ConnectionPoolError::Broken {
399 last_connection_error: err.clone(),
400 }),
401 MaybePoolConnections::Initializing => Err(ConnectionPoolError::Initializing),
402 }
403 }
404}
405
406const EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER: usize = 10;
407
408const MIN_FILL_BACKOFF: Duration = Duration::from_millis(50);
410const MAX_FILL_BACKOFF: Duration = Duration::from_secs(10);
411const FILL_BACKOFF_MULTIPLIER: u32 = 2;
412
413struct RefillDelayStrategy {
415 current_delay: Duration,
416}
417
418impl RefillDelayStrategy {
419 fn new() -> Self {
420 Self {
421 current_delay: MIN_FILL_BACKOFF,
422 }
423 }
424
425 fn get_delay(&self) -> Duration {
426 self.current_delay
427 }
428
429 fn on_successful_fill(&mut self) {
430 self.current_delay = MIN_FILL_BACKOFF;
431 }
432
433 fn on_fill_error(&mut self) {
434 self.current_delay = std::cmp::min(
435 MAX_FILL_BACKOFF,
436 self.current_delay * FILL_BACKOFF_MULTIPLIER,
437 );
438 }
439}
440
441struct PoolRefiller {
442 pool_config: HostPoolConfig,
444
445 endpoint: Arc<RwLock<UntranslatedEndpoint>>,
447
448 shard_aware_port: Option<u16>,
450 sharder: Option<Sharder>,
451
452 shared_conns: Arc<ArcSwap<MaybePoolConnections>>,
454 conns: Vec<Vec<Arc<Connection>>>,
455
456 had_error_since_last_refill: bool,
459
460 refill_delay_strategy: RefillDelayStrategy,
461
462 ready_connections:
466 FuturesUnordered<Pin<Box<dyn Future<Output = OpenedConnectionEvent> + Send + 'static>>>,
467
468 connection_errors:
470 FuturesUnordered<Pin<Box<dyn Future<Output = BrokenConnectionEvent> + Send + 'static>>>,
471
472 excess_connections: Vec<Arc<Connection>>,
485
486 current_keyspace: Option<VerifiedKeyspaceName>,
487
488 pool_updated_notify: Arc<Notify>,
490
491 pool_empty_notifier: broadcast::Sender<()>,
493
494 #[cfg(feature = "metrics")]
495 metrics: Arc<Metrics>,
496}
497
498#[derive(Debug)]
499struct UseKeyspaceRequest {
500 keyspace_name: VerifiedKeyspaceName,
501 response_sender: tokio::sync::oneshot::Sender<Result<(), UseKeyspaceError>>,
502}
503
504impl PoolRefiller {
505 pub(crate) fn new(
506 endpoint: Arc<RwLock<UntranslatedEndpoint>>,
507 pool_config: HostPoolConfig,
508 current_keyspace: Option<VerifiedKeyspaceName>,
509 pool_updated_notify: Arc<Notify>,
510 pool_empty_notifier: broadcast::Sender<()>,
511 #[cfg(feature = "metrics")] metrics: Arc<Metrics>,
512 ) -> Self {
513 let conns = vec![Vec::new()];
516 let shared_conns = Arc::new(ArcSwap::new(Arc::new(MaybePoolConnections::Initializing)));
517
518 Self {
519 endpoint,
520 pool_config,
521
522 shard_aware_port: None,
523 sharder: None,
524
525 shared_conns,
526 conns,
527
528 had_error_since_last_refill: false,
529 refill_delay_strategy: RefillDelayStrategy::new(),
530
531 ready_connections: FuturesUnordered::new(),
532 connection_errors: FuturesUnordered::new(),
533
534 excess_connections: Vec::new(),
535
536 current_keyspace,
537
538 pool_updated_notify,
539 pool_empty_notifier,
540
541 #[cfg(feature = "metrics")]
542 metrics,
543 }
544 }
545
546 fn endpoint_description(&self) -> NodeAddr {
547 self.endpoint.read().unwrap().address()
548 }
549
550 pub(crate) fn get_shared_connections(&self) -> Arc<ArcSwap<MaybePoolConnections>> {
551 self.shared_conns.clone()
552 }
553
554 pub(crate) async fn run(
556 mut self,
557 mut use_keyspace_request_receiver: mpsc::Receiver<UseKeyspaceRequest>,
558 ) {
559 debug!(
560 "[{}] Started asynchronous pool worker",
561 self.endpoint_description()
562 );
563
564 let mut next_refill_time = tokio::time::Instant::now();
565 let mut refill_scheduled = true;
566
567 loop {
568 tokio::select! {
569 _ = tokio::time::sleep_until(next_refill_time), if refill_scheduled => {
570 self.had_error_since_last_refill = false;
571 self.start_filling();
572 refill_scheduled = false;
573 }
574
575 evt = self.ready_connections.select_next_some(), if !self.ready_connections.is_empty() => {
576 self.handle_ready_connection(evt);
577
578 if self.is_full() {
579 debug!(
580 "[{}] Pool is full, clearing {} excess connections",
581 self.endpoint_description(),
582 self.excess_connections.len()
583 );
584 self.excess_connections.clear();
585 }
586 }
587
588 evt = self.connection_errors.select_next_some(), if !self.connection_errors.is_empty() => {
589 if let Some(conn) = evt.connection.upgrade() {
590 debug!("[{}] Got error for connection {:p}: {:?}", self.endpoint_description(), Arc::as_ptr(&conn), evt.error);
591 self.remove_connection(conn, evt.error);
592 }
593 }
594
595 req = use_keyspace_request_receiver.recv() => {
596 if let Some(req) = req {
597 debug!("[{}] Requested keyspace change: {}", self.endpoint_description(), req.keyspace_name.as_str());
598 self.use_keyspace(req.keyspace_name, req.response_sender);
599 } else {
600 trace!("[{}] Keyspace request channel dropped, stopping asynchronous pool worker", self.endpoint_description());
604 return;
605 }
606 }
607 }
608 trace!(
609 pool_state = ?ShardedConnectionVectorWrapper(&self.conns)
610 );
611
612 if !refill_scheduled && self.need_filling() {
614 if self.had_error_since_last_refill {
615 self.refill_delay_strategy.on_fill_error();
616 } else {
617 self.refill_delay_strategy.on_successful_fill();
618 }
619 let delay = self.refill_delay_strategy.get_delay();
620 debug!(
621 "[{}] Scheduling next refill in {} ms",
622 self.endpoint_description(),
623 delay.as_millis(),
624 );
625
626 next_refill_time = tokio::time::Instant::now() + delay;
627 refill_scheduled = true;
628 }
629 }
630 }
631
632 fn is_filling(&self) -> bool {
633 !self.ready_connections.is_empty()
634 }
635
636 fn is_full(&self) -> bool {
637 match self.pool_config.pool_size {
638 PoolSize::PerHost(target) => self.active_connection_count() >= target.get(),
639 PoolSize::PerShard(target) => {
640 self.conns.iter().all(|conns| conns.len() >= target.get())
641 }
642 }
643 }
644
645 fn is_empty(&self) -> bool {
646 self.conns.iter().all(|conns| conns.is_empty())
647 }
648
649 fn need_filling(&self) -> bool {
650 !self.is_filling() && !self.is_full()
651 }
652
653 fn can_use_shard_aware_port(&self) -> bool {
654 self.sharder.is_some()
655 && self.shard_aware_port.is_some()
656 && self.pool_config.can_use_shard_aware_port
657 }
658
659 fn start_filling(&mut self) {
663 if self.is_empty() {
664 trace!(
668 "[{}] Will open the first connection to the node",
669 self.endpoint_description()
670 );
671 self.start_opening_connection(None);
672 return;
673 }
674
675 if self.can_use_shard_aware_port() {
676 if let PoolSize::PerShard(target) = self.pool_config.pool_size {
678 for (shard_id, shard_conns) in self.conns.iter().enumerate() {
680 let to_open_count = target.get().saturating_sub(shard_conns.len());
681 if to_open_count == 0 {
682 continue;
683 }
684 trace!(
685 "[{}] Will open {} connections to shard {}",
686 self.endpoint_description(),
687 to_open_count,
688 shard_id,
689 );
690 for _ in 0..to_open_count {
691 self.start_opening_connection(Some(shard_id as Shard));
692 }
693 }
694 return;
695 }
696 }
697 let to_open_count = match self.pool_config.pool_size {
700 PoolSize::PerHost(target) => {
701 target.get().saturating_sub(self.active_connection_count())
702 }
703 PoolSize::PerShard(target) => self
704 .conns
705 .iter()
706 .map(|conns| target.get().saturating_sub(conns.len()))
707 .sum::<usize>(),
708 };
709 trace!(
715 "[{}] Will open {} non-shard-aware connections",
716 self.endpoint_description(),
717 to_open_count,
718 );
719 for _ in 0..to_open_count {
720 self.start_opening_connection(None);
721 }
722 }
723
724 fn handle_ready_connection(&mut self, evt: OpenedConnectionEvent) {
726 match evt.result {
727 Err(err) => {
728 if evt.requested_shard.is_some() {
729 debug!(
741 "[{}] Failed to open connection to the shard-aware port: {:?}, will retry with regular port",
742 self.endpoint_description(),
743 err,
744 );
745 self.start_opening_connection(None);
746 } else {
747 self.had_error_since_last_refill = true;
751 debug!(
752 "[{}] Failed to open connection to the non-shard-aware port: {:?}",
753 self.endpoint_description(),
754 err,
755 );
756
757 if !self.is_filling() && self.is_empty() {
760 self.update_shared_conns(Some(err));
761 }
762 }
763 }
764 Ok((connection, error_receiver)) => {
765 let shard_info = connection.get_shard_info().as_ref();
767 let sharder = shard_info.map(|s| s.get_sharder());
768 let shard_id = shard_info.map_or(0, |s| s.shard as usize);
769 self.maybe_reshard(sharder);
770
771 if self.shard_aware_port != connection.get_shard_aware_port() {
773 debug!(
774 "[{}] Updating shard aware port: {:?}",
775 self.endpoint_description(),
776 connection.get_shard_aware_port(),
777 );
778 self.shard_aware_port = connection.get_shard_aware_port();
779 }
780
781 if let Some(keyspace) = &self.current_keyspace {
784 if evt.keyspace_name.as_ref() != Some(keyspace) {
785 self.start_setting_keyspace_for_connection(
791 connection,
792 error_receiver,
793 evt.requested_shard,
794 );
795 return;
796 }
797 }
798
799 let can_be_accepted = match self.pool_config.pool_size {
802 PoolSize::PerHost(target) => self.active_connection_count() < target.get(),
803 PoolSize::PerShard(target) => self.conns[shard_id].len() < target.get(),
804 };
805
806 if can_be_accepted {
807 let conn = Arc::new(connection);
812 trace!(
813 "[{}] Adding connection {:p} to shard {} pool, now there are {} for the shard, total {}",
814 self.endpoint_description(),
815 Arc::as_ptr(&conn),
816 shard_id,
817 self.conns[shard_id].len() + 1,
818 self.active_connection_count() + 1,
819 );
820
821 self.connection_errors
822 .push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed());
823 self.conns[shard_id].push(conn);
824
825 self.update_shared_conns(None);
826 } else if evt.requested_shard.is_some() {
827 debug!(
834 "[{}] Excess shard-aware port connection for shard {}; will retry with non-shard-aware port",
835 self.endpoint_description(),
836 shard_id,
837 );
838
839 self.start_opening_connection(None);
840 } else {
841 let conn = Arc::new(connection);
847 trace!(
848 "[{}] Storing excess connection {:p} for shard {}",
849 self.endpoint_description(),
850 Arc::as_ptr(&conn),
851 shard_id,
852 );
853
854 self.connection_errors
855 .push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed());
856 self.excess_connections.push(conn);
857
858 let excess_connection_limit = self.excess_connection_limit();
859 if self.excess_connections.len() > excess_connection_limit {
860 debug!(
861 "[{}] Excess connection pool exceeded limit of {} connections - clearing",
862 self.endpoint_description(),
863 excess_connection_limit,
864 );
865 self.excess_connections.clear();
866 }
867 }
868 }
869 }
870 }
871
872 fn start_opening_connection(&self, shard: Option<Shard>) {
877 let cfg = self.pool_config.connection_config.clone();
878 let mut endpoint = self.endpoint.read().unwrap().clone();
879
880 #[cfg(feature = "metrics")]
881 let count_in_metrics = {
882 let metrics = Arc::clone(&self.metrics);
883 move |connect_result: &Result<_, ConnectionError>| {
884 if connect_result.is_ok() {
885 metrics.inc_total_connections();
886 } else if let Err(ConnectionError::ConnectTimeout) = &connect_result {
887 metrics.inc_connection_timeouts();
888 }
889 }
890 };
891
892 let fut = match (self.sharder.clone(), self.shard_aware_port, shard) {
893 (Some(sharder), Some(port), Some(shard)) => async move {
894 let shard_aware_endpoint = {
895 endpoint.set_port(port);
896 endpoint
897 };
898 let result = open_connection_to_shard_aware_port(
899 &shard_aware_endpoint,
900 shard,
901 sharder.clone(),
902 &cfg,
903 )
904 .await;
905
906 #[cfg(feature = "metrics")]
907 count_in_metrics(&result);
908
909 OpenedConnectionEvent {
910 result,
911 requested_shard: Some(shard),
912 keyspace_name: None,
913 }
914 }
915 .boxed(),
916 _ => async move {
917 let non_shard_aware_endpoint = endpoint;
918 let result = open_connection(&non_shard_aware_endpoint, None, &cfg).await;
919
920 #[cfg(feature = "metrics")]
921 count_in_metrics(&result);
922
923 OpenedConnectionEvent {
924 result,
925 requested_shard: None,
926 keyspace_name: None,
927 }
928 }
929 .boxed(),
930 };
931 self.ready_connections.push(fut);
932 }
933
934 fn maybe_reshard(&mut self, new_sharder: Option<Sharder>) {
935 if self.sharder == new_sharder {
936 return;
937 }
938
939 debug!(
940 "[{}] New sharder: {:?}, clearing all connections",
941 self.endpoint_description(),
942 new_sharder,
943 );
944
945 self.sharder.clone_from(&new_sharder);
946
947 self.conns.clear();
951
952 let shard_count = new_sharder.map_or(1, |s| s.nr_shards.get() as usize);
953 self.conns.resize_with(shard_count, Vec::new);
954
955 self.excess_connections.clear();
956 }
957
958 fn update_shared_conns(&mut self, last_error: Option<ConnectionError>) {
962 let new_conns = if self.is_empty() {
963 Arc::new(MaybePoolConnections::Broken(last_error.unwrap()))
964 } else {
965 let new_conns = if let Some(sharder) = self.sharder.as_ref() {
966 debug_assert_eq!(self.conns.len(), sharder.nr_shards.get() as usize);
967 PoolConnections::Sharded {
968 sharder: sharder.clone(),
969 connections: self.conns.clone(),
970 }
971 } else {
972 debug_assert_eq!(self.conns.len(), 1);
973 PoolConnections::NotSharded(self.conns[0].clone())
974 };
975 Arc::new(MaybePoolConnections::Ready(new_conns))
976 };
977
978 self.shared_conns.store(new_conns);
980
981 self.pool_updated_notify.notify_waiters();
983 }
984
985 fn remove_connection(&mut self, connection: Arc<Connection>, last_error: ConnectionError) {
988 let ptr = Arc::as_ptr(&connection);
989
990 let maybe_remove_in_vec = |v: &mut Vec<Arc<Connection>>| -> bool {
991 let maybe_idx = v
992 .iter()
993 .enumerate()
994 .find(|(_, other_conn)| Arc::ptr_eq(&connection, other_conn))
995 .map(|(idx, _)| idx);
996 match maybe_idx {
997 Some(idx) => {
998 v.swap_remove(idx);
999 #[cfg(feature = "metrics")]
1000 self.metrics.dec_total_connections();
1001 true
1002 }
1003 None => false,
1004 }
1005 };
1006
1007 let shard_id = connection
1010 .get_shard_info()
1011 .as_ref()
1012 .map_or(0, |s| s.shard as usize);
1013 if shard_id < self.conns.len() && maybe_remove_in_vec(&mut self.conns[shard_id]) {
1014 trace!(
1015 "[{}] Connection {:p} removed from shard {} pool, now there is {} for the shard, total {}",
1016 self.endpoint_description(),
1017 ptr,
1018 shard_id,
1019 self.conns[shard_id].len(),
1020 self.active_connection_count(),
1021 );
1022 if self.is_empty() {
1023 let _ = self.pool_empty_notifier.send(());
1024 }
1025 self.update_shared_conns(Some(last_error));
1026 return;
1027 }
1028
1029 if maybe_remove_in_vec(&mut self.excess_connections) {
1031 trace!(
1032 "[{}] Connection {:p} removed from excess connection pool",
1033 self.endpoint_description(),
1034 ptr,
1035 );
1036 return;
1037 }
1038
1039 trace!(
1040 "[{}] Connection {:p} was already removed",
1041 self.endpoint_description(),
1042 ptr,
1043 );
1044 }
1045
1046 fn use_keyspace(
1052 &mut self,
1053 keyspace_name: VerifiedKeyspaceName,
1054 response_sender: tokio::sync::oneshot::Sender<Result<(), UseKeyspaceError>>,
1055 ) {
1056 self.current_keyspace = Some(keyspace_name.clone());
1057
1058 let mut conns = self.conns.clone();
1059 let address = self.endpoint.read().unwrap().address();
1060 let connect_timeout = self.pool_config.connection_config.connect_timeout;
1061
1062 let fut = async move {
1063 let mut use_keyspace_futures = Vec::new();
1064
1065 for shard_conns in conns.iter_mut() {
1066 for conn in shard_conns.iter_mut() {
1067 let fut = conn.use_keyspace(&keyspace_name);
1068 use_keyspace_futures.push(fut);
1069 }
1070 }
1071
1072 if use_keyspace_futures.is_empty() {
1073 return Ok(());
1074 }
1075
1076 let use_keyspace_results: Vec<Result<(), UseKeyspaceError>> = tokio::time::timeout(
1077 connect_timeout,
1078 futures::future::join_all(use_keyspace_futures),
1079 )
1080 .await
1081 .map_err(|_| UseKeyspaceError::RequestTimeout(connect_timeout))?;
1083
1084 crate::cluster::use_keyspace_result(use_keyspace_results.into_iter())
1085 };
1086
1087 tokio::task::spawn(async move {
1088 let res = fut.await;
1089 match &res {
1090 Ok(()) => debug!("[{}] Successfully changed current keyspace", address),
1091 Err(err) => warn!("[{}] Failed to change keyspace: {:?}", address, err),
1092 }
1093 let _ = response_sender.send(res);
1094 });
1095 }
1096
1097 fn start_setting_keyspace_for_connection(
1100 &mut self,
1101 connection: Connection,
1102 error_receiver: ErrorReceiver,
1103 requested_shard: Option<Shard>,
1104 ) {
1105 let keyspace_name = self.current_keyspace.as_ref().cloned().unwrap();
1108 self.ready_connections.push(
1109 async move {
1110 let result = connection.use_keyspace(&keyspace_name).await;
1111 if let Err(err) = result {
1112 warn!(
1113 "[{}] Failed to set keyspace for new connection: {}",
1114 connection.get_connect_address().ip(),
1115 err,
1116 );
1117 }
1118 OpenedConnectionEvent {
1119 result: Ok((connection, error_receiver)),
1120 requested_shard,
1121 keyspace_name: Some(keyspace_name),
1122 }
1123 }
1124 .boxed(),
1125 );
1126 }
1127
1128 fn active_connection_count(&self) -> usize {
1129 self.conns.iter().map(Vec::len).sum::<usize>()
1130 }
1131
1132 fn excess_connection_limit(&self) -> usize {
1133 match self.pool_config.pool_size {
1134 PoolSize::PerShard(_) => {
1135 EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER
1136 * self
1137 .sharder
1138 .as_ref()
1139 .map_or(1, |s| s.nr_shards.get() as usize)
1140 }
1141
1142 PoolSize::PerHost(_) => 0,
1144 }
1145 }
1146}
1147
1148struct BrokenConnectionEvent {
1149 connection: Weak<Connection>,
1150 error: ConnectionError,
1151}
1152
1153async fn wait_for_error(
1154 connection: Weak<Connection>,
1155 error_receiver: ErrorReceiver,
1156) -> BrokenConnectionEvent {
1157 BrokenConnectionEvent {
1158 connection,
1159 error: error_receiver.await.unwrap_or_else(|_| {
1160 ConnectionError::BrokenConnection(BrokenConnectionErrorKind::ChannelError.into())
1161 }),
1162 }
1163}
1164
1165struct OpenedConnectionEvent {
1166 result: Result<(Connection, ErrorReceiver), ConnectionError>,
1167 requested_shard: Option<Shard>,
1168 keyspace_name: Option<VerifiedKeyspaceName>,
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173 use super::super::connection::{open_connection_to_shard_aware_port, HostConnectionConfig};
1174 use crate::cluster::metadata::UntranslatedEndpoint;
1175 use crate::cluster::node::ResolvedContactPoint;
1176 use crate::routing::{ShardCount, Sharder};
1177 use crate::test_utils::setup_tracing;
1178 use std::net::{SocketAddr, ToSocketAddrs};
1179
1180 #[tokio::test]
1184 #[cfg(not(scylla_cloud_tests))]
1185 async fn many_connections() {
1186 setup_tracing();
1187 let connections_number = 512;
1188
1189 let connect_address: SocketAddr = std::env::var("SCYLLA_URI")
1190 .unwrap_or_else(|_| "127.0.0.1:9042".to_string())
1191 .to_socket_addrs()
1192 .unwrap()
1193 .next()
1194 .unwrap();
1195
1196 let connection_config = HostConnectionConfig {
1197 compression: None,
1198 tcp_nodelay: true,
1199 tls_config: None,
1200 ..Default::default()
1201 };
1202
1203 let sharder = Sharder::new(ShardCount::new(3).unwrap(), 12);
1207
1208 let endpoint = UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
1209 address: connect_address,
1210 datacenter: None,
1211 });
1212
1213 let conns = (0..connections_number).map(|_| {
1215 open_connection_to_shard_aware_port(&endpoint, 0, sharder.clone(), &connection_config)
1216 });
1217
1218 let joined = futures::future::join_all(conns).await;
1219
1220 for res in joined {
1222 res.unwrap();
1223 }
1224 }
1225}