Skip to content

Commit

Permalink
Support to rustls 0.20 (#1388)
Browse files Browse the repository at this point in the history
  • Loading branch information
BiagioFesta authored Nov 29, 2021
1 parent 8fe22c4 commit 8b37ae4
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 101 deletions.
14 changes: 8 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ rustls-tls-native-roots = ["rustls-native-certs", "__rustls"]

blocking = ["futures-util/io", "tokio/rt-multi-thread", "tokio/sync"]

cookies = ["cookie_crate", "cookie_store"]
cookies = ["cookie_crate", "cookie_store", "proc-macro-hack"]

gzip = ["async-compression", "async-compression/gzip", "tokio-util"]

Expand All @@ -70,7 +70,7 @@ __tls = []

# Enables common rustls code.
# Equivalent to rustls-tls-manual-roots but shorter :)
__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls"]
__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls", "rustls-pemfile"]

# When enabled, disable using the cached SYS_PROXIES.
__internal_proxy_sys_no_cache = []
Expand Down Expand Up @@ -112,15 +112,17 @@ native-tls-crate = { version = "0.2.8", optional = true, package = "native-tls"
tokio-native-tls = { version = "0.3.0", optional = true }

# rustls-tls
hyper-rustls = { version = "0.22.1", default-features = false, optional = true }
rustls = { version = "0.19", features = ["dangerous_configuration"], optional = true }
tokio-rustls = { version = "0.22", optional = true }
hyper-rustls = { version = "0.23", default-features = false, optional = true }
rustls = { version = "0.20", features = ["dangerous_configuration"], optional = true }
tokio-rustls = { version = "0.23", optional = true }
webpki-roots = { version = "0.21", optional = true }
rustls-native-certs = { version = "0.5", optional = true }
rustls-native-certs = { version = "0.6", optional = true }
rustls-pemfile = { version = "0.2", optional = true }

## cookies
cookie_crate = { version = "0.15", package = "cookie", optional = true }
cookie_store = { version = "0.15", optional = true }
proc-macro-hack = { version = "0.5.19", optional = true }

## compression
async-compression = { version = "0.3.7", default-features = false, features = ["tokio"], optional = true }
Expand Down
120 changes: 69 additions & 51 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ use hyper::client::ResponseFuture;
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::TlsConnector;
use pin_project_lite::pin_project;
#[cfg(feature = "rustls-tls-native-roots")]
use rustls::RootCertStore;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -322,68 +320,94 @@ impl ClientBuilder {
TlsBackend::Rustls => {
use crate::tls::NoVerifier;

let mut tls = rustls::ClientConfig::new();
match config.http_version_pref {
HttpVersionPref::Http1 => {
tls.set_protocols(&["http/1.1".into()]);
}
HttpVersionPref::Http2 => {
tls.set_protocols(&["h2".into()]);
}
HttpVersionPref::All => {
tls.set_protocols(&["h2".into(), "http/1.1".into()]);
}
// Set root certificates.
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in config.root_certs {
cert.add_to_rustls(&mut root_cert_store)?;
}

#[cfg(feature = "rustls-tls-webpki-roots")]
if config.tls_built_in_root_certs {
tls.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
}
#[cfg(feature = "rustls-tls-native-roots")]
if config.tls_built_in_root_certs {
let roots_slice = NATIVE_ROOTS.as_ref().unwrap().roots.as_slice();
tls.root_store.roots.extend_from_slice(roots_slice);
}

if !config.certs_verification {
tls.dangerous()
.set_certificate_verifier(Arc::new(NoVerifier));
}
use rustls::OwnedTrustAnchor;

let trust_anchors =
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|trust_anchor| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
trust_anchor.subject,
trust_anchor.spki,
trust_anchor.name_constraints,
)
});

for cert in config.root_certs {
cert.add_to_rustls(&mut tls)?;
root_cert_store.add_server_trust_anchors(trust_anchors);
}

if let Some(id) = config.identity {
id.add_to_rustls(&mut tls)?;
#[cfg(feature = "rustls-tls-native-roots")]
if config.tls_built_in_root_certs {
for cert in rustls_native_certs::load_native_certs()
.map_err(crate::error::builder)?
{
root_cert_store
.add(&rustls::Certificate(cert.0))
.map_err(crate::error::builder)?
}
}

// rustls does not support TLS versions <1.2 and this is unlikely to change.
// https://github.com/rustls/rustls/issues/33

// As of writing, TLS 1.2 and 1.3 are the only implemented versions and are both
// enabled by default.
// rustls 0.20 will add ALL_VERSIONS and DEFAULT_VERSIONS. That will enable a more
// sophisticated approach.
// For now we assume the default tls.versions matches the future ALL_VERSIONS and
// act based on that.
// Set TLS versions.
let mut versions = rustls::ALL_VERSIONS.to_vec();

if let Some(min_tls_version) = config.min_tls_version {
tls.versions
.retain(|&version| match tls::Version::from_rustls(version) {
versions.retain(|&supported_version| {
match tls::Version::from_rustls(supported_version.version) {
Some(version) => version >= min_tls_version,
// Assume it's so new we don't know about it, allow it
// (as of writing this is unreachable)
None => true,
});
}
});
}

if let Some(max_tls_version) = config.max_tls_version {
tls.versions
.retain(|&version| match tls::Version::from_rustls(version) {
versions.retain(|&supported_version| {
match tls::Version::from_rustls(supported_version.version) {
Some(version) => version <= max_tls_version,
None => false,
});
}
});
}

// Build TLS config
let config_builder = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&versions)
.map_err(crate::error::builder)?
.with_root_certificates(root_cert_store);

// Finalize TLS config
let mut tls = if let Some(id) = config.identity {
id.add_to_rustls(config_builder)?
} else {
config_builder.with_no_client_auth()
};

// Certificate verifier
if !config.certs_verification {
tls.dangerous()
.set_certificate_verifier(Arc::new(NoVerifier));
}

// ALPN protocol
match config.http_version_pref {
HttpVersionPref::Http1 => {
tls.alpn_protocols = vec!["http/1.1".into()];
}
HttpVersionPref::Http2 => {
tls.alpn_protocols = vec!["h2".into()];
}
HttpVersionPref::All => {
tls.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
}
}

Connector::new_rustls_tls(
Expand Down Expand Up @@ -1848,12 +1872,6 @@ fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &dyn cookie::CookieS
}
}

#[cfg(feature = "rustls-tls-native-roots")]
lazy_static! {
static ref NATIVE_ROOTS: std::io::Result<RootCertStore> =
rustls_native_certs::load_native_certs().map_err(|e| e.1);
}

#[cfg(test)]
mod tests {
#[tokio::test]
Expand Down
24 changes: 11 additions & 13 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,16 @@ impl Connector {
#[cfg(feature = "__rustls")]
Inner::RustlsTls { tls_proxy, .. } => {
if dst.scheme() == Some(&Scheme::HTTPS) {
use tokio_rustls::webpki::DNSNameRef;
use std::convert::TryFrom;
use tokio_rustls::TlsConnector as RustlsConnector;

let tls = tls_proxy.clone();
let host = dst.host().ok_or("no host in url")?.to_string();
let conn = socks::connect(proxy, dst, dns).await?;
let dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned())
.map_err(|_| "Invalid DNS Name")?;
let server_name = rustls::ServerName::try_from(host.as_str())
.map_err(|_| "Invalid Server Name")?;
let io = RustlsConnector::from(tls)
.connect(dnsname.as_ref(), conn)
.connect(server_name, conn)
.await?;
return Ok(Conn {
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
Expand Down Expand Up @@ -479,7 +478,8 @@ impl Connector {
tls_proxy,
} => {
if dst.scheme() == Some(&Scheme::HTTPS) {
use tokio_rustls::webpki::DNSNameRef;
use rustls::ServerName;
use std::convert::TryFrom;
use tokio_rustls::TlsConnector as RustlsConnector;

let host = dst.host().ok_or("no host in url")?.to_string();
Expand All @@ -489,13 +489,12 @@ impl Connector {
let tls = tls.clone();
let conn = http.call(proxy_dst).await?;
log::trace!("tunneling HTTPS over proxy");
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned())
.map_err(|_| "Invalid DNS Name");
let maybe_server_name =
ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name");
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
let dnsname = maybe_dnsname?;
let server_name = maybe_server_name?;
let io = RustlsConnector::from(tls)
.connect(dnsname.as_ref(), tunneled)
.connect(server_name, tunneled)
.await?;

return Ok(Conn {
Expand Down Expand Up @@ -820,7 +819,6 @@ mod native_tls_conn {
mod rustls_tls_conn {
use hyper::client::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use rustls::Session;
use std::{
io::{self, IoSlice},
pin::Pin,
Expand All @@ -837,7 +835,7 @@ mod rustls_tls_conn {

impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
fn connected(&self) -> Connected {
if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") {
if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") {
self.inner.get_ref().0.connected().negotiated_h2()
} else {
self.inner.get_ref().0.connected()
Expand Down
Loading

0 comments on commit 8b37ae4

Please sign in to comment.