1use std::io;
22#[cfg(feature = "unstable-cloud")]
23use std::sync::Arc;
24
25#[cfg(feature = "unstable-cloud")]
26use tracing::warn;
27#[cfg(feature = "unstable-cloud")]
28use uuid::Uuid;
29
30use crate::client::session::TlsContext;
31#[cfg(feature = "unstable-cloud")]
32use crate::cloud::CloudConfig;
33#[cfg(feature = "unstable-cloud")]
34use crate::cluster::metadata::PeerEndpoint;
35use crate::cluster::metadata::UntranslatedEndpoint;
36#[cfg(feature = "unstable-cloud")]
37use crate::cluster::node::ResolvedContactPoint;
38
39#[derive(Clone)] pub(crate) enum TlsProvider {
42 GlobalContext(TlsContext),
43 #[cfg(feature = "unstable-cloud")]
44 ScyllaCloud(Arc<CloudConfig>),
45}
46
47impl TlsProvider {
48 pub(crate) fn new_with_global_context(context: TlsContext) -> Self {
50 Self::GlobalContext(context)
51 }
52
53 #[cfg(feature = "unstable-cloud")]
55 pub(crate) fn new_cloud(cloud_config: Arc<CloudConfig>) -> Self {
56 Self::ScyllaCloud(cloud_config)
57 }
58
59 pub(crate) fn make_tls_config(
61 &self,
62 #[allow(unused)] endpoint: &UntranslatedEndpoint,
65 ) -> Option<TlsConfig> {
66 match self {
67 TlsProvider::GlobalContext(context) => {
68 Some(TlsConfig::new_with_global_context(context.clone()))
69 }
70 #[cfg(feature = "unstable-cloud")]
71 TlsProvider::ScyllaCloud(cloud_config) => {
72 let (host_id, address, dc) = match *endpoint {
73 UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
74 address,
75 ref datacenter,
76 }) => (None, address, datacenter.as_deref()), UntranslatedEndpoint::Peer(PeerEndpoint {
78 host_id,
79 address,
80 ref datacenter,
81 ..
82 }) => (Some(host_id), address.into_inner(), datacenter.as_deref()),
83 };
84
85 cloud_config.make_tls_config_for_scylla_cloud_host(host_id, dc, address)
86 .map_err(|err| {
89 warn!(
90 "TlsProvider for SNI connection to Scylla Cloud node {{ host_id={:?}, dc={:?} at {} }} could not be set up: {}\n Proceeding with attempting probably nonworking connection",
91 host_id,
92 dc,
93 address,
94 err
95 );
96 }).ok().flatten()
97 }
98 }
99 }
100}
101
102#[derive(Clone)]
109pub(crate) struct TlsConfig {
110 context: TlsContext,
111 #[cfg(feature = "unstable-cloud")]
112 sni: Option<String>,
113}
114
115pub(crate) enum Tls {
117 #[cfg(feature = "openssl-010")]
118 OpenSsl010(openssl::ssl::Ssl),
119 #[cfg(feature = "rustls-023")]
120 Rustls023 {
121 connector: tokio_rustls::TlsConnector,
122 #[cfg(feature = "unstable-cloud")]
123 sni: Option<rustls::pki_types::ServerName<'static>>,
124 },
125}
126
127#[derive(Debug, thiserror::Error)]
131#[error(transparent)]
132#[non_exhaustive]
133pub enum TlsError {
134 #[cfg(feature = "openssl-010")]
135 OpenSsl010(#[from] openssl::error::ErrorStack),
136 #[cfg(feature = "rustls-023")]
137 InvalidName(#[from] rustls::pki_types::InvalidDnsNameError),
138 #[cfg(feature = "rustls-023")]
139 PemParse(#[from] rustls::pki_types::pem::Error),
140 #[cfg(feature = "rustls-023")]
141 Rustls023(#[from] rustls::Error),
142}
143
144impl From<TlsError> for io::Error {
145 fn from(value: TlsError) -> Self {
146 match value {
147 #[cfg(feature = "openssl-010")]
148 TlsError::OpenSsl010(e) => e.into(),
149 #[cfg(feature = "rustls-023")]
150 TlsError::InvalidName(e) => io::Error::new(io::ErrorKind::Other, e),
151 #[cfg(feature = "rustls-023")]
152 TlsError::PemParse(e) => io::Error::new(io::ErrorKind::Other, e),
153 #[cfg(feature = "rustls-023")]
154 TlsError::Rustls023(e) => io::Error::new(io::ErrorKind::Other, e),
155 }
156 }
157}
158
159impl TlsConfig {
160 pub(crate) fn new_with_global_context(context: TlsContext) -> Self {
162 Self {
163 context,
164 #[cfg(feature = "unstable-cloud")]
165 sni: None,
166 }
167 }
168
169 #[cfg(feature = "unstable-cloud")]
171 pub(crate) fn new_for_sni(
172 context: TlsContext,
173 domain_name: &str,
174 host_id: Option<Uuid>,
175 ) -> Self {
176 Self {
177 context,
178 #[cfg(feature = "unstable-cloud")]
179 sni: Some(if let Some(host_id) = host_id {
180 format!("{}.{}", host_id, domain_name)
181 } else {
182 domain_name.into()
183 }),
184 }
185 }
186
187 pub(crate) fn new_tls(&self) -> Result<Tls, TlsError> {
189 #[allow(unreachable_code)]
191 match self.context {
192 #[cfg(feature = "openssl-010")]
193 TlsContext::OpenSsl010(ref context) => {
194 #[allow(unused_mut)]
195 let mut ssl = openssl::ssl::Ssl::new(context)?;
196 #[cfg(feature = "unstable-cloud")]
197 if let Some(sni) = self.sni.as_ref() {
198 ssl.set_hostname(sni)?;
199 }
200 Ok(Tls::OpenSsl010(ssl))
201 }
202 #[cfg(feature = "rustls-023")]
203 TlsContext::Rustls023(ref config) => {
204 let connector = tokio_rustls::TlsConnector::from(config.clone());
205 #[cfg(feature = "unstable-cloud")]
206 let sni = self
207 .sni
208 .as_deref()
209 .map(rustls::pki_types::ServerName::try_from)
210 .transpose()?
211 .map(|s| s.to_owned());
212
213 Ok(Tls::Rustls023 {
214 connector,
215 #[cfg(feature = "unstable-cloud")]
216 sni,
217 })
218 }
219 }
220 }
221}