diff --git a/Cargo.toml b/Cargo.toml index 7ac3088b..68c01c36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,13 +34,15 @@ async = ["aerospike-core"] sync = ["aerospike-sync"] rt-tokio = ["aerospike-core/rt-tokio", "aerospike-macro/rt-tokio"] rt-async-std = ["aerospike-core/rt-async-std", "aerospike-macro/rt-async-std"] +tokio-rustls = ["aerospike-core/tokio-rustls"] +tokio-native-tls = ["aerospike-core/tokio-native-tls"] [[bench]] name = "client_server" harness = false [workspace] -members = ["tools/benchmark", "aerospike-core", "aerospike-rt", "aerospike-sync", "aerospike-macro"] +members = ["tools/benchmark", "aerospike-core", "aerospike-rt", "aerospike-sync", "aerospike-macro", "aerospike-tls"] [dev-dependencies] env_logger = "0.7" diff --git a/aerospike-core/Cargo.toml b/aerospike-core/Cargo.toml index a4f4fa1a..39f034f2 100644 --- a/aerospike-core/Cargo.toml +++ b/aerospike-core/Cargo.toml @@ -19,11 +19,15 @@ serde = { version = "1.0", features = ["derive"], optional = true } aerospike-rt = {path = "../aerospike-rt"} futures = {version = "0.3.16" } async-trait = "0.1.51" +aerospike-tls = { path = "../aerospike-tls", optional = true } [features] serialization = ["serde"] rt-tokio = ["aerospike-rt/rt-tokio"] rt-async-std = ["aerospike-rt/rt-async-std"] +tls = [] +tokio-rustls = ["aerospike-tls/tokio-rustls", "rt-tokio", "tls"] +tokio-native-tls = ["aerospike-tls/tokio-native-tls", "rt-tokio", "tls"] [dev-dependencies] env_logger = "0.7" diff --git a/aerospike-core/src/cluster/info_helper.rs b/aerospike-core/src/cluster/info_helper.rs new file mode 100644 index 00000000..eece952a --- /dev/null +++ b/aerospike-core/src/cluster/info_helper.rs @@ -0,0 +1,256 @@ +use std::str::FromStr; + +use crate::errors::{ErrorKind, Result}; + + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ServicesResponse<'a> { + pub peers_generation: u32, + pub port: u16, + pub nodes: Vec>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct NodeResponse<'a> { + pub node_id: &'a str, + pub tls_name: Option<&'a str>, + pub endpoints: Vec<&'a str>, +} + +pub fn parse_services_response<'a>(response: &'a str) -> Result> { + // peers-generation, port, [ list of [ NodeIDs/Names, TLSName(if defined), [ List of endpoints/IPaddresses ]]] + + const COMMON : char = ','; + const OPEN_BRACE : char = '['; + const CLOSE_BRACE : char = ']'; + + fn remove_outer_bracers<'a>(input: &'a str) -> Result<&'a str> { + match (input.chars().next(), input.chars().nth(input.len() - 1)) { + (Some(OPEN_BRACE), Some(CLOSE_BRACE)) => Ok(&input[1..input.len()-1]), + _ => Result::Err(ErrorKind::BadResponse(format!("Missing outer bracers {input}")).into()) + } + } + + fn read_generation_and_port_and_nodes<'a>(input: &'a str) -> Result<(u32, u16, Vec)> { + + fn read_nodes<'a>(nodes: &'a str) -> Result> { + // [ list of [ NodeIDs/Names, TLSName(if defined), [ List of endpoints/IPaddresses ]]] + let mut result : Vec = vec![]; + + fn read_node<'a>(node: &'a str) -> Result { + // [ NodeIDs/Names, TLSName(if defined), [ List of endpoints/IPaddresses ]] + + fn read_endpoints<'a>(endpoints: &'a str) -> Result> { + // [ List of endpoints/IPaddresses ] + + let endpoints = remove_outer_bracers(endpoints)?; + + Ok(endpoints.split(COMMON).collect::>()) + } + + let node = remove_outer_bracers(node)?; + + let first_common = node.find(COMMON) + .ok_or(ErrorKind::BadResponse("Missing section after node id".to_string()))?; + + let node_id = node.get(0..first_common) + .ok_or(ErrorKind::BadResponse("Missing node id".to_string()))?; + + let second_common = node[first_common+1..] + .find(COMMON) + .map(|x| x + first_common + 1) + .ok_or(ErrorKind::BadResponse("Missing section after tls name".to_string()))?; + + let tls_name = node.get(first_common+1..second_common) + .ok_or(ErrorKind::BadResponse("Missing tls name".to_string()))?; + let tls_name = if tls_name.is_empty() { None } else { Some(tls_name) }; + + let endpoints_slice = node.get(second_common+1..) + .ok_or(ErrorKind::BadResponse("Missing endpoints list".to_string()))?; + + let endpoints = read_endpoints(endpoints_slice)?; + + Ok(NodeResponse { node_id, tls_name, endpoints }) + } + + let nodes = remove_outer_bracers(nodes)?; + + let mut opened_bracers : i32 = 0; + let mut first_opened_brace_pos : Option = None; + for (pos, ch) in nodes.char_indices() { + match ch { + OPEN_BRACE => { + if first_opened_brace_pos.is_none() { + first_opened_brace_pos = Some(pos); + } + opened_bracers += 1; + + }, + CLOSE_BRACE => { + opened_bracers -= 1; + if opened_bracers < 0 { + return Result::Err(ErrorKind::BadResponse("Malformed nodes list".to_string()).into()); + } + + if opened_bracers == 0 { + let Some(opened_brace_pos) = first_opened_brace_pos else { + return Result::Err(ErrorKind::BadResponse("Wrong node list parser state".to_string()).into()); + }; + + let node_slice = nodes.get(opened_brace_pos..pos+1) + .ok_or(ErrorKind::BadResponse("Invalid node slice in list".to_string()))?; + + let node = read_node(node_slice)?; + + result.push(node); + + first_opened_brace_pos = None; + } + }, + _ => {}, + } + } + + Ok(result) + } + + let first_common = input + .find(COMMON) + .ok_or(ErrorKind::BadResponse("Missing peers generation".to_string()))?; + + let peers_generation_slice = input.get(0..first_common) + .ok_or(ErrorKind::BadResponse("Missing peers generation".to_string()))?; + + let peers_generation = u32::from_str(peers_generation_slice) + .map_err(|_| ErrorKind::BadResponse("Peers generation should be u32".to_string()))?; + + let second_common = input[first_common+1..] + .find(COMMON) + .map(|x| x + first_common + 1) + .ok_or(ErrorKind::BadResponse("Missing port".to_string()))?; + + let port_slice = input.get(first_common+1..second_common) + .ok_or(ErrorKind::BadResponse("Missing port".to_string()))?; + + let port = u16::from_str(port_slice) + .map_err(|_| ErrorKind::BadResponse("TCP port should be u16".to_string()))?; + + let nodes = input.get(second_common+1..) + .ok_or(ErrorKind::BadResponse("Missing node list".to_string()))?; + + let nodes = read_nodes(nodes)?; + + Ok((peers_generation, port, nodes)) + } + + let (peers_generation, port, nodes) = read_generation_and_port_and_nodes(response)?; + + Ok(ServicesResponse { + peers_generation, + port, + nodes, + }) +} + +mod tests { + use crate::cluster::info_helper::{ServicesResponse, NodeResponse, parse_services_response}; + + #[test] + fn positive_cases() { + let responses = [ + "9,3000,[[BB9040011AC4202,,[172.17.0.4]],[BB9050011AC4202,,[172.17.0.5]]]", + "9,3000,[[BB9060011AC4202,,[74.125.239.53]],[BB9070011AC4202,,[74.125.239.54]]]", + "10,4333,[[BB9060011AC4202,clusternode,[74.125.239.53]],[BB9070011AC4202,clusternode,[74.125.239.54]]]", + "10,4333,[[BB9040011AC4202,clusternode,[172.17.0.4,74.125.239.53]],[BB9050011AC4202,clusternode,[172.17.0.5,74.125.239.54]]]", + ]; + + let parsed_responses = [ + ServicesResponse { + peers_generation: 9, + port: 3000, + nodes: vec![ + NodeResponse { + node_id: "BB9040011AC4202", + tls_name: None, + endpoints: vec![ + "172.17.0.4", + ] + }, + NodeResponse { + node_id: "BB9050011AC4202", + tls_name: None, + endpoints: vec![ + "172.17.0.5", + ] + }, + ], + }, + ServicesResponse { + peers_generation: 9, + port: 3000, + nodes: vec![ + NodeResponse { + node_id: "BB9060011AC4202", + tls_name: None, + endpoints: vec![ + "74.125.239.53", + ] + }, + NodeResponse { + node_id: "BB9070011AC4202", + tls_name: None, + endpoints: vec![ + "74.125.239.54", + ] + }, + ], + }, + ServicesResponse { + peers_generation: 10, + port: 4333, + nodes: vec![ + NodeResponse { + node_id: "BB9060011AC4202", + tls_name: Some("clusternode"), + endpoints: vec![ + "74.125.239.53", + ] + }, + NodeResponse { + node_id: "BB9070011AC4202", + tls_name: Some("clusternode"), + endpoints: vec![ + "74.125.239.54", + ] + }, + ], + }, + ServicesResponse { + peers_generation: 10, + port: 4333, + nodes: vec![ + NodeResponse { + node_id: "BB9040011AC4202", + tls_name: Some("clusternode"), + endpoints: vec![ + "172.17.0.4", + "74.125.239.53", + ] + }, + NodeResponse { + node_id: "BB9050011AC4202", + tls_name: Some("clusternode"), + endpoints: vec![ + "172.17.0.5", + "74.125.239.54", + ] + }, + ], + }, + ]; + + for (parsed, response) in parsed_responses.iter().zip(responses.iter()) { + assert_eq!(parsed, &parse_services_response(response).unwrap()); + } + } +} \ No newline at end of file diff --git a/aerospike-core/src/cluster/mod.rs b/aerospike-core/src/cluster/mod.rs index a3127649..4e8ba295 100644 --- a/aerospike-core/src/cluster/mod.rs +++ b/aerospike-core/src/cluster/mod.rs @@ -17,6 +17,7 @@ pub mod node; pub mod node_validator; pub mod partition; pub mod partition_tokenizer; +mod info_helper; use aerospike_rt::time::{Duration, Instant}; use std::collections::HashMap; diff --git a/aerospike-core/src/cluster/node.rs b/aerospike-core/src/cluster/node.rs index b2f4b602..672c2acb 100644 --- a/aerospike-core/src/cluster/node.rs +++ b/aerospike-core/src/cluster/node.rs @@ -17,7 +17,6 @@ use std::collections::HashMap; use std::fmt; use std::hash::{Hash, Hasher}; use std::result::Result as StdResult; -use std::str::FromStr; use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering}; use std::sync::Arc; @@ -28,6 +27,8 @@ use crate::net::{ConnectionPool, Host, PooledConnection}; use crate::policy::ClientPolicy; use aerospike_rt::RwLock; +use super::info_helper::parse_services_response; + pub const PARTITIONS: usize = 4096; /// The node instance holding connections and node settings. @@ -144,10 +145,16 @@ impl Node { // Returns the services that the client should use for the cluster tend const fn services_name(&self) -> &'static str { - if self.client_policy.use_services_alternate { - "services-alternate" - } else { - "services" + #[cfg(not(feature = "tls"))] + let has_tls = false; + #[cfg(feature = "tls")] + let has_tls = self.client_policy.tls_policy.is_some(); + + match (self.client_policy.use_services_alternate, has_tls) { + (true, true) => "peers-tls-alt", + (true, false) => "peers-clear-alt", + (false, true) => "peers-tls-std", + (false, false) => "peers-clear-std", } } @@ -204,31 +211,26 @@ impl Node { Some(friend_string) => friend_string, }; - let friend_names = friend_string.split(';'); - for friend in friend_names { - let mut friend_info = friend.split(':'); - if friend_info.clone().count() != 2 { - error!( - "Node info from asinfo:services is malformed. Expected HOST:PORT, but got \ - '{}'", - friend - ); - continue; - } + let services_response = parse_services_response(friend_string)?; - let host = friend_info.next().unwrap(); - let port = u16::from_str(friend_info.next().unwrap())?; - let alias = match self.client_policy.ip_map { - Some(ref ip_map) if ip_map.contains_key(host) => { - Host::new(ip_map.get(host).unwrap(), port) - } - _ => Host::new(host, port), - }; + let empty_ip_map : HashMap = HashMap::new(); - if current_aliases.contains_key(&alias) { - self.reference_count.fetch_add(1, Ordering::Relaxed); - } else if !friends.contains(&alias) { - friends.push(alias); + let ip_map = self.client_policy.ip_map.as_ref().unwrap_or(&empty_ip_map); + + for node in services_response.nodes { + for endpoint in node.endpoints { + let mapped_ip = ip_map + .get(endpoint) + .map(|x| x.as_str()) + .unwrap_or(endpoint); + + let alias = Host::new(mapped_ip, services_response.port, node.tls_name); + + if current_aliases.contains_key(&alias) { + self.reference_count.fetch_add(1, Ordering::Relaxed); + } else if !friends.contains(&alias) { + friends.push(alias); + } } } diff --git a/aerospike-core/src/cluster/node_validator.rs b/aerospike-core/src/cluster/node_validator.rs index 6739c4f6..bd647ffd 100644 --- a/aerospike-core/src/cluster/node_validator.rs +++ b/aerospike-core/src/cluster/node_validator.rs @@ -80,7 +80,11 @@ impl NodeValidator { fn resolve_aliases(&mut self, host: &Host) -> Result<()> { self.aliases = (host.name.as_ref(), host.port) .to_socket_addrs()? - .map(|addr| Host::new(&addr.ip().to_string(), addr.port())) + .map(|addr| Host::new( + &addr.ip().to_string(), + addr.port(), + host.tls_name.as_ref().map(|x| x.as_str()) + )) .collect(); debug!("Resolved aliases for host {}: {:?}", host, self.aliases); if self.aliases.is_empty() { @@ -91,7 +95,7 @@ impl NodeValidator { } async fn validate_alias(&mut self, cluster: &Cluster, alias: &Host) -> Result<()> { - let mut conn = Connection::new(&alias.address(), &self.client_policy).await?; + let mut conn = Connection::new(&alias, &self.client_policy).await?; let info_map = Message::info(&mut conn, &["node", "cluster-name", "features"]).await?; match info_map.get("node") { diff --git a/aerospike-core/src/lib.rs b/aerospike-core/src/lib.rs index 1a98c3f7..e5ab4d7d 100644 --- a/aerospike-core/src/lib.rs +++ b/aerospike-core/src/lib.rs @@ -147,6 +147,8 @@ extern crate lazy_static; extern crate log; extern crate pwhash; extern crate rand; +#[cfg(feature = "tls")] +pub extern crate aerospike_tls; pub use batch::BatchRead; pub use bin::{Bin, Bins}; @@ -170,6 +172,8 @@ pub use result_code::ResultCode; pub use task::{IndexTask, RegisterTask, Task}; pub use user::User; pub use value::{FloatValue, Value}; +#[cfg(feature = "tls")] +pub use aerospike_tls::{self as tls, *}; #[macro_use] pub mod errors; diff --git a/aerospike-core/src/net/connection.rs b/aerospike-core/src/net/connection.rs index d17f220c..a312334d 100644 --- a/aerospike-core/src/net/connection.rs +++ b/aerospike-core/src/net/connection.rs @@ -13,6 +13,7 @@ // License for the specific language governing permissions and limitations under // the License. +use crate::Host; use crate::commands::admin_command::AdminCommand; use crate::commands::buffer::Buffer; use crate::errors::{ErrorKind, Result}; @@ -26,6 +27,7 @@ use aerospike_rt::time::{Duration, Instant}; #[cfg(all(any(feature = "rt-async-std"), not(feature = "rt-tokio")))] use futures::{AsyncReadExt, AsyncWriteExt}; use std::ops::Add; +use super::connection_stream::ConnectionStream; #[derive(Debug)] pub struct Connection { @@ -36,7 +38,7 @@ pub struct Connection { idle_deadline: Option, // connection object - conn: TcpStream, + conn: ConnectionStream, bytes_read: usize, @@ -44,18 +46,31 @@ pub struct Connection { } impl Connection { - pub async fn new(addr: &str, policy: &ClientPolicy) -> Result { - let stream = aerospike_rt::timeout(Duration::from_secs(10), TcpStream::connect(addr)).await; - if stream.is_err() { - bail!(ErrorKind::Connection( + pub async fn new(host: &Host, policy: &ClientPolicy) -> Result { + let stream = aerospike_rt::timeout(Duration::from_secs(10), TcpStream::connect(host.address())).await + .map_err(|_| ErrorKind::Connection( "Could not open network connection".to_string() - )); - } + ))??; + + #[cfg(feature = "tls")] + let stream = match &policy.tls_policy { + Some(tls_policy) => { + ConnectionStream::Tls(tls_policy.tls_connector.connect(host.tls_name().unwrap_or(host.name.as_str()), stream).await + .map_err(|_| ErrorKind::Connection( + "Could not open TLS network connection".to_string() + ))?) + }, + None => ConnectionStream::Tcp(stream), + }; + + #[cfg(not(feature = "tls"))] + let stream = ConnectionStream::Tcp(stream); + let mut conn = Connection { buffer: Buffer::new(policy.buffer_reclaim_threshold), bytes_read: 0, timeout: policy.timeout, - conn: stream.unwrap()?, + conn: stream.into(), idle_timeout: policy.idle_timeout, idle_deadline: policy.idle_timeout.map(|timeout| Instant::now() + timeout), }; diff --git a/aerospike-core/src/net/connection_pool.rs b/aerospike-core/src/net/connection_pool.rs index 95575dd1..f0180d8f 100644 --- a/aerospike-core/src/net/connection_pool.rs +++ b/aerospike-core/src/net/connection_pool.rs @@ -85,7 +85,7 @@ impl Queue { let conn = aerospike_rt::timeout( Duration::from_secs(5), - Connection::new(&self.0.host.address(), &self.0.policy), + Connection::new(&self.0.host, &self.0.policy), ) .await; diff --git a/aerospike-core/src/net/connection_stream/async_std.rs b/aerospike-core/src/net/connection_stream/async_std.rs new file mode 100644 index 00000000..06debbba --- /dev/null +++ b/aerospike-core/src/net/connection_stream/async_std.rs @@ -0,0 +1,53 @@ +use std::{pin::Pin, task::{Poll, Context}}; +use aerospike_rt::net::TcpStream; +use futures::{AsyncRead, AsyncWrite}; +use aerospike_rt::async_std::net::Shutdown; + +#[derive(Debug)] +pub enum ConnectionStream { + Tcp(TcpStream), +} + +impl ConnectionStream { + pub fn shutdown(&mut self, how: Shutdown) -> Result<(), std::io::Error> { + match self { + Self::Tcp(stream) => stream.shutdown(how) + } + } +} + +impl AsyncRead for ConnectionStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf) + } + } +} + +impl AsyncWrite for ConnectionStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_flush(cx) + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_close(cx) + } + } +} diff --git a/aerospike-core/src/net/connection_stream/mod.rs b/aerospike-core/src/net/connection_stream/mod.rs new file mode 100644 index 00000000..5ef88817 --- /dev/null +++ b/aerospike-core/src/net/connection_stream/mod.rs @@ -0,0 +1,11 @@ +#[cfg(all(any(not(feature = "rt-async-std")), feature = "rt-tokio"))] +mod tokio; + +#[cfg(all(any(feature = "rt-async-std"), not(feature = "rt-tokio")))] +mod async_std; + +#[cfg(all(any(not(feature = "rt-async-std")), feature = "rt-tokio"))] +pub use self::tokio::ConnectionStream; + +#[cfg(all(any(feature = "rt-async-std"), not(feature = "rt-tokio")))] +pub use self::async_std::ConnectionStream; \ No newline at end of file diff --git a/aerospike-core/src/net/connection_stream/tokio.rs b/aerospike-core/src/net/connection_stream/tokio.rs new file mode 100644 index 00000000..e40696cd --- /dev/null +++ b/aerospike-core/src/net/connection_stream/tokio.rs @@ -0,0 +1,60 @@ +use std::{pin::Pin, task::{Poll, Context}}; +use aerospike_rt::net::TcpStream; +use aerospike_rt::io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "tls")] +use aerospike_tls::TlsStream; + +#[derive(Debug)] +pub enum ConnectionStream { + Tcp(TcpStream), + #[cfg(feature = "tls")] + Tls(TlsStream), +} + +impl AsyncRead for ConnectionStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut aerospike_rt::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tls")] + Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ConnectionStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tls")] + Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tls")] + Self::Tls(stream) => Pin::new(stream).poll_flush(cx), + Self::Tcp(stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tls")] + Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), + Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} diff --git a/aerospike-core/src/net/host.rs b/aerospike-core/src/net/host.rs index 8bd52a33..a0244c26 100644 --- a/aerospike-core/src/net/host.rs +++ b/aerospike-core/src/net/host.rs @@ -29,14 +29,18 @@ pub struct Host { /// Port of database server. pub port: u16, + + /// Name used for TLS connection + pub tls_name: Option, } impl Host { /// Create a new host instance given a hostname/IP and a port number. - pub fn new(name: &str, port: u16) -> Self { + pub fn new(name: &str, port: u16, tls_name: Option<&str>) -> Self { Host { name: name.to_string(), port, + tls_name: tls_name.map(|x| x.to_owned()), } } @@ -44,6 +48,11 @@ impl Host { pub fn address(&self) -> String { format!("{}:{}", self.name, self.port) } + + /// Returns a string representation of domain name for TLS conenction. + pub fn tls_name(&self) -> Option<&str> { + self.tls_name.as_ref().map(|x| x.as_str()) + } } impl ToSocketAddrs for Host { @@ -55,7 +64,10 @@ impl ToSocketAddrs for Host { impl fmt::Display for Host { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}:{}", self.name, self.port) + match &self.tls_name { + Some(tls_name) => write!(f, "{}:{}:{}", self.name, tls_name, self.port), + None => write!(f, "{}:{}", self.name, self.port), + } } } @@ -97,14 +109,18 @@ mod tests { #[test] fn to_hosts() { assert_eq!( - vec![Host::new("foo", 3000)], + vec![Host::new("foo", 3000, None)], String::from("foo").to_hosts().unwrap() ); - assert_eq!(vec![Host::new("foo", 3000)], "foo".to_hosts().unwrap()); - assert_eq!(vec![Host::new("foo", 1234)], "foo:1234".to_hosts().unwrap()); + assert_eq!(vec![Host::new("foo", 3000, None)], "foo".to_hosts().unwrap()); + assert_eq!(vec![Host::new("foo", 1234, None)], "foo:1234".to_hosts().unwrap()); assert_eq!( - vec![Host::new("foo", 1234), Host::new("bar", 1234)], + vec![Host::new("foo", 1234, None), Host::new("bar", 1234, None)], "foo:1234,bar:1234".to_hosts().unwrap() ); + assert_eq!( + vec![Host::new("foo", 1234, Some("bar")), Host::new("bar", 1234, Some("foo"))], + "foo:bar:1234,bar:foo:1234".to_hosts().unwrap() + ); } } diff --git a/aerospike-core/src/net/mod.rs b/aerospike-core/src/net/mod.rs index f27d1f26..66b31be1 100644 --- a/aerospike-core/src/net/mod.rs +++ b/aerospike-core/src/net/mod.rs @@ -21,5 +21,6 @@ pub use self::host::ToHosts; mod connection; mod connection_pool; +mod connection_stream; pub mod host; mod parser; diff --git a/aerospike-core/src/net/parser.rs b/aerospike-core/src/net/parser.rs index 69a84fe8..68ff8294 100644 --- a/aerospike-core/src/net/parser.rs +++ b/aerospike-core/src/net/parser.rs @@ -35,7 +35,7 @@ impl<'a> Parser<'a> { let mut hosts = Vec::new(); loop { let addr = self.read_addr_tuple()?; - let (host, _tls_name, port) = match addr.len() { + let (host, tls_name, port) = match addr.len() { 3 => (addr[0].clone(), Some(addr[1].clone()), addr[2].parse()?), 2 => { if let Ok(port) = addr[1].parse() { @@ -49,8 +49,8 @@ impl<'a> Parser<'a> { "Invalid address string".to_string() )), }; - // TODO: add TLS name - hosts.push(Host::new(&host, port)); + + hosts.push(Host::new(&host, port, tls_name.as_ref().map(|x| x.as_str()))); match self.peek() { Some(&c) if c == ',' => self.next_char(), @@ -150,25 +150,43 @@ mod tests { #[test] fn read_hosts() { assert_eq!( - vec![Host::new("foo", 3000)], + vec![Host::new("foo", 3000, None)], Parser::new("foo", 3000).read_hosts().unwrap() ); assert_eq!( - vec![Host::new("foo", 3000)], + vec![Host::new("foo", 3000, Some("bar"))], Parser::new("foo:bar", 3000).read_hosts().unwrap() ); assert_eq!( - vec![Host::new("foo", 1234)], + vec![Host::new("foo", 1234, None)], Parser::new("foo:1234", 3000).read_hosts().unwrap() ); assert_eq!( - vec![Host::new("foo", 1234)], + vec![Host::new("foo", 1234, Some("bar"))], Parser::new("foo:bar:1234", 3000).read_hosts().unwrap() ); assert_eq!( - vec![Host::new("foo", 1234), Host::new("bar", 1234)], + vec![Host::new("foo", 1234, None), Host::new("bar", 1234, None)], Parser::new("foo:1234,bar:1234", 3000).read_hosts().unwrap() ); + + assert_eq!( + vec![Host::new("foo", 3000, Some("bar"))], + Parser::new("foo:bar", 3000).read_hosts().unwrap() + ); + assert_eq!( + vec![Host::new("foo", 1234, Some("bar"))], + Parser::new("foo:bar:1234", 3000).read_hosts().unwrap() + ); + assert_eq!( + vec![Host::new("foo", 1234, Some("bar"))], + Parser::new("foo:bar:1234", 3000).read_hosts().unwrap() + ); + assert_eq!( + vec![Host::new("foo", 1234, Some("bar")), Host::new("bar", 1234, Some("barbar"))], + Parser::new("foo:bar:1234,bar:barbar:1234", 3000).read_hosts().unwrap() + ); + assert!(Parser::new("", 3000).read_hosts().is_err()); assert!(Parser::new(",", 3000).read_hosts().is_err()); assert!(Parser::new("foo,", 3000).read_hosts().is_err()); diff --git a/aerospike-core/src/policy/client_policy.rs b/aerospike-core/src/policy/client_policy.rs index 1ceb7a86..8780d63d 100644 --- a/aerospike-core/src/policy/client_policy.rs +++ b/aerospike-core/src/policy/client_policy.rs @@ -17,6 +17,8 @@ use std::time::Duration; use crate::commands::admin_command::AdminCommand; use crate::errors::Result; +#[cfg(feature = "tls")] +use crate::policy::TlsPolicy; /// `ClientPolicy` encapsulates parameters for client policy command. #[derive(Debug, Clone)] @@ -83,6 +85,9 @@ pub struct ClientPolicy { /// to join the client's view of the cluster. Should only be set when connecting to servers /// that support the "cluster-name" info command. pub cluster_name: Option, + + #[cfg(feature = "tls")] + pub tls_policy: Option, } impl Default for ClientPolicy { @@ -100,6 +105,8 @@ impl Default for ClientPolicy { thread_pool_size: 128, cluster_name: None, buffer_reclaim_threshold: 65536, + #[cfg(feature = "tls")] + tls_policy: None, } } } diff --git a/aerospike-core/src/policy/mod.rs b/aerospike-core/src/policy/mod.rs index 88687120..6d48a020 100644 --- a/aerospike-core/src/policy/mod.rs +++ b/aerospike-core/src/policy/mod.rs @@ -29,6 +29,8 @@ mod query_policy; mod read_policy; mod record_exists_action; mod scan_policy; +#[cfg(feature = "tls")] +mod tls_policy; mod write_policy; pub use self::admin_policy::AdminPolicy; @@ -44,6 +46,8 @@ pub use self::query_policy::QueryPolicy; pub use self::read_policy::ReadPolicy; pub use self::record_exists_action::RecordExistsAction; pub use self::scan_policy::ScanPolicy; +#[cfg(feature = "tls")] +pub use self::tls_policy::TlsPolicy; pub use self::write_policy::WritePolicy; use crate::expressions::FilterExpression; diff --git a/aerospike-core/src/policy/tls_policy.rs b/aerospike-core/src/policy/tls_policy.rs new file mode 100644 index 00000000..2267c9ee --- /dev/null +++ b/aerospike-core/src/policy/tls_policy.rs @@ -0,0 +1,12 @@ +use aerospike_tls::TlsConnector; + +#[derive(Debug, Clone)] +pub struct TlsPolicy { + pub tls_connector: TlsConnector, +} + +impl TlsPolicy { + pub fn new(tls_connector: TlsConnector) -> Self { + Self { tls_connector, } + } +} \ No newline at end of file diff --git a/aerospike-tls/Cargo.toml b/aerospike-tls/Cargo.toml new file mode 100644 index 00000000..e5240d7e --- /dev/null +++ b/aerospike-tls/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "aerospike-tls" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +futures = { version = "0.3.16" } +aerospike-rt = { path = "../aerospike-rt" } +tokio-rustls = { version = "0.23.4", optional = true } +webpki = { version = "0.22.0", optional = true } +tokio-native-tls = { version = "0.3.0", optional = true } + +[features] +default = ["tokio-rustls"] +tokio-rustls = ["dep:tokio-rustls", "dep:webpki", "aerospike-rt/rt-tokio"] +tokio-native-tls = ["dep:tokio-native-tls", "aerospike-rt/rt-tokio"] diff --git a/aerospike-tls/src/lib.rs b/aerospike-tls/src/lib.rs new file mode 100644 index 00000000..eb0cc58c --- /dev/null +++ b/aerospike-tls/src/lib.rs @@ -0,0 +1,19 @@ +#[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))] +compile_error!("Please select a tls implementation from ['tokio-rustls', 'tokio-native-tls']"); + +#[cfg(feature = "tokio-rustls")] +extern crate tokio_rustls; + +#[cfg(feature = "tokio-native-tls")] +extern crate tokio_native_tls; + +mod tls_connector; +mod tls_stream; + +pub use tls_connector::{TlsConnectError, TlsConnector}; +pub use tls_stream::TlsStream; +#[cfg(feature = "tokio-rustls")] +pub use tokio_rustls::*; + +#[cfg(feature = "tokio-native-tls")] +pub use tokio_native_tls::*; diff --git a/aerospike-tls/src/tls_connector.rs b/aerospike-tls/src/tls_connector.rs new file mode 100644 index 00000000..1cc5781e --- /dev/null +++ b/aerospike-tls/src/tls_connector.rs @@ -0,0 +1,68 @@ +use crate::tls_stream::TlsStream; +use aerospike_rt::net::TcpStream; +use std::fmt::Debug; +#[cfg(feature = "tokio-rustls")] +use tokio_rustls::rustls::ServerName; + +#[derive(Clone)] +pub enum TlsConnector { + #[cfg(feature = "tokio-rustls")] + Rustls(tokio_rustls::TlsConnector), + #[cfg(feature = "tokio-native-tls")] + NativeTls(tokio_native_tls::TlsConnector), +} + +impl Debug for TlsConnector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(_) => f.debug_tuple("Rustls").finish(), + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(), + } + } +} + +pub enum TlsConnectError { + IO(std::io::Error), + InvalidDnsName, + #[cfg(feature = "tokio-native-tls")] + NativeTlsError(tokio_native_tls::native_tls::Error), +} + +impl TlsConnector { + #[cfg(feature = "tokio-rustls")] + pub fn new_rustls(connector: tokio_rustls::TlsConnector) -> Self { + Self::Rustls(connector) + } + + #[cfg(feature = "tokio-native-tls")] + pub fn new_native_tls(connector: tokio_native_tls::TlsConnector) -> Self { + Self::NativeTls(connector) + } + + pub async fn connect( + &self, + domain: &str, + stream: TcpStream, + ) -> Result { + match self { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(connector) => { + let domain = + ServerName::try_from(domain).map_err(|_| TlsConnectError::InvalidDnsName)?; + connector + .connect(domain, stream) + .await + .map_err(|err| TlsConnectError::IO(err)) + .map(|stream| TlsStream::Rustls(stream)) + } + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(connector) => connector + .connect(domain, stream) + .await + .map_err(|err| TlsConnectError::NativeTlsError(err)) + .map(|stream| TlsStream::NativeTls(stream)), + } + } +} diff --git a/aerospike-tls/src/tls_stream.rs b/aerospike-tls/src/tls_stream.rs new file mode 100644 index 00000000..f4736eee --- /dev/null +++ b/aerospike-tls/src/tls_stream.rs @@ -0,0 +1,73 @@ +use std::{fmt::Debug, pin::Pin}; + +use aerospike_rt::{io::AsyncRead, io::AsyncWrite, net::TcpStream}; +use std::task::{Context, Poll}; + +pub enum TlsStream { + #[cfg(feature = "tokio-rustls")] + Rustls(tokio_rustls::client::TlsStream), + #[cfg(feature = "tokio-native-tls")] + NativeTls(tokio_native_tls::TlsStream), +} + +impl Debug for TlsStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(_) => f.debug_tuple("Rustls").finish(), + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(), + } + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut aerospike_rt::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(stream) => Pin::new(stream).poll_read(cx, buf), + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(stream) => Pin::new(stream).poll_write(cx, buf), + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(stream) => Pin::new(stream).poll_flush(cx), + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tokio-rustls")] + Self::Rustls(stream) => Pin::new(stream).poll_shutdown(cx), + #[cfg(feature = "tokio-native-tls")] + Self::NativeTls(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +}