scylla/network/
tls.rs

1//! This module contains abstractions related to the TLS layer of driver connections.
2//!
3//! The full picture looks like this:
4//!
5//! ┌─←─ TlsContext (openssl::SslContext / rustls::ClientConfig)
6//! │
7//! ├─←─ CloudConfig (powered by either TLS backend)
8//! │
9//! │ gets wrapped in
10//! │
11//! ↳TlsProvider (same for all connections)
12//!   │
13//!   │ produces
14//!   │
15//!   ↳TlsConfig (specific for the particular connection)
16//!     │
17//!     │ produces
18//!     │
19//!     ↳Tls (wrapper over TCP stream which adds encryption)
20
21use std::io;
22#[cfg(feature = "unstable-cloud")]
23use std::sync::Arc;
24
25#[cfg(feature = "unstable-cloud")]
26use tracing::warn;
27#[cfg(feature = "unstable-cloud")]
28use uuid::Uuid;
29
30use crate::client::session::TlsContext;
31#[cfg(feature = "unstable-cloud")]
32use crate::cloud::CloudConfig;
33#[cfg(feature = "unstable-cloud")]
34use crate::cluster::metadata::PeerEndpoint;
35use crate::cluster::metadata::UntranslatedEndpoint;
36#[cfg(feature = "unstable-cloud")]
37use crate::cluster::node::ResolvedContactPoint;
38
39/// Abstraction capable of producing [TlsConfig] for connections on-demand.
40#[derive(Clone)] // Cheaply clonable (reference-counted)
41pub(crate) enum TlsProvider {
42    GlobalContext(TlsContext),
43    #[cfg(feature = "unstable-cloud")]
44    ScyllaCloud(Arc<CloudConfig>),
45}
46
47impl TlsProvider {
48    /// Used in case when the user provided their own [TlsContext] to be used in all connections.
49    pub(crate) fn new_with_global_context(context: TlsContext) -> Self {
50        Self::GlobalContext(context)
51    }
52
53    /// Used in the cloud case.
54    #[cfg(feature = "unstable-cloud")]
55    pub(crate) fn new_cloud(cloud_config: Arc<CloudConfig>) -> Self {
56        Self::ScyllaCloud(cloud_config)
57    }
58
59    /// Produces a [TlsConfig] that is specific for the given endpoint.
60    pub(crate) fn make_tls_config(
61        &self,
62        // Currently, this is only used for cloud; but it makes abstract sense to pass endpoint here
63        // also for non-cloud cases, so let's just allow(unused).
64        #[allow(unused)] endpoint: &UntranslatedEndpoint,
65    ) -> Option<TlsConfig> {
66        match self {
67            TlsProvider::GlobalContext(context) => {
68                Some(TlsConfig::new_with_global_context(context.clone()))
69            }
70            #[cfg(feature = "unstable-cloud")]
71            TlsProvider::ScyllaCloud(cloud_config) => {
72                let (host_id, address, dc) = match *endpoint {
73                    UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
74                        address,
75                        ref datacenter,
76                    }) => (None, address, datacenter.as_deref()), // FIXME: Pass DC in ContactPoint
77                    UntranslatedEndpoint::Peer(PeerEndpoint {
78                        host_id,
79                        address,
80                        ref datacenter,
81                        ..
82                    }) => (Some(host_id), address.into_inner(), datacenter.as_deref()),
83                };
84
85                cloud_config.make_tls_config_for_scylla_cloud_host(host_id, dc, address)
86                    // inspect_err() is stable since 1.76.
87                    // TODO: use inspect_err once we bump MSRV to at least 1.76.
88                    .map_err(|err| {
89                        warn!(
90                            "TlsProvider for SNI connection to Scylla Cloud node {{ host_id={:?}, dc={:?} at {} }} could not be set up: {}\n Proceeding with attempting probably nonworking connection",
91                            host_id,
92                            dc,
93                            address,
94                            err
95                        );
96                    }).ok().flatten()
97            }
98        }
99    }
100}
101
102/// Encapsulates TLS-regarding configuration that is specific for a particular endpoint.
103///
104/// Both use cases are supported:
105/// 1. User-provided global TlsContext. Then, the global TlsContext is simply cloned here.
106/// 2. Serverless Cloud. Then the TlsContext is customized for the given endpoint,
107///    and its SNI information is stored alongside.
108#[derive(Clone)]
109pub(crate) struct TlsConfig {
110    context: TlsContext,
111    #[cfg(feature = "unstable-cloud")]
112    sni: Option<String>,
113}
114
115/// An abstraction over connection's TLS layer which holds its state and configuration.
116pub(crate) enum Tls {
117    #[cfg(feature = "openssl-010")]
118    OpenSsl010(openssl::ssl::Ssl),
119    #[cfg(feature = "rustls-023")]
120    Rustls023 {
121        connector: tokio_rustls::TlsConnector,
122        #[cfg(feature = "unstable-cloud")]
123        sni: Option<rustls::pki_types::ServerName<'static>>,
124    },
125}
126
127/// A wrapper around a TLS error.
128///
129/// The original error came from one of the supported TLS backends.
130#[derive(Debug, thiserror::Error)]
131#[error(transparent)]
132#[non_exhaustive]
133pub enum TlsError {
134    #[cfg(feature = "openssl-010")]
135    OpenSsl010(#[from] openssl::error::ErrorStack),
136    #[cfg(feature = "rustls-023")]
137    InvalidName(#[from] rustls::pki_types::InvalidDnsNameError),
138    #[cfg(feature = "rustls-023")]
139    PemParse(#[from] rustls::pki_types::pem::Error),
140    #[cfg(feature = "rustls-023")]
141    Rustls023(#[from] rustls::Error),
142}
143
144impl From<TlsError> for io::Error {
145    fn from(value: TlsError) -> Self {
146        match value {
147            #[cfg(feature = "openssl-010")]
148            TlsError::OpenSsl010(e) => e.into(),
149            #[cfg(feature = "rustls-023")]
150            TlsError::InvalidName(e) => io::Error::new(io::ErrorKind::Other, e),
151            #[cfg(feature = "rustls-023")]
152            TlsError::PemParse(e) => io::Error::new(io::ErrorKind::Other, e),
153            #[cfg(feature = "rustls-023")]
154            TlsError::Rustls023(e) => io::Error::new(io::ErrorKind::Other, e),
155        }
156    }
157}
158
159impl TlsConfig {
160    /// Used in case when the user provided their own TlsContext to be used in all connections.
161    pub(crate) fn new_with_global_context(context: TlsContext) -> Self {
162        Self {
163            context,
164            #[cfg(feature = "unstable-cloud")]
165            sni: None,
166        }
167    }
168
169    /// Used in case of Serverless Cloud connections.
170    #[cfg(feature = "unstable-cloud")]
171    pub(crate) fn new_for_sni(
172        context: TlsContext,
173        domain_name: &str,
174        host_id: Option<Uuid>,
175    ) -> Self {
176        Self {
177            context,
178            #[cfg(feature = "unstable-cloud")]
179            sni: Some(if let Some(host_id) = host_id {
180                format!("{}.{}", host_id, domain_name)
181            } else {
182                domain_name.into()
183            }),
184        }
185    }
186
187    /// Produces a new Tls object that is able to wrap a TCP stream.
188    pub(crate) fn new_tls(&self) -> Result<Tls, TlsError> {
189        // To silence warnings when TlsContext is an empty enum (tls features are disabled).
190        #[allow(unreachable_code)]
191        match self.context {
192            #[cfg(feature = "openssl-010")]
193            TlsContext::OpenSsl010(ref context) => {
194                #[allow(unused_mut)]
195                let mut ssl = openssl::ssl::Ssl::new(context)?;
196                #[cfg(feature = "unstable-cloud")]
197                if let Some(sni) = self.sni.as_ref() {
198                    ssl.set_hostname(sni)?;
199                }
200                Ok(Tls::OpenSsl010(ssl))
201            }
202            #[cfg(feature = "rustls-023")]
203            TlsContext::Rustls023(ref config) => {
204                let connector = tokio_rustls::TlsConnector::from(config.clone());
205                #[cfg(feature = "unstable-cloud")]
206                let sni = self
207                    .sni
208                    .as_deref()
209                    .map(rustls::pki_types::ServerName::try_from)
210                    .transpose()?
211                    .map(|s| s.to_owned());
212
213                Ok(Tls::Rustls023 {
214                    connector,
215                    #[cfg(feature = "unstable-cloud")]
216                    sni,
217                })
218            }
219        }
220    }
221}