scylla/network/
connection_pool.rs

1use super::connection::{
2    open_connection, open_connection_to_shard_aware_port, Connection, ConnectionConfig,
3    ErrorReceiver, HostConnectionConfig, VerifiedKeyspaceName,
4};
5
6use crate::errors::{
7    BrokenConnectionErrorKind, ConnectionError, ConnectionPoolError, UseKeyspaceError,
8};
9use crate::routing::{Shard, ShardCount, Sharder};
10
11use crate::cluster::metadata::{PeerEndpoint, UntranslatedEndpoint};
12
13#[cfg(feature = "metrics")]
14use crate::observability::metrics::Metrics;
15
16use crate::cluster::NodeAddr;
17
18use arc_swap::ArcSwap;
19use futures::{future::RemoteHandle, stream::FuturesUnordered, Future, FutureExt, StreamExt};
20use itertools::Itertools;
21use rand::Rng;
22use std::convert::TryInto;
23use std::num::NonZeroUsize;
24use std::pin::Pin;
25use std::sync::{Arc, RwLock, Weak};
26use std::time::Duration;
27
28use tokio::sync::{broadcast, mpsc, Notify};
29use tracing::{debug, error, trace, warn};
30
31/// The target size of a per-node connection pool.
32#[derive(Debug, Clone, Copy)]
33pub enum PoolSize {
34    /// Indicates that the pool should establish given number of connections to the node.
35    ///
36    /// If this option is used with a Scylla cluster, it is not guaranteed that connections will be
37    /// distributed evenly across shards. Use this option if you cannot use the shard-aware port
38    /// and you suffer from the "connection storm" problems.
39    PerHost(NonZeroUsize),
40
41    /// Indicates that the pool should establish given number of connections to each shard on the node.
42    ///
43    /// Cassandra nodes will be treated as if they have only one shard.
44    ///
45    /// The recommended setting for Scylla is one connection per shard - `PerShard(1)`.
46    PerShard(NonZeroUsize),
47}
48
49impl Default for PoolSize {
50    fn default() -> Self {
51        PoolSize::PerShard(NonZeroUsize::new(1).unwrap())
52    }
53}
54
55#[derive(Clone)]
56pub(crate) struct PoolConfig {
57    pub(crate) connection_config: ConnectionConfig,
58    pub(crate) pool_size: PoolSize,
59    pub(crate) can_use_shard_aware_port: bool,
60}
61
62#[cfg(test)]
63impl Default for PoolConfig {
64    fn default() -> Self {
65        Self {
66            connection_config: Default::default(),
67            pool_size: Default::default(),
68            can_use_shard_aware_port: true,
69        }
70    }
71}
72
73impl PoolConfig {
74    fn to_host_pool_config(&self, endpoint: &UntranslatedEndpoint) -> HostPoolConfig {
75        HostPoolConfig {
76            connection_config: self.connection_config.to_host_connection_config(endpoint),
77            pool_size: self.pool_size,
78            can_use_shard_aware_port: self.can_use_shard_aware_port,
79        }
80    }
81}
82
83#[derive(Clone)]
84struct HostPoolConfig {
85    pub(crate) connection_config: HostConnectionConfig,
86    pub(crate) pool_size: PoolSize,
87    pub(crate) can_use_shard_aware_port: bool,
88}
89
90#[cfg(test)]
91impl Default for HostPoolConfig {
92    fn default() -> Self {
93        Self {
94            connection_config: Default::default(),
95            pool_size: Default::default(),
96            can_use_shard_aware_port: true,
97        }
98    }
99}
100
101enum MaybePoolConnections {
102    // The pool is being filled for the first time
103    Initializing,
104
105    // The pool is empty because either initial filling failed or all connections
106    // became broken; will be asynchronously refilled. Contains an error
107    // from the last connection attempt.
108    Broken(ConnectionError),
109
110    // The pool has some connections which are usable (or will be removed soon)
111    Ready(PoolConnections),
112}
113
114impl std::fmt::Debug for MaybePoolConnections {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self {
117            MaybePoolConnections::Initializing => write!(f, "Initializing"),
118            MaybePoolConnections::Broken(err) => write!(f, "Broken({:?})", err),
119            MaybePoolConnections::Ready(conns) => write!(f, "{:?}", conns),
120        }
121    }
122}
123
124#[derive(Clone)]
125enum PoolConnections {
126    NotSharded(Vec<Arc<Connection>>),
127    Sharded {
128        sharder: Sharder,
129        connections: Vec<Vec<Arc<Connection>>>,
130    },
131}
132
133struct ConnectionVectorWrapper<'a>(&'a Vec<Arc<Connection>>);
134impl std::fmt::Debug for ConnectionVectorWrapper<'_> {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_list()
137            .entries(self.0.iter().map(|conn| conn.get_connect_address()))
138            .finish()
139    }
140}
141
142struct ShardedConnectionVectorWrapper<'a>(&'a Vec<Vec<Arc<Connection>>>);
143impl std::fmt::Debug for ShardedConnectionVectorWrapper<'_> {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        f.debug_list()
146            .entries(
147                self.0
148                    .iter()
149                    .enumerate()
150                    .map(|(shard_no, conn_vec)| (shard_no, ConnectionVectorWrapper(conn_vec))),
151            )
152            .finish()
153    }
154}
155
156impl std::fmt::Debug for PoolConnections {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        match self {
159            PoolConnections::NotSharded(conns) => {
160                write!(f, "non-sharded: {:?}", ConnectionVectorWrapper(conns))
161            }
162            PoolConnections::Sharded {
163                sharder,
164                connections,
165            } => write!(
166                f,
167                "sharded(nr_shards:{}, msb_ignore_bits:{}): {:?}",
168                sharder.nr_shards,
169                sharder.msb_ignore,
170                ShardedConnectionVectorWrapper(connections)
171            ),
172        }
173    }
174}
175
176#[derive(Clone)]
177pub(crate) struct NodeConnectionPool {
178    conns: Arc<ArcSwap<MaybePoolConnections>>,
179    use_keyspace_request_sender: mpsc::Sender<UseKeyspaceRequest>,
180    _refiller_handle: Arc<RemoteHandle<()>>,
181    pool_updated_notify: Arc<Notify>,
182    endpoint: Arc<RwLock<UntranslatedEndpoint>>,
183}
184
185impl std::fmt::Debug for NodeConnectionPool {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("NodeConnectionPool")
188            .field("conns", &self.conns)
189            .finish_non_exhaustive()
190    }
191}
192
193impl NodeConnectionPool {
194    pub(crate) fn new(
195        endpoint: UntranslatedEndpoint,
196        pool_config: &PoolConfig,
197        current_keyspace: Option<VerifiedKeyspaceName>,
198        pool_empty_notifier: broadcast::Sender<()>,
199        #[cfg(feature = "metrics")] metrics: Arc<Metrics>,
200    ) -> Self {
201        let (use_keyspace_request_sender, use_keyspace_request_receiver) = mpsc::channel(1);
202        let pool_updated_notify = Arc::new(Notify::new());
203
204        let host_pool_config = pool_config.to_host_pool_config(&endpoint);
205
206        let arced_endpoint = Arc::new(RwLock::new(endpoint));
207
208        let refiller = PoolRefiller::new(
209            arced_endpoint.clone(),
210            host_pool_config,
211            current_keyspace,
212            pool_updated_notify.clone(),
213            pool_empty_notifier,
214            #[cfg(feature = "metrics")]
215            metrics,
216        );
217
218        let conns = refiller.get_shared_connections();
219        let (fut, refiller_handle) = refiller.run(use_keyspace_request_receiver).remote_handle();
220        tokio::spawn(fut);
221
222        Self {
223            conns,
224            use_keyspace_request_sender,
225            _refiller_handle: Arc::new(refiller_handle),
226            pool_updated_notify,
227            endpoint: arced_endpoint,
228        }
229    }
230
231    pub(crate) fn is_connected(&self) -> bool {
232        let maybe_conns = self.conns.load();
233        match maybe_conns.as_ref() {
234            MaybePoolConnections::Initializing => false,
235            MaybePoolConnections::Broken(_) => false,
236            // Here we use the assumption that _pool_connections is always non-empty.
237            MaybePoolConnections::Ready(_pool_connections) => true,
238        }
239    }
240
241    pub(crate) fn update_endpoint(&self, new_endpoint: PeerEndpoint) {
242        *self.endpoint.write().unwrap() = UntranslatedEndpoint::Peer(new_endpoint);
243    }
244
245    pub(crate) fn sharder(&self) -> Option<Sharder> {
246        self.with_connections(|pool_conns| match pool_conns {
247            PoolConnections::NotSharded(_) => None,
248            PoolConnections::Sharded { sharder, .. } => Some(sharder.clone()),
249        })
250        .unwrap_or(None)
251    }
252
253    pub(crate) fn connection_for_shard(
254        &self,
255        shard: Shard,
256    ) -> Result<Arc<Connection>, ConnectionPoolError> {
257        trace!(shard = shard, "Selecting connection for shard");
258        self.with_connections(|pool_conns| match pool_conns {
259            PoolConnections::NotSharded(conns) => {
260                Self::choose_random_connection_from_slice(conns).unwrap()
261            }
262            PoolConnections::Sharded {
263                connections,
264                sharder
265            } => {
266                let shard = shard
267                    .try_into()
268                    // It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy`
269                    // now, which can be implemented by a user in an arbitrary way.
270                    .unwrap_or_else(|_| {
271                        error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard);
272                        0
273                    });
274                Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
275            }
276        })
277    }
278
279    pub(crate) fn random_connection(&self) -> Result<Arc<Connection>, ConnectionPoolError> {
280        trace!("Selecting random connection");
281        self.with_connections(|pool_conns| match pool_conns {
282            PoolConnections::NotSharded(conns) => {
283                Self::choose_random_connection_from_slice(conns).unwrap()
284            }
285            PoolConnections::Sharded {
286                sharder,
287                connections,
288            } => {
289                let shard: u16 = rand::rng().random_range(0..sharder.nr_shards.get());
290                Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
291            }
292        })
293    }
294
295    // Tries to get a connection to given shard, if it's broken returns any working connection
296    fn connection_for_shard_helper(
297        shard: u16,
298        nr_shards: ShardCount,
299        shard_conns: &[Vec<Arc<Connection>>],
300    ) -> Arc<Connection> {
301        // Try getting the desired connection
302        if let Some(conn) = Self::choose_random_connection_from_slice(&shard_conns[shard as usize])
303        {
304            trace!(shard = shard, "Found connection for the target shard");
305            return conn;
306        }
307
308        // If this fails try getting any other in random order
309        let mut shards_to_try: Vec<u16> = (0..shard).chain(shard + 1..nr_shards.get()).collect();
310
311        let orig_shard = shard;
312        while !shards_to_try.is_empty() {
313            let idx = rand::rng().random_range(0..shards_to_try.len());
314            let shard = shards_to_try.swap_remove(idx);
315
316            if let Some(conn) =
317                Self::choose_random_connection_from_slice(&shard_conns[shard as usize])
318            {
319                trace!(
320                    orig_shard = orig_shard,
321                    shard = shard,
322                    "Choosing connection for a different shard"
323                );
324                return conn;
325            }
326        }
327
328        unreachable!("could not find any connection in supposedly non-empty pool")
329    }
330
331    pub(crate) async fn use_keyspace(
332        &self,
333        keyspace_name: VerifiedKeyspaceName,
334    ) -> Result<(), UseKeyspaceError> {
335        let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
336
337        self.use_keyspace_request_sender
338            .send(UseKeyspaceRequest {
339                keyspace_name,
340                response_sender,
341            })
342            .await
343            .expect("Bug in NodeConnectionPool::use_keyspace sending");
344        // Other end of this channel is in the PoolRefiller, can't be dropped while we have &self to _refiller_handle
345
346        response_receiver.await.unwrap() // PoolRefiller always responds
347    }
348
349    // Waits until the pool becomes initialized.
350    // The pool is considered initialized either if the first connection has been
351    // established or after first filling ends, whichever comes first.
352    pub(crate) async fn wait_until_initialized(&self) {
353        // First, register for the notification
354        // so that we don't miss it
355        let notified = self.pool_updated_notify.notified();
356
357        if let MaybePoolConnections::Initializing = **self.conns.load() {
358            // If the pool is not initialized yet, wait until we get a notification
359            notified.await;
360        }
361    }
362
363    pub(crate) fn get_working_connections(
364        &self,
365    ) -> Result<Vec<Arc<Connection>>, ConnectionPoolError> {
366        self.with_connections(|pool_conns| match pool_conns {
367            PoolConnections::NotSharded(conns) => conns.clone(),
368            PoolConnections::Sharded { connections, .. } => {
369                connections.iter().flatten().cloned().collect()
370            }
371        })
372    }
373
374    fn choose_random_connection_from_slice(v: &[Arc<Connection>]) -> Option<Arc<Connection>> {
375        trace!(
376            connections = tracing::field::display(
377                v.iter().map(|conn| conn.get_connect_address()).format(", ")
378            ),
379            "Available"
380        );
381        if v.is_empty() {
382            None
383        } else if v.len() == 1 {
384            Some(v[0].clone())
385        } else {
386            let idx = rand::rng().random_range(0..v.len());
387            Some(v[idx].clone())
388        }
389    }
390
391    fn with_connections<T>(
392        &self,
393        f: impl FnOnce(&PoolConnections) -> T,
394    ) -> Result<T, ConnectionPoolError> {
395        let conns = self.conns.load_full();
396        match &*conns {
397            MaybePoolConnections::Ready(pool_connections) => Ok(f(pool_connections)),
398            MaybePoolConnections::Broken(err) => Err(ConnectionPoolError::Broken {
399                last_connection_error: err.clone(),
400            }),
401            MaybePoolConnections::Initializing => Err(ConnectionPoolError::Initializing),
402        }
403    }
404}
405
406const EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER: usize = 10;
407
408// TODO: Make it configurable through a policy (issue #184)
409const MIN_FILL_BACKOFF: Duration = Duration::from_millis(50);
410const MAX_FILL_BACKOFF: Duration = Duration::from_secs(10);
411const FILL_BACKOFF_MULTIPLIER: u32 = 2;
412
413// A simple exponential strategy for pool fill backoffs.
414struct RefillDelayStrategy {
415    current_delay: Duration,
416}
417
418impl RefillDelayStrategy {
419    fn new() -> Self {
420        Self {
421            current_delay: MIN_FILL_BACKOFF,
422        }
423    }
424
425    fn get_delay(&self) -> Duration {
426        self.current_delay
427    }
428
429    fn on_successful_fill(&mut self) {
430        self.current_delay = MIN_FILL_BACKOFF;
431    }
432
433    fn on_fill_error(&mut self) {
434        self.current_delay = std::cmp::min(
435            MAX_FILL_BACKOFF,
436            self.current_delay * FILL_BACKOFF_MULTIPLIER,
437        );
438    }
439}
440
441struct PoolRefiller {
442    // Following information identify the pool and do not change
443    pool_config: HostPoolConfig,
444
445    // Following information is subject to updates on topology refresh
446    endpoint: Arc<RwLock<UntranslatedEndpoint>>,
447
448    // Following fields are updated with information from OPTIONS
449    shard_aware_port: Option<u16>,
450    sharder: Option<Sharder>,
451
452    // `shared_conns` is updated only after `conns` change
453    shared_conns: Arc<ArcSwap<MaybePoolConnections>>,
454    conns: Vec<Vec<Arc<Connection>>>,
455
456    // Set to true if there was an error since the last refill,
457    // set to false when refilling starts.
458    had_error_since_last_refill: bool,
459
460    refill_delay_strategy: RefillDelayStrategy,
461
462    // Receives information about connections becoming ready, i.e. newly connected
463    // or after its keyspace was correctly set.
464    // TODO: This should probably be a channel
465    ready_connections:
466        FuturesUnordered<Pin<Box<dyn Future<Output = OpenedConnectionEvent> + Send + 'static>>>,
467
468    // Receives information about breaking connections
469    connection_errors:
470        FuturesUnordered<Pin<Box<dyn Future<Output = BrokenConnectionEvent> + Send + 'static>>>,
471
472    // When connecting, Scylla always assigns the shard which handles the least
473    // number of connections. If there are some non-shard-aware clients
474    // connected to the same node, they might cause the shard distribution
475    // to be heavily biased and Scylla will be very reluctant to assign some shards.
476    //
477    // In order to combat this, if the pool is not full and we get a connection
478    // for a shard which was already filled, we keep those additional connections
479    // in order to affect how Scylla assigns shards. A similar method is used
480    // in Scylla's forks of the java and gocql drivers.
481    //
482    // The number of those connections is bounded by the number of shards multiplied
483    // by a constant factor, and are all closed when they exceed this number.
484    excess_connections: Vec<Arc<Connection>>,
485
486    current_keyspace: Option<VerifiedKeyspaceName>,
487
488    // Signaled when the connection pool is updated
489    pool_updated_notify: Arc<Notify>,
490
491    // Signaled when the connection pool becomes empty
492    pool_empty_notifier: broadcast::Sender<()>,
493
494    #[cfg(feature = "metrics")]
495    metrics: Arc<Metrics>,
496}
497
498#[derive(Debug)]
499struct UseKeyspaceRequest {
500    keyspace_name: VerifiedKeyspaceName,
501    response_sender: tokio::sync::oneshot::Sender<Result<(), UseKeyspaceError>>,
502}
503
504impl PoolRefiller {
505    pub(crate) fn new(
506        endpoint: Arc<RwLock<UntranslatedEndpoint>>,
507        pool_config: HostPoolConfig,
508        current_keyspace: Option<VerifiedKeyspaceName>,
509        pool_updated_notify: Arc<Notify>,
510        pool_empty_notifier: broadcast::Sender<()>,
511        #[cfg(feature = "metrics")] metrics: Arc<Metrics>,
512    ) -> Self {
513        // At the beginning, we assume the node does not have any shards
514        // and assume that the node is a Cassandra node
515        let conns = vec![Vec::new()];
516        let shared_conns = Arc::new(ArcSwap::new(Arc::new(MaybePoolConnections::Initializing)));
517
518        Self {
519            endpoint,
520            pool_config,
521
522            shard_aware_port: None,
523            sharder: None,
524
525            shared_conns,
526            conns,
527
528            had_error_since_last_refill: false,
529            refill_delay_strategy: RefillDelayStrategy::new(),
530
531            ready_connections: FuturesUnordered::new(),
532            connection_errors: FuturesUnordered::new(),
533
534            excess_connections: Vec::new(),
535
536            current_keyspace,
537
538            pool_updated_notify,
539            pool_empty_notifier,
540
541            #[cfg(feature = "metrics")]
542            metrics,
543        }
544    }
545
546    fn endpoint_description(&self) -> NodeAddr {
547        self.endpoint.read().unwrap().address()
548    }
549
550    pub(crate) fn get_shared_connections(&self) -> Arc<ArcSwap<MaybePoolConnections>> {
551        self.shared_conns.clone()
552    }
553
554    // The main loop of the pool refiller
555    pub(crate) async fn run(
556        mut self,
557        mut use_keyspace_request_receiver: mpsc::Receiver<UseKeyspaceRequest>,
558    ) {
559        debug!(
560            "[{}] Started asynchronous pool worker",
561            self.endpoint_description()
562        );
563
564        let mut next_refill_time = tokio::time::Instant::now();
565        let mut refill_scheduled = true;
566
567        loop {
568            tokio::select! {
569                _ = tokio::time::sleep_until(next_refill_time), if refill_scheduled => {
570                    self.had_error_since_last_refill = false;
571                    self.start_filling();
572                    refill_scheduled = false;
573                }
574
575                evt = self.ready_connections.select_next_some(), if !self.ready_connections.is_empty() => {
576                    self.handle_ready_connection(evt);
577
578                    if self.is_full() {
579                        debug!(
580                            "[{}] Pool is full, clearing {} excess connections",
581                            self.endpoint_description(),
582                            self.excess_connections.len()
583                        );
584                        self.excess_connections.clear();
585                    }
586                }
587
588                evt = self.connection_errors.select_next_some(), if !self.connection_errors.is_empty() => {
589                    if let Some(conn) = evt.connection.upgrade() {
590                        debug!("[{}] Got error for connection {:p}: {:?}", self.endpoint_description(), Arc::as_ptr(&conn), evt.error);
591                        self.remove_connection(conn, evt.error);
592                    }
593                }
594
595                req = use_keyspace_request_receiver.recv() => {
596                    if let Some(req) = req {
597                        debug!("[{}] Requested keyspace change: {}", self.endpoint_description(), req.keyspace_name.as_str());
598                        self.use_keyspace(req.keyspace_name, req.response_sender);
599                    } else {
600                        // The keyspace request channel is dropped.
601                        // This means that the corresponding pool is dropped.
602                        // We can stop here.
603                        trace!("[{}] Keyspace request channel dropped, stopping asynchronous pool worker", self.endpoint_description());
604                        return;
605                    }
606                }
607            }
608            trace!(
609                pool_state = ?ShardedConnectionVectorWrapper(&self.conns)
610            );
611
612            // Schedule refilling here
613            if !refill_scheduled && self.need_filling() {
614                if self.had_error_since_last_refill {
615                    self.refill_delay_strategy.on_fill_error();
616                } else {
617                    self.refill_delay_strategy.on_successful_fill();
618                }
619                let delay = self.refill_delay_strategy.get_delay();
620                debug!(
621                    "[{}] Scheduling next refill in {} ms",
622                    self.endpoint_description(),
623                    delay.as_millis(),
624                );
625
626                next_refill_time = tokio::time::Instant::now() + delay;
627                refill_scheduled = true;
628            }
629        }
630    }
631
632    fn is_filling(&self) -> bool {
633        !self.ready_connections.is_empty()
634    }
635
636    fn is_full(&self) -> bool {
637        match self.pool_config.pool_size {
638            PoolSize::PerHost(target) => self.active_connection_count() >= target.get(),
639            PoolSize::PerShard(target) => {
640                self.conns.iter().all(|conns| conns.len() >= target.get())
641            }
642        }
643    }
644
645    fn is_empty(&self) -> bool {
646        self.conns.iter().all(|conns| conns.is_empty())
647    }
648
649    fn need_filling(&self) -> bool {
650        !self.is_filling() && !self.is_full()
651    }
652
653    fn can_use_shard_aware_port(&self) -> bool {
654        self.sharder.is_some()
655            && self.shard_aware_port.is_some()
656            && self.pool_config.can_use_shard_aware_port
657    }
658
659    // Begins opening a number of connections in order to fill the connection pool.
660    // Futures which open the connections are pushed to the `ready_connections`
661    // FuturesUnordered structure, and their results are processed in the main loop.
662    fn start_filling(&mut self) {
663        if self.is_empty() {
664            // If the pool is empty, it might mean that the node is not alive.
665            // It is more likely than not that the next connection attempt will
666            // fail, so there is no use in opening more than one connection now.
667            trace!(
668                "[{}] Will open the first connection to the node",
669                self.endpoint_description()
670            );
671            self.start_opening_connection(None);
672            return;
673        }
674
675        if self.can_use_shard_aware_port() {
676            // Only use the shard-aware port if we have a PerShard strategy
677            if let PoolSize::PerShard(target) = self.pool_config.pool_size {
678                // Try to fill up each shard up to `target` connections
679                for (shard_id, shard_conns) in self.conns.iter().enumerate() {
680                    let to_open_count = target.get().saturating_sub(shard_conns.len());
681                    if to_open_count == 0 {
682                        continue;
683                    }
684                    trace!(
685                        "[{}] Will open {} connections to shard {}",
686                        self.endpoint_description(),
687                        to_open_count,
688                        shard_id,
689                    );
690                    for _ in 0..to_open_count {
691                        self.start_opening_connection(Some(shard_id as Shard));
692                    }
693                }
694                return;
695            }
696        }
697        // Calculate how many more connections we need to open in order
698        // to achieve the target connection count.
699        let to_open_count = match self.pool_config.pool_size {
700            PoolSize::PerHost(target) => {
701                target.get().saturating_sub(self.active_connection_count())
702            }
703            PoolSize::PerShard(target) => self
704                .conns
705                .iter()
706                .map(|conns| target.get().saturating_sub(conns.len()))
707                .sum::<usize>(),
708        };
709        // When connecting to Scylla through non-shard-aware port,
710        // Scylla alone will choose shards for us. We hope that
711        // they will distribute across shards in the way we want,
712        // but we have no guarantee, so we might have to retry
713        // connecting later.
714        trace!(
715            "[{}] Will open {} non-shard-aware connections",
716            self.endpoint_description(),
717            to_open_count,
718        );
719        for _ in 0..to_open_count {
720            self.start_opening_connection(None);
721        }
722    }
723
724    // Handles a newly opened connection and decides what to do with it.
725    fn handle_ready_connection(&mut self, evt: OpenedConnectionEvent) {
726        match evt.result {
727            Err(err) => {
728                if evt.requested_shard.is_some() {
729                    // If we failed to connect to a shard-aware port,
730                    // fall back to the non-shard-aware port.
731                    // Don't set `had_error_since_last_refill` here;
732                    // the shard-aware port might be unreachable, but
733                    // the regular port might be reachable. If we set
734                    // `had_error_since_last_refill` here, it would cause
735                    // the backoff to increase on each refill. With
736                    // the non-shard aware port, multiple refills are sometimes
737                    // necessary, so increasing the backoff would delay
738                    // filling the pool even if the non-shard-aware port works
739                    // and does not cause any errors.
740                    debug!(
741                        "[{}] Failed to open connection to the shard-aware port: {:?}, will retry with regular port",
742                        self.endpoint_description(),
743                        err,
744                    );
745                    self.start_opening_connection(None);
746                } else {
747                    // Encountered an error while connecting to the non-shard-aware
748                    // port. Set the `had_error_since_last_refill` flag so that
749                    // the next refill will be delayed more than this one.
750                    self.had_error_since_last_refill = true;
751                    debug!(
752                        "[{}] Failed to open connection to the non-shard-aware port: {:?}",
753                        self.endpoint_description(),
754                        err,
755                    );
756
757                    // If all connection attempts in this fill attempt failed
758                    // and the pool is empty, report this error.
759                    if !self.is_filling() && self.is_empty() {
760                        self.update_shared_conns(Some(err));
761                    }
762                }
763            }
764            Ok((connection, error_receiver)) => {
765                // Update sharding and optionally reshard
766                let shard_info = connection.get_shard_info().as_ref();
767                let sharder = shard_info.map(|s| s.get_sharder());
768                let shard_id = shard_info.map_or(0, |s| s.shard as usize);
769                self.maybe_reshard(sharder);
770
771                // Update the shard-aware port
772                if self.shard_aware_port != connection.get_shard_aware_port() {
773                    debug!(
774                        "[{}] Updating shard aware port: {:?}",
775                        self.endpoint_description(),
776                        connection.get_shard_aware_port(),
777                    );
778                    self.shard_aware_port = connection.get_shard_aware_port();
779                }
780
781                // Before the connection can be put to the pool, we need
782                // to make sure that it uses appropriate keyspace
783                if let Some(keyspace) = &self.current_keyspace {
784                    if evt.keyspace_name.as_ref() != Some(keyspace) {
785                        // Asynchronously start setting keyspace for this
786                        // connection. It will be received on the ready
787                        // connections channel and will travel through
788                        // this logic again, to be finally put into
789                        // the conns.
790                        self.start_setting_keyspace_for_connection(
791                            connection,
792                            error_receiver,
793                            evt.requested_shard,
794                        );
795                        return;
796                    }
797                }
798
799                // Decide if the connection can be accepted, according to
800                // the pool filling strategy
801                let can_be_accepted = match self.pool_config.pool_size {
802                    PoolSize::PerHost(target) => self.active_connection_count() < target.get(),
803                    PoolSize::PerShard(target) => self.conns[shard_id].len() < target.get(),
804                };
805
806                if can_be_accepted {
807                    // Don't complain and just put the connection to the pool.
808                    // If this was a shard-aware port connection which missed
809                    // the right shard, we still want to accept it
810                    // because it fills our pool.
811                    let conn = Arc::new(connection);
812                    trace!(
813                        "[{}] Adding connection {:p} to shard {} pool, now there are {} for the shard, total {}",
814                        self.endpoint_description(),
815                        Arc::as_ptr(&conn),
816                        shard_id,
817                        self.conns[shard_id].len() + 1,
818                        self.active_connection_count() + 1,
819                    );
820
821                    self.connection_errors
822                        .push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed());
823                    self.conns[shard_id].push(conn);
824
825                    self.update_shared_conns(None);
826                } else if evt.requested_shard.is_some() {
827                    // This indicates that some shard-aware connections
828                    // missed the target shard (probably due to NAT).
829                    // Because we don't know how address translation
830                    // works here, it's better to leave the task
831                    // of choosing the shard to Scylla. We will retry
832                    // immediately with a non-shard-aware port here.
833                    debug!(
834                        "[{}] Excess shard-aware port connection for shard {}; will retry with non-shard-aware port",
835                        self.endpoint_description(),
836                        shard_id,
837                    );
838
839                    self.start_opening_connection(None);
840                } else {
841                    // We got unlucky and Scylla didn't distribute
842                    // shards across connections evenly.
843                    // We will retry in the next iteration,
844                    // for now put it into the excess connection
845                    // pool.
846                    let conn = Arc::new(connection);
847                    trace!(
848                        "[{}] Storing excess connection {:p} for shard {}",
849                        self.endpoint_description(),
850                        Arc::as_ptr(&conn),
851                        shard_id,
852                    );
853
854                    self.connection_errors
855                        .push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed());
856                    self.excess_connections.push(conn);
857
858                    let excess_connection_limit = self.excess_connection_limit();
859                    if self.excess_connections.len() > excess_connection_limit {
860                        debug!(
861                            "[{}] Excess connection pool exceeded limit of {} connections - clearing",
862                            self.endpoint_description(),
863                            excess_connection_limit,
864                        );
865                        self.excess_connections.clear();
866                    }
867                }
868            }
869        }
870    }
871
872    // Starts opening a new connection in the background. The result of connecting
873    // will be available on `ready_connections`. If the shard is specified and
874    // the shard aware port is available, it will attempt to connect directly
875    // to the shard using the port.
876    fn start_opening_connection(&self, shard: Option<Shard>) {
877        let cfg = self.pool_config.connection_config.clone();
878        let mut endpoint = self.endpoint.read().unwrap().clone();
879
880        #[cfg(feature = "metrics")]
881        let count_in_metrics = {
882            let metrics = Arc::clone(&self.metrics);
883            move |connect_result: &Result<_, ConnectionError>| {
884                if connect_result.is_ok() {
885                    metrics.inc_total_connections();
886                } else if let Err(ConnectionError::ConnectTimeout) = &connect_result {
887                    metrics.inc_connection_timeouts();
888                }
889            }
890        };
891
892        let fut = match (self.sharder.clone(), self.shard_aware_port, shard) {
893            (Some(sharder), Some(port), Some(shard)) => async move {
894                let shard_aware_endpoint = {
895                    endpoint.set_port(port);
896                    endpoint
897                };
898                let result = open_connection_to_shard_aware_port(
899                    &shard_aware_endpoint,
900                    shard,
901                    sharder.clone(),
902                    &cfg,
903                )
904                .await;
905
906                #[cfg(feature = "metrics")]
907                count_in_metrics(&result);
908
909                OpenedConnectionEvent {
910                    result,
911                    requested_shard: Some(shard),
912                    keyspace_name: None,
913                }
914            }
915            .boxed(),
916            _ => async move {
917                let non_shard_aware_endpoint = endpoint;
918                let result = open_connection(&non_shard_aware_endpoint, None, &cfg).await;
919
920                #[cfg(feature = "metrics")]
921                count_in_metrics(&result);
922
923                OpenedConnectionEvent {
924                    result,
925                    requested_shard: None,
926                    keyspace_name: None,
927                }
928            }
929            .boxed(),
930        };
931        self.ready_connections.push(fut);
932    }
933
934    fn maybe_reshard(&mut self, new_sharder: Option<Sharder>) {
935        if self.sharder == new_sharder {
936            return;
937        }
938
939        debug!(
940            "[{}] New sharder: {:?}, clearing all connections",
941            self.endpoint_description(),
942            new_sharder,
943        );
944
945        self.sharder.clone_from(&new_sharder);
946
947        // If the sharder has changed, we can throw away all previous connections.
948        // All connections to the same live node will have the same sharder,
949        // so the old ones will become dead very soon anyway.
950        self.conns.clear();
951
952        let shard_count = new_sharder.map_or(1, |s| s.nr_shards.get() as usize);
953        self.conns.resize_with(shard_count, Vec::new);
954
955        self.excess_connections.clear();
956    }
957
958    // Updates `shared_conns` based on `conns`.
959    // `last_error` must not be `None` if there is a possibility of the pool
960    // being empty.
961    fn update_shared_conns(&mut self, last_error: Option<ConnectionError>) {
962        let new_conns = if self.is_empty() {
963            Arc::new(MaybePoolConnections::Broken(last_error.unwrap()))
964        } else {
965            let new_conns = if let Some(sharder) = self.sharder.as_ref() {
966                debug_assert_eq!(self.conns.len(), sharder.nr_shards.get() as usize);
967                PoolConnections::Sharded {
968                    sharder: sharder.clone(),
969                    connections: self.conns.clone(),
970                }
971            } else {
972                debug_assert_eq!(self.conns.len(), 1);
973                PoolConnections::NotSharded(self.conns[0].clone())
974            };
975            Arc::new(MaybePoolConnections::Ready(new_conns))
976        };
977
978        // Make the connection list available
979        self.shared_conns.store(new_conns);
980
981        // Notify potential waiters
982        self.pool_updated_notify.notify_waiters();
983    }
984
985    // Removes given connection from the pool. It looks both into active
986    // connections and excess connections.
987    fn remove_connection(&mut self, connection: Arc<Connection>, last_error: ConnectionError) {
988        let ptr = Arc::as_ptr(&connection);
989
990        let maybe_remove_in_vec = |v: &mut Vec<Arc<Connection>>| -> bool {
991            let maybe_idx = v
992                .iter()
993                .enumerate()
994                .find(|(_, other_conn)| Arc::ptr_eq(&connection, other_conn))
995                .map(|(idx, _)| idx);
996            match maybe_idx {
997                Some(idx) => {
998                    v.swap_remove(idx);
999                    #[cfg(feature = "metrics")]
1000                    self.metrics.dec_total_connections();
1001                    true
1002                }
1003                None => false,
1004            }
1005        };
1006
1007        // First, look it up in the shard bucket
1008        // We might have resharded, so the bucket might not exist anymore
1009        let shard_id = connection
1010            .get_shard_info()
1011            .as_ref()
1012            .map_or(0, |s| s.shard as usize);
1013        if shard_id < self.conns.len() && maybe_remove_in_vec(&mut self.conns[shard_id]) {
1014            trace!(
1015                "[{}] Connection {:p} removed from shard {} pool, now there is {} for the shard, total {}",
1016                self.endpoint_description(),
1017                ptr,
1018                shard_id,
1019                self.conns[shard_id].len(),
1020                self.active_connection_count(),
1021            );
1022            if self.is_empty() {
1023                let _ = self.pool_empty_notifier.send(());
1024            }
1025            self.update_shared_conns(Some(last_error));
1026            return;
1027        }
1028
1029        // If we didn't find it, it might sit in the excess_connections bucket
1030        if maybe_remove_in_vec(&mut self.excess_connections) {
1031            trace!(
1032                "[{}] Connection {:p} removed from excess connection pool",
1033                self.endpoint_description(),
1034                ptr,
1035            );
1036            return;
1037        }
1038
1039        trace!(
1040            "[{}] Connection {:p} was already removed",
1041            self.endpoint_description(),
1042            ptr,
1043        );
1044    }
1045
1046    // Sets current keyspace for available connections.
1047    // Connections which are being currently opened and future connections
1048    // will have this keyspace set when they appear on `ready_connections`.
1049    // Sends response to the `response_sender` when all current connections
1050    // have their keyspace set.
1051    fn use_keyspace(
1052        &mut self,
1053        keyspace_name: VerifiedKeyspaceName,
1054        response_sender: tokio::sync::oneshot::Sender<Result<(), UseKeyspaceError>>,
1055    ) {
1056        self.current_keyspace = Some(keyspace_name.clone());
1057
1058        let mut conns = self.conns.clone();
1059        let address = self.endpoint.read().unwrap().address();
1060        let connect_timeout = self.pool_config.connection_config.connect_timeout;
1061
1062        let fut = async move {
1063            let mut use_keyspace_futures = Vec::new();
1064
1065            for shard_conns in conns.iter_mut() {
1066                for conn in shard_conns.iter_mut() {
1067                    let fut = conn.use_keyspace(&keyspace_name);
1068                    use_keyspace_futures.push(fut);
1069                }
1070            }
1071
1072            if use_keyspace_futures.is_empty() {
1073                return Ok(());
1074            }
1075
1076            let use_keyspace_results: Vec<Result<(), UseKeyspaceError>> = tokio::time::timeout(
1077                connect_timeout,
1078                futures::future::join_all(use_keyspace_futures),
1079            )
1080            .await
1081            // FIXME: We could probably make USE KEYSPACE request timeout configurable in the future.
1082            .map_err(|_| UseKeyspaceError::RequestTimeout(connect_timeout))?;
1083
1084            crate::cluster::use_keyspace_result(use_keyspace_results.into_iter())
1085        };
1086
1087        tokio::task::spawn(async move {
1088            let res = fut.await;
1089            match &res {
1090                Ok(()) => debug!("[{}] Successfully changed current keyspace", address),
1091                Err(err) => warn!("[{}] Failed to change keyspace: {:?}", address, err),
1092            }
1093            let _ = response_sender.send(res);
1094        });
1095    }
1096
1097    // Requires the keyspace to be set
1098    // Requires that the event is for a successful connection
1099    fn start_setting_keyspace_for_connection(
1100        &mut self,
1101        connection: Connection,
1102        error_receiver: ErrorReceiver,
1103        requested_shard: Option<Shard>,
1104    ) {
1105        // TODO: There should be a timeout for this
1106
1107        let keyspace_name = self.current_keyspace.as_ref().cloned().unwrap();
1108        self.ready_connections.push(
1109            async move {
1110                let result = connection.use_keyspace(&keyspace_name).await;
1111                if let Err(err) = result {
1112                    warn!(
1113                        "[{}] Failed to set keyspace for new connection: {}",
1114                        connection.get_connect_address().ip(),
1115                        err,
1116                    );
1117                }
1118                OpenedConnectionEvent {
1119                    result: Ok((connection, error_receiver)),
1120                    requested_shard,
1121                    keyspace_name: Some(keyspace_name),
1122                }
1123            }
1124            .boxed(),
1125        );
1126    }
1127
1128    fn active_connection_count(&self) -> usize {
1129        self.conns.iter().map(Vec::len).sum::<usize>()
1130    }
1131
1132    fn excess_connection_limit(&self) -> usize {
1133        match self.pool_config.pool_size {
1134            PoolSize::PerShard(_) => {
1135                EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER
1136                    * self
1137                        .sharder
1138                        .as_ref()
1139                        .map_or(1, |s| s.nr_shards.get() as usize)
1140            }
1141
1142            // In PerHost mode we do not need to keep excess connections
1143            PoolSize::PerHost(_) => 0,
1144        }
1145    }
1146}
1147
1148struct BrokenConnectionEvent {
1149    connection: Weak<Connection>,
1150    error: ConnectionError,
1151}
1152
1153async fn wait_for_error(
1154    connection: Weak<Connection>,
1155    error_receiver: ErrorReceiver,
1156) -> BrokenConnectionEvent {
1157    BrokenConnectionEvent {
1158        connection,
1159        error: error_receiver.await.unwrap_or_else(|_| {
1160            ConnectionError::BrokenConnection(BrokenConnectionErrorKind::ChannelError.into())
1161        }),
1162    }
1163}
1164
1165struct OpenedConnectionEvent {
1166    result: Result<(Connection, ErrorReceiver), ConnectionError>,
1167    requested_shard: Option<Shard>,
1168    keyspace_name: Option<VerifiedKeyspaceName>,
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173    use super::super::connection::{open_connection_to_shard_aware_port, HostConnectionConfig};
1174    use crate::cluster::metadata::UntranslatedEndpoint;
1175    use crate::cluster::node::ResolvedContactPoint;
1176    use crate::routing::{ShardCount, Sharder};
1177    use crate::test_utils::setup_tracing;
1178    use std::net::{SocketAddr, ToSocketAddrs};
1179
1180    // Open many connections to a node
1181    // Port collision should occur
1182    // If they are not handled this test will most likely fail
1183    #[tokio::test]
1184    #[cfg(not(scylla_cloud_tests))]
1185    async fn many_connections() {
1186        setup_tracing();
1187        let connections_number = 512;
1188
1189        let connect_address: SocketAddr = std::env::var("SCYLLA_URI")
1190            .unwrap_or_else(|_| "127.0.0.1:9042".to_string())
1191            .to_socket_addrs()
1192            .unwrap()
1193            .next()
1194            .unwrap();
1195
1196        let connection_config = HostConnectionConfig {
1197            compression: None,
1198            tcp_nodelay: true,
1199            tls_config: None,
1200            ..Default::default()
1201        };
1202
1203        // This does not have to be the real sharder,
1204        // the test is only about port collisions, not connecting
1205        // to the right shard
1206        let sharder = Sharder::new(ShardCount::new(3).unwrap(), 12);
1207
1208        let endpoint = UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
1209            address: connect_address,
1210            datacenter: None,
1211        });
1212
1213        // Open the connections
1214        let conns = (0..connections_number).map(|_| {
1215            open_connection_to_shard_aware_port(&endpoint, 0, sharder.clone(), &connection_config)
1216        });
1217
1218        let joined = futures::future::join_all(conns).await;
1219
1220        // Check that each connection managed to connect successfully
1221        for res in joined {
1222            res.unwrap();
1223        }
1224    }
1225}