Skip to content

Commit

Permalink
scylla: Add support for rustls Fixes scylladb#293
Browse files Browse the repository at this point in the history
  • Loading branch information
nemosupremo committed Jan 11, 2024
1 parent befa148 commit 8edd689
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 10 deletions.
6 changes: 4 additions & 2 deletions scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[features]
default = []
default = ["rustls"]
ssl = ["dep:tokio-openssl", "dep:openssl"]
cloud = ["ssl", "scylla-cql/serde", "dep:serde_yaml", "dep:serde", "dep:url", "dep:base64"]
rustls = ["dep:tokio-rustls"]
cloud = ["scylla-cql/serde", "dep:serde_yaml", "dep:serde", "dep:url", "dep:base64"]
secret = ["scylla-cql/secret"]
chrono = ["scylla-cql/chrono"]
time = ["scylla-cql/time"]
Expand All @@ -42,6 +43,7 @@ tracing = "0.1.36"
chrono = { version = "0.4.20", default-features = false, features = ["clock"] }
openssl = { version = "0.10.32", optional = true }
tokio-openssl = { version = "0.6.1", optional = true }
tokio-rustls = { version = "0.25", optional = true }
arc-swap = "1.3.0"
dashmap = "5.2"
strum = "0.23"
Expand Down
6 changes: 6 additions & 0 deletions scylla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,9 @@ pub use transport::retry_policy;
pub use transport::speculative_execution;

pub use transport::metrics::Metrics;

#[cfg(all(feature = "ssl", feature = "rustls"))]
compile_error!("both rustls and ssl should not be enabled together.");

#[cfg(all(feature = "cloud", not(any(feature = "ssl", feature = "rustls"))))]
compile_error!("cloud feature requires either the rustls or ssl feature.");
96 changes: 90 additions & 6 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ use std::sync::atomic::AtomicU64;
use std::time::Duration;
#[cfg(feature = "ssl")]
use tokio_openssl::SslStream;
#[cfg(feature = "rustls")]
use tokio_rustls::TlsConnector;

#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
pub(crate) use ssl_config::SslConfig;

use crate::authentication::AuthenticatorProvider;
Expand Down Expand Up @@ -280,12 +282,19 @@ impl NonErrorQueryResponse {
})
}
}
#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
mod ssl_config {
#[cfg(feature = "ssl")]
use openssl::{
error::ErrorStack,
ssl::{Ssl, SslContext},
};
#[cfg(feature = "rustls")]
use std::{net::IpAddr, sync::Arc};
#[cfg(feature = "rustls")]
use tokio_rustls::rustls::pki_types::ServerName;
#[cfg(feature = "rustls")]
use tokio_rustls::rustls::ClientConfig;
#[cfg(feature = "cloud")]
use uuid::Uuid;

Expand All @@ -299,13 +308,15 @@ mod ssl_config {
// NodeConnectionPool::new(). Inside that function, the field is mutated to contain SslConfig specific
// for the particular node. (The SslConfig must be different, because SNIs differ for different nodes.)
// Thenceforth, all connections to that node share the same SslConfig.
#[cfg(feature = "ssl")]
#[derive(Clone)]
pub struct SslConfig {
context: SslContext,
#[cfg(feature = "cloud")]
sni: Option<String>,
}

#[cfg(feature = "ssl")]
impl SslConfig {
// Used in case when the user provided their own SslContext to be used in all connections.
pub fn new_with_global_context(context: SslContext) -> Self {
Expand Down Expand Up @@ -345,14 +356,66 @@ mod ssl_config {
Ok(ssl)
}
}

#[cfg(feature = "rustls")]
#[derive(Clone)]
pub struct SslConfig {
config: Arc<ClientConfig>,
#[cfg(feature = "cloud")]
sni: Option<ServerName<'static>>,
}

impl SslConfig {
// Used in case when the user provided their own ClientConfig to be used in all connections.
pub fn new_with_global_config(config: &Arc<ClientConfig>) -> Self {
Self {
config: config.clone(),
#[cfg(feature = "cloud")]
sni: None,
}
}

// Used in case of Serverless Cloud connections.
#[cfg(feature = "cloud")]
pub(crate) fn new_for_sni(
config: &Arc<ClientConfig>,
domain_name: &str,
host_id: Option<Uuid>,
) -> Self {
Self {
config: config.clone(),
#[cfg(feature = "cloud")]
sni: Some(if let Some(host_id) = host_id {
ServerName::try_from(&format!("{}.{}", host_id, domain_name))
.expect("invalid DNS name")
.to_owned()
} else {
ServerName::try_from(domain_name.into().expect("invalid DNS name")).to_owned()
}),
}
}

pub(crate) fn server_name(&self, node_addr: IpAddr) -> ServerName<'static> {
#[cfg(feature = "cloud")]
if let Some(sni) = self.sni.as_ref() {
return sni.clone();
}
ServerName::IpAddress(node_addr.into())
}

// A reference to the rustls Client Config to produce a TlsConnection
pub(crate) fn config(&self) -> &Arc<ClientConfig> {
&self.config
}
}
}

#[derive(Clone)]
pub struct ConnectionConfig {
pub compression: Option<Compression>,
pub tcp_nodelay: bool,
pub tcp_keepalive_interval: Option<Duration>,
#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
pub ssl_config: Option<SslConfig>,
pub connect_timeout: std::time::Duration,
// should be Some only in control connections,
Expand All @@ -375,7 +438,7 @@ impl Default for ConnectionConfig {
tcp_nodelay: true,
tcp_keepalive_interval: None,
event_sender: None,
#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
ssl_config: None,
connect_timeout: std::time::Duration::from_secs(5),
default_consistency: Default::default(),
Expand All @@ -393,7 +456,7 @@ impl Default for ConnectionConfig {
}

impl ConnectionConfig {
#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
pub fn is_ssl(&self) -> bool {
#[cfg(feature = "cloud")]
if self.cloud_config.is_some() {
Expand All @@ -402,7 +465,7 @@ impl ConnectionConfig {
self.ssl_config.is_some()
}

#[cfg(not(feature = "ssl"))]
#[cfg(not(any(feature = "ssl", feature = "rustls")))]
pub fn is_ssl(&self) -> bool {
false
}
Expand Down Expand Up @@ -1034,6 +1097,27 @@ impl Connection {
return Ok(handle);
}

#[cfg(feature = "rustls")]
if let Some(rustls_config) = &config.ssl_config {
let connector = TlsConnector::from(rustls_config.config().clone());
let stream = connector
.connect(rustls_config.server_name(node_address), stream)
.await?;

let (task, handle) = Self::router(
config,
stream,
receiver,
error_sender,
orphan_notification_receiver,
router_handle,
node_address,
)
.remote_handle();
tokio::task::spawn(task.with_current_subscriber());
return Ok(handle);
}

let (task, handle) = Self::router(
config,
stream,
Expand Down
2 changes: 1 addition & 1 deletion scylla/src/transport/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ mod tests {
let connection_config = ConnectionConfig {
compression: None,
tcp_nodelay: true,
#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
ssl_config: None,
..Default::default()
};
Expand Down
15 changes: 14 additions & 1 deletion scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use uuid::Uuid;

use super::connection::NonErrorQueryResponse;
use super::connection::QueryResponse;
#[cfg(feature = "ssl")]
#[cfg(any(feature = "ssl", feature = "rustls"))]
use super::connection::SslConfig;
use super::errors::{NewSessionError, QueryError};
use super::execution_profile::{ExecutionProfile, ExecutionProfileHandle, ExecutionProfileInner};
Expand Down Expand Up @@ -77,6 +77,8 @@ use crate::authentication::AuthenticatorProvider;
#[cfg(feature = "ssl")]
use openssl::ssl::SslContext;
use scylla_cql::errors::BadQuery;
#[cfg(feature = "rustls")]
use tokio_rustls::rustls::ClientConfig;

/// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses.
///
Expand Down Expand Up @@ -196,6 +198,10 @@ pub struct SessionConfig {
#[cfg(feature = "ssl")]
pub ssl_context: Option<SslContext>,

/// Provide our Session with TLS
#[cfg(feature = "rustls")]
pub rustls_config: Option<Arc<ClientConfig>>,

pub authenticator: Option<Arc<dyn AuthenticatorProvider>>,

pub connect_timeout: Duration,
Expand Down Expand Up @@ -312,6 +318,8 @@ impl SessionConfig {
keyspace_case_sensitive: false,
#[cfg(feature = "ssl")]
ssl_context: None,
#[cfg(feature = "rustls")]
rustls_config: None,
authenticator: None,
connect_timeout: Duration::from_secs(5),
connection_pool_size: Default::default(),
Expand Down Expand Up @@ -499,6 +507,11 @@ impl Session {
tcp_keepalive_interval: config.tcp_keepalive_interval,
#[cfg(feature = "ssl")]
ssl_config: config.ssl_context.map(SslConfig::new_with_global_context),
#[cfg(feature = "rustls")]
ssl_config: config
.rustls_config
.as_ref()
.map(SslConfig::new_with_global_config),
authenticator: config.authenticator.clone(),
connect_timeout: config.connect_timeout,
event_sender: None,
Expand Down
34 changes: 34 additions & 0 deletions scylla/src/transport/session_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use std::time::Duration;
use crate::authentication::{AuthenticatorProvider, PlainTextAuthenticator};
#[cfg(feature = "ssl")]
use openssl::ssl::SslContext;
#[cfg(feature = "rustls")]
use tokio_rustls::rustls::ClientConfig;
use tracing::warn;

mod sealed {
Expand Down Expand Up @@ -334,6 +336,38 @@ impl GenericSessionBuilder<DefaultMode> {
self.config.ssl_context = ssl_context;
self
}

/// rustls feature
/// Provide SessionBuilder with ClientConfig from rustls crate that will be
/// used to create an ssl connection to the database.
/// If set to None SSL connection won't be used.
/// Default is None.
///
/// # Example
/// ```
/// # use std::fs;
/// # use std::path::PathBuf;
/// # use scylla::{Session, SessionBuilder};
/// # use openssl::ssl::{SslContextBuilder, SslVerifyMode, SslMethod, SslFiletype};
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let certdir = fs::canonicalize(PathBuf::from("./examples/certs/scylla.crt"))?;
/// let mut context_builder = SslContextBuilder::new(SslMethod::tls())?;
/// context_builder.set_certificate_file(certdir.as_path(), SslFiletype::PEM)?;
/// context_builder.set_verify(SslVerifyMode::NONE);
///
/// let session: Session = SessionBuilder::new()
/// .known_node("127.0.0.1:9042")
/// .ssl_context(Some(context_builder.build()))
/// .build()
/// .await?;
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "rustls")]
pub fn rustls_config(mut self, config: Option<Arc<ClientConfig>>) -> Self {
self.config.rustls_config = config;
self
}
}

// NOTE: this `impl` block contains configuration options specific for **Cloud** [`Session`].
Expand Down

0 comments on commit 8edd689

Please sign in to comment.