diff --git a/Cargo.lock b/Cargo.lock index a532321ee7..475d77d7a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4108,6 +4108,7 @@ dependencies = [ "db_common", "derive_more", "futures 0.3.15", + "futures-rustls 0.21.1", "gstuff", "hex 0.4.3", "lazy_static", @@ -4288,10 +4289,13 @@ dependencies = [ "primitives", "rand 0.6.5", "rand 0.7.3", + "rcgen", "regex", "rmp-serde", "rpc", "rpc_task", + "rustls 0.20.4", + "rustls-pemfile 1.0.2", "script", "secp256k1 0.20.3", "ser_error", @@ -4370,6 +4374,7 @@ dependencies = [ "derive_more", "ethkey", "futures 0.3.15", + "futures-util", "gstuff", "http 0.2.7", "hyper", @@ -4379,8 +4384,11 @@ dependencies = [ "mm2_err_handle", "prost", "rand 0.7.3", + "rustls 0.20.4", "serde", "serde_json", + "tokio", + "tokio-rustls", "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test", @@ -4944,6 +4952,15 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c719dcf55f09a3a7e764c6649ab594c18a177e3599c467983cdf644bfc0a4088" +[[package]] +name = "pem" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" +dependencies = [ + "base64 0.13.0", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -5626,6 +5643,18 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "rcgen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbe84efe2f38dea12e9bfc1f65377fdf03e53a18cb3b995faedf7934c7e785b" +dependencies = [ + "pem", + "ring", + "time 0.3.11", + "yasna", +] + [[package]] name = "rdrand" version = "0.4.0" @@ -5996,11 +6025,11 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7522c9de787ff061458fe9a829dc790a3f5b22dc571694fc5883f448b94d9a9" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.13.0", + "base64 0.21.2", ] [[package]] @@ -7952,7 +7981,7 @@ dependencies = [ "pin-project 1.0.10", "prost", "prost-derive", - "rustls-pemfile 1.0.0", + "rustls-pemfile 1.0.2", "tokio", "tokio-rustls", "tokio-stream", @@ -8865,6 +8894,15 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time 0.3.11", +] + [[package]] name = "zbase32" version = "0.1.2" diff --git a/mm2src/mm2_core/Cargo.toml b/mm2src/mm2_core/Cargo.toml index 6d967b852b..fd7ea0c6be 100644 --- a/mm2src/mm2_core/Cargo.toml +++ b/mm2src/mm2_core/Cargo.toml @@ -29,4 +29,5 @@ gstuff = { version = "0.7", features = ["nightly"] } mm2_rpc = { path = "../mm2_rpc" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] +futures-rustls = { version = "0.21.1" } gstuff = { version = "0.7", features = ["nightly"] } diff --git a/mm2src/mm2_core/src/mm_ctx.rs b/mm2src/mm2_core/src/mm_ctx.rs index c390b1da7f..533f5f64f7 100644 --- a/mm2src/mm2_core/src/mm_ctx.rs +++ b/mm2src/mm2_core/src/mm_ctx.rs @@ -26,10 +26,12 @@ cfg_wasm32! { cfg_native! { use db_common::sqlite::rusqlite::Connection; + use futures_rustls::webpki::DNSNameRef; use mm2_metrics::prometheus; use mm2_metrics::MmMetricsError; use std::net::{IpAddr, SocketAddr, AddrParseError}; use std::path::{Path, PathBuf}; + use std::str::FromStr; use std::sync::MutexGuard; } @@ -197,6 +199,52 @@ impl MmCtx { Ok(SocketAddr::new(ip, port as u16)) } + /// Whether to use HTTPS for RPC server or not. + #[cfg(not(target_arch = "wasm32"))] + pub fn is_https(&self) -> bool { self.conf["https"].as_bool().unwrap_or(false) } + + /// SANs for self-signed certificate generation. + #[cfg(not(target_arch = "wasm32"))] + pub fn alt_names(&self) -> Result, String> { + // Helper function to validate `alt_names` entries + fn validate_alt_name(name: &str) -> Result<(), String> { + // Check if it is a valid IP address + if let Ok(ip) = IpAddr::from_str(name) { + if ip.is_unspecified() { + return ERR!("IP address {} must be specified", ip); + } + return Ok(()); + } + + // Check if it is a valid DNS name + if DNSNameRef::try_from_ascii_str(name).is_ok() { + return Ok(()); + } + + ERR!( + "`alt_names` contains {} which is neither a valid IP address nor a valid DNS name", + name + ) + } + + if self.conf["alt_names"].is_null() { + // Default SANs + return Ok(vec!["localhost".to_string(), "127.0.0.1".to_string()]); + } + + json::from_value(self.conf["alt_names"].clone()) + .map_err(|e| format!("`alt_names` is not a valid JSON array of strings: {}", e)) + .and_then(|names: Vec| { + if names.is_empty() { + return ERR!("alt_names is empty"); + } + for name in &names { + try_s!(validate_alt_name(name)); + } + Ok(names) + }) + } + /// MM database path. /// Defaults to a relative "DB". /// diff --git a/mm2src/mm2_main/Cargo.toml b/mm2src/mm2_main/Cargo.toml index 23f9c2b0f9..1dd68e3b02 100644 --- a/mm2src/mm2_main/Cargo.toml +++ b/mm2src/mm2_main/Cargo.toml @@ -44,12 +44,12 @@ enum-primitive-derive = "0.2" futures01 = { version = "0.1", package = "futures" } futures = { version = "0.3.1", package = "futures", features = ["compat", "async-await"] } gstuff = { version = "0.7", features = ["nightly"] } -mm2_gui_storage = { path = "../mm2_gui_storage" } hash256-std-hasher = "0.15.2" hash-db = "0.15.2" hex = "0.4.2" http = "0.2" hw_common = { path = "../hw_common" } +instant = { version = "0.1.12" } itertools = "0.10" keys = { path = "../mm2_bitcoin/keys" } lazy_static = "1.4" @@ -57,6 +57,7 @@ lazy_static = "1.4" libc = "0.2" mm2_core = { path = "../mm2_core" } mm2_err_handle = { path = "../mm2_err_handle" } +mm2_gui_storage = { path = "../mm2_gui_storage" } mm2_io = { path = "../mm2_io" } mm2-libp2p = { path = "../mm2_libp2p" } mm2_metrics = { path = "../mm2_metrics" } @@ -90,7 +91,6 @@ sp-trie = { version = "6.0", default-features = false } trie-db = { version = "0.23.1", default-features = false } trie-root = "0.16.0" uuid = { version = "1.2.2", features = ["fast-rng", "serde", "v4"] } -instant = { version = "0.1.12" } [target.'cfg(target_arch = "wasm32")'.dependencies] instant = { version = "0.1.12", features = ["wasm-bindgen"] } @@ -106,6 +106,9 @@ web-sys = { version = "0.3.55", features = ["console"] } dirs = { version = "1" } futures-rustls = { version = "0.21.1" } hyper = { version = "0.14.26", features = ["client", "http2", "server", "tcp"] } +rcgen = "0.10" +rustls = { version = "0.20", default-features = false } +rustls-pemfile = "1.0.2" tokio = { version = "1.20", features = ["io-util", "rt-multi-thread", "net"] } [target.'cfg(windows)'.dependencies] diff --git a/mm2src/mm2_main/src/rpc.rs b/mm2src/mm2_main/src/rpc.rs index c0c338dab6..8ca80d5274 100644 --- a/mm2src/mm2_main/src/rpc.rs +++ b/mm2src/mm2_main/src/rpc.rs @@ -21,7 +21,7 @@ // use crate::mm2::rpc::rate_limiter::RateLimitError; -use common::log::error; +use common::log::{error, info}; use common::{err_to_rpc_json_string, err_tp_rpc_json, HttpStatusCode, APPLICATION_JSON}; use derive_more::Display; use futures::future::{join_all, FutureExt}; @@ -303,10 +303,47 @@ async fn rpc_service(req: Request, ctx_h: u32, client: SocketAddr) -> Resp #[cfg(not(target_arch = "wasm32"))] pub extern "C" fn spawn_rpc(ctx_h: u32) { + use common::now_sec; use common::wio::CORE; - use hyper::server::conn::AddrStream; + use hyper::server::conn::{AddrIncoming, AddrStream}; use hyper::service::{make_service_fn, service_fn}; + use mm2_net::native_tls::{TlsAcceptor, TlsStream}; + use rcgen::{generate_simple_self_signed, RcgenError}; + use rustls::{Certificate, PrivateKey}; + use rustls_pemfile as pemfile; use std::convert::Infallible; + use std::env; + use std::fs::File; + use std::io::{self, BufReader}; + + // Reads a certificate and a key from the specified files. + fn read_certificate_and_key( + cert_file: &File, + cert_key_path: &str, + ) -> Result<(Vec, PrivateKey), io::Error> { + let cert_file = &mut BufReader::new(cert_file); + let cert_chain = pemfile::certs(cert_file)?.into_iter().map(Certificate).collect(); + let key_file = &mut BufReader::new(File::open(cert_key_path)?); + let key = pemfile::read_all(key_file)? + .into_iter() + .find_map(|item| match item { + pemfile::Item::RSAKey(key) | pemfile::Item::PKCS8Key(key) | pemfile::Item::ECKey(key) => Some(key), + _ => None, + }) + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "No private key found"))?; + Ok((cert_chain, PrivateKey(key))) + } + + // Generates a self-signed certificate + fn generate_self_signed_cert(subject_alt_names: Vec) -> Result<(Vec, PrivateKey), RcgenError> { + // Generate the certificate + let cert = generate_simple_self_signed(subject_alt_names)?; + let cert_der = cert.serialize_der()?; + let privkey = PrivateKey(cert.serialize_private_key_der()); + let cert = Certificate(cert_der); + let cert_chain = vec![cert]; + Ok((cert_chain, privkey)) + } // NB: We need to manually handle the incoming connections in order to get the remote IP address, // cf. https://github.com/hyperium/hyper/issues/1410#issuecomment-419510220. @@ -314,62 +351,126 @@ pub extern "C" fn spawn_rpc(ctx_h: u32) { // then we might want to refactor into starting it ideomatically in order to benefit from a more graceful shutdown, // cf. https://github.com/hyperium/hyper/pull/1640. + let make_svc_fut = move |remote_addr: SocketAddr| async move { + Ok::<_, Infallible>(service_fn(move |req: Request| async move { + let res = rpc_service(req, ctx_h, remote_addr).await; + Ok::<_, Infallible>(res) + })) + }; + + //The `make_svc` macro creates a `make_service_fn` for a specified socket type. + // `$socket_type`: The socket type with a `remote_addr` method that returns a `SocketAddr`. + macro_rules! make_svc { + ($socket_type:ty) => { + make_service_fn(move |socket: &$socket_type| { + let remote_addr = socket.remote_addr(); + make_svc_fut(remote_addr) + }) + }; + } + + // The `get_shutdown_future` macro registers a graceful shutdown listener by calling the `register_listener` + // method of `GracefulShutdownRegistry`. + // If the `register_listener` method fails, it implies that the application is already in a shutdown state. + // In this case, the macro logs an error and immediately returns. + macro_rules! get_shutdown_future { + ($ctx:expr) => { + match $ctx.graceful_shutdown_registry.register_listener() { + Ok(shutdown_fut) => shutdown_fut, + Err(e) => { + error!("MmCtx seems to be stopped already: {e}"); + return; + }, + } + }; + } + + // Macro for spawning a server with error handling and logging + macro_rules! spawn_server { + ($server:expr, $ctx:expr, $ip:expr, $port:expr) => { + { + let server = $server.then(|r| { + if let Err(err) = r { + error!("{}", err); + }; + futures::future::ready(()) + }); + + // As it's said in the [issue](https://github.com/hyperium/tonic/issues/330): + // + // Aborting the server future will forcefully cancel all connections and not perform a proper drain/shutdown. + // While using the special shutdown methods on the server will allow hyper to gracefully drain all connections + // and gracefully close connections. + common::executor::spawn({ + log_tag!( + $ctx, + "😉"; + fmt = ">>>>>>>>>> DEX stats {}:{} DEX stats API enabled at unixtime.{} <<<<<<<<<", + $ip, + $port, + now_sec() + ); + let _ = $ctx.rpc_started.pin(true); + server + }); + } + }; + } + let ctx = MmArc::from_ffi_handle(ctx_h).expect("No context"); - let rpc_ip_port = ctx.rpc_ip_port().unwrap(); + let rpc_ip_port = ctx + .rpc_ip_port() + .unwrap_or_else(|err| panic!("Invalid RPC port: {}", err)); // By entering the context, we tie `tokio::spawn` to this executor. let _runtime_guard = CORE.0.enter(); - let server = Server::try_bind(&rpc_ip_port).unwrap_or_else(|_| panic!("Can't bind on {}", rpc_ip_port)); - let make_svc = make_service_fn(move |socket: &AddrStream| { - let remote_addr = socket.remote_addr(); - async move { - Ok::<_, Infallible>(service_fn(move |req: Request| async move { - let res = rpc_service(req, ctx_h, remote_addr).await; - Ok::<_, Infallible>(res) - })) - } - }); - - let shutdown_fut = match ctx.graceful_shutdown_registry.register_listener() { - Ok(shutdown_fut) => shutdown_fut, - Err(e) => { - error!("MmCtx seems to be stopped already: {e}"); - return; - }, - }; + if ctx.is_https() { + let cert_path = env::var("MM_CERT_PATH").unwrap_or_else(|_| "cert.pem".to_string()); + let (cert_chain, privkey) = match File::open(cert_path.clone()) { + Ok(cert_file) => { + let cert_key_path = env::var("MM_CERT_KEY_PATH").unwrap_or_else(|_| "key.pem".to_string()); + read_certificate_and_key(&cert_file, &cert_key_path) + .unwrap_or_else(|err| panic!("Can't read certificate and/or key from {:?}: {}", cert_path, err)) + }, + Err(ref err) if err.kind() == io::ErrorKind::NotFound => { + info!( + "No certificate found at {:?}, generating a self-signed certificate", + cert_path + ); + let subject_alt_names = ctx + .alt_names() + .unwrap_or_else(|err| panic!("Invalid `alt_names` config: {}", err)); + generate_self_signed_cert(subject_alt_names) + .unwrap_or_else(|err| panic!("Can't generate self-signed certificate: {}", err)) + }, + Err(err) => panic!("Can't open {:?}: {}", cert_path, err), + }; - let server = server - .http1_half_close(false) - .serve(make_svc) - .with_graceful_shutdown(shutdown_fut); + // Create a TcpListener + let incoming = + AddrIncoming::bind(&rpc_ip_port).unwrap_or_else(|err| panic!("Can't bind on {}: {}", rpc_ip_port, err)); + let acceptor = TlsAcceptor::builder() + .with_single_cert(cert_chain, privkey) + .unwrap_or_else(|err| panic!("Can't set certificate for TlsAcceptor: {}", err)) + .with_all_versions_alpn() + .with_incoming(incoming); + + let server = Server::builder(acceptor) + .http1_half_close(false) + .serve(make_svc!(TlsStream)) + .with_graceful_shutdown(get_shutdown_future!(ctx)); + + spawn_server!(server, ctx, rpc_ip_port.ip(), rpc_ip_port.port()); + } else { + let server = Server::try_bind(&rpc_ip_port) + .unwrap_or_else(|err| panic!("Can't bind on {}: {}", rpc_ip_port, err)) + .http1_half_close(false) + .serve(make_svc!(AddrStream)) + .with_graceful_shutdown(get_shutdown_future!(ctx)); - let server = server.then(|r| { - if let Err(err) = r { - error!("{}", err); - }; - futures::future::ready(()) - }); - - let rpc_ip_port = ctx.rpc_ip_port().unwrap(); - - // As it's said in the [issue](https://github.com/hyperium/tonic/issues/330): - // - // Aborting the server future will forcefully cancel all connections and not perform a proper drain/shutdown. - // While using the special shutdown methods on the server will allow hyper to gracefully drain all connections - // and gracefully close connections. - common::executor::spawn({ - log_tag!( - ctx, - "😉"; - fmt = ">>>>>>>>>> DEX stats {}:{} DEX stats API enabled at unixtime.{} <<<<<<<<<", - rpc_ip_port.ip(), - rpc_ip_port.port(), - gstuff::now_ms() / 1000 - ); - let _ = ctx.rpc_started.pin(true); - server - }); + spawn_server!(server, ctx, rpc_ip_port.ip(), rpc_ip_port.port()); + } } #[cfg(target_arch = "wasm32")] diff --git a/mm2src/mm2_net/Cargo.toml b/mm2src/mm2_net/Cargo.toml index eb9d624ff7..e567546f6c 100644 --- a/mm2src/mm2_net/Cargo.toml +++ b/mm2src/mm2_net/Cargo.toml @@ -32,5 +32,9 @@ web-sys = { version = "0.3.55", features = ["console", "CloseEvent", "DomExcepti js-sys = "0.3.27" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] +futures-util = { version = "0.3" } hyper = { version = "0.14.26", features = ["client", "http2", "server", "tcp"] } gstuff = { version = "0.7", features = ["nightly"] } +rustls = { version = "0.20", default-features = false } +tokio = { version = "1.20" } +tokio-rustls = { version = "0.23", default-features = false } diff --git a/mm2src/mm2_net/src/lib.rs b/mm2src/mm2_net/src/lib.rs index 30a951a0fb..99935bd25b 100644 --- a/mm2src/mm2_net/src/lib.rs +++ b/mm2src/mm2_net/src/lib.rs @@ -3,5 +3,6 @@ pub mod transport; #[cfg(not(target_arch = "wasm32"))] pub mod ip_addr; #[cfg(not(target_arch = "wasm32"))] pub mod native_http; +#[cfg(not(target_arch = "wasm32"))] pub mod native_tls; #[cfg(target_arch = "wasm32")] pub mod wasm_http; #[cfg(target_arch = "wasm32")] pub mod wasm_ws; diff --git a/mm2src/mm2_net/src/native_tls/README.md b/mm2src/mm2_net/src/native_tls/README.md new file mode 100644 index 0000000000..7c3983bf55 --- /dev/null +++ b/mm2src/mm2_net/src/native_tls/README.md @@ -0,0 +1,13 @@ +# HTTPS Support with TLSAcceptor and Builder + +This mod provides HTTPS support for [hyper](https://github.com/hyperium/hyper) using [rustls](https://github.com/rustls/rustls). The code in this mod is a port of the [acceptor](https://github.com/rustls/hyper-rustls/tree/286e1fa57ff5cac99994fab355f91c3454d6d83d/src/acceptor) module and the [acceptor.rs](https://github.com/rustls/hyper-rustls/blob/286e1fa57ff5cac99994fab355f91c3454d6d83d/src/acceptor.rs) file from the [hyper-rustls](https://github.com/rustls/hyper-rustls) repository at revision [286e1fa57ff5cac99994fab355f91c3454d6d83d](https://github.com/rustls/hyper-rustls/tree/286e1fa57ff5cac99994fab355f91c3454d6d83d). +> **Note:** Please be aware that the acceptor module was not available in the latest version of [hyper-rustls](https://docs.rs/hyper-rustls/0.24.0/hyper_rustls/index.html) at the time of writing this, the latest version was 0.24.0 at this time. + +## Compatibility + +The ported mod is compatible with hyper 0.14 and rustls 0.20. + +## Purpose + +The purpose of porting these files is to enable retrieving the remote address from the incoming connection and to expose the `TlsStream` type. +> **Note:** The following commit [7eca34d](https://github.com/KomodoPlatform/atomicDEX-API/pull/1861/commits/7eca34dd4621a7de0033f8a81cc11ad117aeb3c3) show the changes applied to the ported code. \ No newline at end of file diff --git a/mm2src/mm2_net/src/native_tls/acceptor.rs b/mm2src/mm2_net/src/native_tls/acceptor.rs new file mode 100644 index 0000000000..cd5c86db84 --- /dev/null +++ b/mm2src/mm2_net/src/native_tls/acceptor.rs @@ -0,0 +1,123 @@ +use crate::native_tls::builder::{AcceptorBuilder, WantsTlsConfig}; +use core::task::{Context, Poll}; +use futures_util::ready; +use hyper::server::{accept::Accept, + conn::{AddrIncoming, AddrStream}}; +use rustls::ServerConfig; +use std::future::Future; +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +enum State { + Handshaking(tokio_rustls::Accept), + Streaming(tokio_rustls::server::TlsStream), +} + +// tokio_rustls::server::TlsStream doesn't expose constructor methods, +// so we have to TlsAcceptor::accept and handshake to have access to it +// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first +pub struct TlsStream { + state: State, + remote_addr: SocketAddr, +} + +impl TlsStream { + fn new(stream: AddrStream, config: Arc) -> TlsStream { + let remote_addr = stream.remote_addr(); + let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); + TlsStream { + state: State::Handshaking(accept), + remote_addr, + } + } + + #[inline] + pub fn remote_addr(&self) -> SocketAddr { self.remote_addr } +} + +impl AsyncRead for TlsStream { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_read(cx, buf); + pin.state = State::Streaming(stream); + result + }, + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_write(cx, buf); + pin.state = State::Streaming(stream); + result + }, + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +/// A TLS acceptor that can be used with hyper servers. +pub struct TlsAcceptor { + pub(crate) config: Arc, + pub(crate) incoming: AddrIncoming, +} + +/// An Acceptor for the `https` scheme. +impl TlsAcceptor { + /// Provides a builder for a `TlsAcceptor`. + pub fn builder() -> AcceptorBuilder { AcceptorBuilder::new() } + /// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`. + pub fn new(config: Arc, incoming: AddrIncoming) -> TlsAcceptor { TlsAcceptor { config, incoming } } +} + +impl From<(C, I)> for TlsAcceptor +where + C: Into>, + I: Into, +{ + fn from((config, incoming): (C, I)) -> TlsAcceptor { TlsAcceptor::new(config.into(), incoming.into()) } +} + +impl Accept for TlsAcceptor { + type Conn = TlsStream; + type Error = io::Error; + + fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + let pin = self.get_mut(); + match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} diff --git a/mm2src/mm2_net/src/native_tls/builder.rs b/mm2src/mm2_net/src/native_tls/builder.rs new file mode 100644 index 0000000000..415a998b45 --- /dev/null +++ b/mm2src/mm2_net/src/native_tls/builder.rs @@ -0,0 +1,92 @@ +use crate::native_tls::TlsAcceptor; +use hyper::server::conn::AddrIncoming; +use rustls::ServerConfig; +use std::sync::Arc; + +/// Builder for [`TlsAcceptor`] +pub struct AcceptorBuilder(State); + +/// State of a builder that needs a TLS client config next +pub struct WantsTlsConfig(()); + +impl AcceptorBuilder { + #[inline] + /// Creates a new [`AcceptorBuilder`] + pub fn new() -> Self { Self(WantsTlsConfig(())) } + + #[inline] + /// Passes a rustls [`ServerConfig`] to configure the TLS connection + pub fn with_tls_config(self, config: ServerConfig) -> AcceptorBuilder { + AcceptorBuilder(WantsAlpn(config)) + } + + /// Use rustls [defaults][with_safe_defaults] without [client authentication][with_no_client_auth] + /// + /// [with_safe_defaults]: rustls::ConfigBuilder::with_safe_defaults + /// [with_no_client_auth]: rustls::ConfigBuilder::with_no_client_auth + pub fn with_single_cert( + self, + cert_chain: Vec, + key_der: rustls::PrivateKey, + ) -> Result, rustls::Error> { + Ok(AcceptorBuilder(WantsAlpn( + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, key_der)?, + ))) + } +} + +impl Default for AcceptorBuilder { + fn default() -> Self { Self::new() } +} + +/// State of a builder that needs a incoming address next +pub struct WantsAlpn(ServerConfig); + +impl AcceptorBuilder { + /// Configure ALPN accept protocols in order + pub fn with_alpn_protocols(mut self, alpn_protocols: Vec>) -> AcceptorBuilder { + self.0 .0.alpn_protocols = alpn_protocols; + AcceptorBuilder(WantsIncoming(self.0 .0)) + } + + /// Configure ALPN to accept HTTP/2 + pub fn with_http2_alpn(mut self) -> AcceptorBuilder { + self.0 .0.alpn_protocols = vec![b"h2".to_vec()]; + AcceptorBuilder(WantsIncoming(self.0 .0)) + } + + /// Configure ALPN to accept HTTP/1.0 + pub fn with_http10_alpn(mut self) -> AcceptorBuilder { + self.0 .0.alpn_protocols = vec![b"http/1.0".to_vec()]; + AcceptorBuilder(WantsIncoming(self.0 .0)) + } + + /// Configure ALPN to accept HTTP/1.1 + pub fn with_http11_alpn(mut self) -> AcceptorBuilder { + self.0 .0.alpn_protocols = vec![b"http/1.1".to_vec()]; + AcceptorBuilder(WantsIncoming(self.0 .0)) + } + + /// Configure ALPN to accept HTTP/2, HTTP/1.1, HTTP/1.0 in that order. + pub fn with_all_versions_alpn(mut self) -> AcceptorBuilder { + self.0 .0.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + AcceptorBuilder(WantsIncoming(self.0 .0)) + } +} + +/// State of a builder that needs a incoming address next +pub struct WantsIncoming(ServerConfig); + +impl AcceptorBuilder { + /// Passes a [`AddrIncoming`] to configure the TLS connection and + /// creates the [`TlsAcceptor`] + pub fn with_incoming(self, incoming: impl Into) -> TlsAcceptor { + TlsAcceptor { + config: Arc::new(self.0 .0), + incoming: incoming.into(), + } + } +} diff --git a/mm2src/mm2_net/src/native_tls/mod.rs b/mm2src/mm2_net/src/native_tls/mod.rs new file mode 100644 index 0000000000..46970f19c1 --- /dev/null +++ b/mm2src/mm2_net/src/native_tls/mod.rs @@ -0,0 +1,4 @@ +mod acceptor; +pub use acceptor::{TlsAcceptor, TlsStream}; + +mod builder;