Skip to content

Commit

Permalink
Rebase fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibsG committed Feb 4, 2023
1 parent 8677264 commit 7ffc133
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 26 deletions.
3 changes: 1 addition & 2 deletions sqlx-core/src/net/tls/tls_native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ pub async fn handshake<S: Socket>(
if let (Some(cert_path), Some(key_path)) = (config.client_cert_path, config.client_key_path) {
let cert_path = cert_path.data().await?;
let key_path = key_path.data().await?;
let identity =
Identity::from_pkcs8(&cert_path, &key_path).map_err(|e| Error::Tls(e.into()))?;
let identity = Identity::from_pkcs8(&cert_path, &key_path).map_err(Error::tls)?;
builder.identity(identity);
}

Expand Down
79 changes: 55 additions & 24 deletions sqlx-core/src/net/tls/tls_rustls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use futures_util::future;
use std::io;
use std::io::{Cursor, Read, Write};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::SystemTime;
Expand Down Expand Up @@ -48,7 +47,7 @@ impl<S: Socket> Socket for RustlsSocket<S> {
match self.state.writer().write(buf) {
// Returns a zero-length write when the buffer is full.
Ok(0) => Err(io::ErrorKind::WouldBlock.into()),
other => return other,
other => other,
}
}

Expand Down Expand Up @@ -81,10 +80,32 @@ where
{
let config = ClientConfig::builder().with_safe_defaults();

// authentication using user's key and its associated certificate
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert_chain = certs_from_pem(cert_path.data().await?)?;
let key_der = private_key_from_pem(key_path.data().await?)?;
Some((cert_chain, key_der))
}
(None, None) => None,
(_, _) => {
return Err(Error::Configuration(
"user auth key and certs must be given together".into(),
))
}
};

let config = if tls_config.accept_invalid_certs {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
if let Some(user_auth) = user_auth {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_single_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
}
} else {
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
Expand All @@ -100,37 +121,22 @@ where
let mut cursor = Cursor::new(data);

for cert in rustls_pemfile::certs(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
.map_err(|_| Error::Tls(format!("Invalid certificate {ca}").into()))?
{
cert_store
.add(&rustls::Certificate(cert))
.map_err(|err| Error::Tls(err.into()))?;
}
}

// authentication using user's key and its associated certificate
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert_chain = certs_from_pem(cert_path.data().await?)?;
let key_der = private_key_from_pem(key_path.data().await?)?;
Some((cert_chain, key_der))
}
(None, None) => None,
(_, _) => {
return Err(Error::Configuration(
"user auth key and certs must be given together".into(),
))
}
};

if tls_config.accept_invalid_hostnames {
let verifier = WebPkiVerifier::new(cert_store, None);

if let Some(user_auth) = user_auth {
config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_single_cert(user_auth.0, user_auth.1)
.map_err(|err| Error::Tls(err.into()))?
.map_err(Error::tls)?
} else {
config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
Expand All @@ -140,7 +146,7 @@ where
config
.with_root_certificates(cert_store)
.with_single_cert(user_auth.0, user_auth.1)
.map_err(|err| Error::Tls(err.into()))?
.map_err(Error::tls)?
} else {
config
.with_root_certificates(cert_store)
Expand All @@ -162,6 +168,31 @@ where
Ok(socket)
}

fn certs_from_pem(pem: Vec<u8>) -> Result<Vec<rustls::Certificate>, Error> {
let cur = Cursor::new(pem);
let mut reader = BufReader::new(cur);
rustls_pemfile::certs(&mut reader)?
.into_iter()
.map(|v| Ok(rustls::Certificate(v)))
.collect()
}

fn private_key_from_pem(pem: Vec<u8>) -> Result<rustls::PrivateKey, Error> {
let cur = Cursor::new(pem);
let mut reader = BufReader::new(cur);

loop {
match rustls_pemfile::read_one(&mut reader)? {
Some(rustls_pemfile::Item::RSAKey(key)) => return Ok(rustls::PrivateKey(key)),
Some(rustls_pemfile::Item::PKCS8Key(key)) => return Ok(rustls::PrivateKey(key)),
None => break,
_ => {}
}
}

Err(Error::Configuration("no keys found pem file".into()))
}

struct DummyTlsVerifier;

impl ServerCertVerifier for DummyTlsVerifier {
Expand Down

0 comments on commit 7ffc133

Please sign in to comment.