diff --git a/Cargo.lock b/Cargo.lock index c66c23254..628dd8399 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,6 +184,12 @@ version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1485d4d2cc45e7b201ee3767015c96faa5904387c9d87c6efdd0fb511f12d305" +[[package]] +name = "arc-swap" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "983cd8b9d4b02a6dc6ffa557262eb5858a27a0038ffffe21a0f133eaa819a164" + [[package]] name = "arrayref" version = "0.3.6" @@ -831,6 +837,26 @@ dependencies = [ "mime", ] +[[package]] +name = "axum-server" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8456dab8f11484979a86651da8e619b355ede5d61a160755155f6c344bd18c47" +dependencies = [ + "arc-swap", + "bytes 1.1.0", + "futures-util", + "http 0.2.8", + "http-body", + "hyper", + "pin-project-lite 0.2.9", + "rustls 0.20.6", + "rustls-pemfile 1.0.1", + "tokio", + "tokio-rustls 0.23.4", + "tower-service", +] + [[package]] name = "base-x" version = "0.2.11" @@ -2802,9 +2828,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "496ce29bb5a52785b44e0f7ca2847ae0bb839c9bd28f69acac9b99d461c0c04c" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" @@ -2820,9 +2846,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.20" +version = "0.14.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" +checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" dependencies = [ "bytes 1.1.0", "futures-channel", @@ -5475,6 +5501,7 @@ dependencies = [ "anyhow", "async-trait", "axum", + "axum-server", "base64 0.13.0", "bollard", "chrono", @@ -5490,9 +5517,12 @@ dependencies = [ "opentelemetry", "opentelemetry-datadog", "opentelemetry-http", + "pem", "portpicker", "rand 0.8.5", "rcgen", + "rustls 0.20.6", + "rustls-pemfile 1.0.1", "serde", "serde_json", "shuttle-common", @@ -5677,9 +5707,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.4.4" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66d72b759436ae32898a2af0a14218dbf55efde3feeb170eb623637db85ee1e0" +checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" dependencies = [ "libc", "winapi", diff --git a/Makefile b/Makefile index c41a139da..f26084d5b 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,8 @@ APPS_FQDN=shuttleapp.rs DB_FQDN=db.shuttle.rs CONTAINER_REGISTRY=public.ecr.aws/shuttle DD_ENV=production +# make sure we only ever go to production with `--tls=enable` +USE_TLS=enable else DOCKER_COMPOSE_FILES=-f docker-compose.yml -f docker-compose.dev.yml STACK=shuttle-dev @@ -47,6 +49,7 @@ APPS_FQDN=unstable.shuttleapp.rs DB_FQDN=db.unstable.shuttle.rs CONTAINER_REGISTRY=public.ecr.aws/shuttle-dev DD_ENV=unstable +USE_TLS?=disable endif POSTGRES_EXTRA_PATH?=./extras/postgres @@ -54,7 +57,7 @@ POSTGRES_TAG?=14 RUST_LOG?=debug -DOCKER_COMPOSE_ENV=STACK=$(STACK) BACKEND_TAG=$(TAG) PROVISIONER_TAG=$(TAG) POSTGRES_TAG=${POSTGRES_TAG} APPS_FQDN=$(APPS_FQDN) DB_FQDN=$(DB_FQDN) POSTGRES_PASSWORD=$(POSTGRES_PASSWORD) RUST_LOG=$(RUST_LOG) CONTAINER_REGISTRY=$(CONTAINER_REGISTRY) MONGO_INITDB_ROOT_USERNAME=$(MONGO_INITDB_ROOT_USERNAME) MONGO_INITDB_ROOT_PASSWORD=$(MONGO_INITDB_ROOT_PASSWORD) DD_ENV=$(DD_ENV) +DOCKER_COMPOSE_ENV=STACK=$(STACK) BACKEND_TAG=$(TAG) PROVISIONER_TAG=$(TAG) POSTGRES_TAG=${POSTGRES_TAG} APPS_FQDN=$(APPS_FQDN) DB_FQDN=$(DB_FQDN) POSTGRES_PASSWORD=$(POSTGRES_PASSWORD) RUST_LOG=$(RUST_LOG) CONTAINER_REGISTRY=$(CONTAINER_REGISTRY) MONGO_INITDB_ROOT_USERNAME=$(MONGO_INITDB_ROOT_USERNAME) MONGO_INITDB_ROOT_PASSWORD=$(MONGO_INITDB_ROOT_PASSWORD) DD_ENV=$(DD_ENV) USE_TLS=$(USE_TLS) .PHONY: images clean src up down deploy shuttle-% postgres docker-compose.rendered.yml test bump-% deploy-examples publish publish-% --validate-version diff --git a/docker-compose.yml b/docker-compose.yml index e9e0fb672..83cd7a6d0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ services: depends_on: - provisioner ports: + - 7999:7999 - 8000:8000 - 8001:8001 deploy: @@ -42,16 +43,18 @@ services: environment: - RUST_LOG=${RUST_LOG} command: - - "--state=/var/lib/shuttle/gateway.sqlite" + - "--state=/var/lib/shuttle" - "start" - "--control=0.0.0.0:8001" - "--user=0.0.0.0:8000" + - "--bouncer=0.0.0.0:7999" - "--image=${CONTAINER_REGISTRY}/deployer:${BACKEND_TAG}" - "--prefix=shuttle_" - "--network-name=${STACK}_user-net" - "--docker-host=/var/run/docker.sock" - "--provisioner-host=provisioner" - "--proxy-fqdn=${APPS_FQDN}" + - "--use-tls=${USE_TLS}" healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8001"] interval: 1m diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index e2779cead..448173aba 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -7,12 +7,20 @@ publish = false [dependencies] acme2 = "0.5.1" async-trait = "0.1.52" + axum = { version = "0.5.8", features = [ "headers" ] } +axum-server = { version = "0.4.4", features = [ "tls-rustls" ] } +rustls = { version = "0.20.6" } +rustls-pemfile = { version = "1.0.1" } +pem = "1.1.0" + base64 = "0.13" bollard = "0.13" chrono = "0.4" clap = { version = "4.0.0", features = [ "derive" ] } + fqdn = "0.2.2" + futures = "0.3.21" http = "0.2.8" hyper = { version = "0.14.19", features = [ "stream" ] } diff --git a/gateway/src/custom_domain.rs b/gateway/src/acme.rs similarity index 75% rename from gateway/src/custom_domain.rs rename to gateway/src/acme.rs index 052cfe217..7ba9dd60d 100644 --- a/gateway/src/custom_domain.rs +++ b/gateway/src/acme.rs @@ -5,21 +5,24 @@ use std::time::Duration; use axum::body::boxed; use axum::response::Response; -use fqdn::Fqdn; use futures::future::BoxFuture; +use hyper::server::conn::AddrStream; use hyper::{Body, Request}; use instant_acme::{ - Account, AccountCredentials, Authorization, AuthorizationStatus, ChallengeType, Identifier, - KeyAuthorization, LetsEncrypt, NewAccount, NewOrder, Order, OrderStatus, + Account, AccountCredentials, Authorization, AuthorizationStatus, Challenge, ChallengeType, + Identifier, KeyAuthorization, LetsEncrypt, NewAccount, NewOrder, Order, OrderStatus, }; use rcgen::{Certificate, CertificateParams, DistinguishedName}; use tokio::sync::Mutex; use tokio::time::sleep; use tower::{Layer, Service}; -use tracing::{error, trace}; +use tracing::{error, trace, warn}; +use crate::proxy::AsResponderTo; use crate::{Error, ProjectName}; +const MAX_RETRIES: usize = 15; + #[derive(Debug, Eq, PartialEq)] pub struct CustomDomain { pub project_name: ProjectName, @@ -86,14 +89,15 @@ impl AcmeClient { Ok(credentials) } - /// Create a certificate and return it with the keys used to sign it + /// Create an ACME-signed certificate and return it and its + /// associated PEM-encoded private key pub async fn create_certificate( &self, - fqdn: &Fqdn, + identifier: &str, + challenge_type: ChallengeType, credentials: AccountCredentials<'_>, - ) -> Result<(String, Certificate), AcmeClientError> { - let fqdn = fqdn.to_string(); - trace!(fqdn, "requesting acme certificate"); + ) -> Result<(String, String), AcmeClientError> { + trace!(identifier, "requesting acme certificate"); let account = Account::from_credentials(credentials).map_err(|error| { error!( @@ -105,7 +109,7 @@ impl AcmeClient { let (mut order, state) = account .new_order(&NewOrder { - identifiers: &[Identifier::Dns(fqdn.to_string())], + identifiers: &[Identifier::Dns(identifier.to_string())], }) .await .map_err(|error| { @@ -128,12 +132,11 @@ impl AcmeClient { trace!(?authorization, "got authorization"); - self.complete_challenge(authorization, &mut order).await?; - - let Identifier::Dns(identifier) = &authorization.identifier; + self.complete_challenge(challenge_type, authorization, &mut order) + .await?; let certificate = { - let mut params = CertificateParams::new(vec![identifier.to_string()]); + let mut params = CertificateParams::new(vec![identifier.to_owned()]); params.distinguished_name = DistinguishedName::new(); Certificate::from_params(params).map_err(|error| { error!(%error, "failed to create certificate"); @@ -153,46 +156,26 @@ impl AcmeClient { AcmeClientError::OrderFinalizing })?; - Ok((certificate_chain, certificate)) + Ok((certificate_chain, certificate.serialize_private_key_pem())) } - async fn complete_challenge( - &self, + fn find_challenge( + ty: ChallengeType, authorization: &Authorization, - order: &mut Order, - ) -> Result<(), AcmeClientError> { - // Don't complete challenge for orders that are already valid - if let AuthorizationStatus::Valid = authorization.status { - return Ok(()); - } - - let challenge = authorization + ) -> Result<&Challenge, AcmeClientError> { + authorization .challenges .iter() - .find(|c| c.r#type == ChallengeType::Http01) + .find(|c| c.r#type == ty) .ok_or_else(|| { error!("http-01 challenge not found"); - AcmeClientError::MissingHttp01Challenge - })?; - - trace!(?challenge, "will complete challenge"); - - self.add_http01_challenge_authorization( - challenge.token.clone(), - order.key_authorization(challenge), - ) - .await; - - order - .set_challenge_ready(&challenge.url) - .await - .map_err(|error| { - error!(%error, "failed to mark challenge as ready"); - AcmeClientError::SetReadyFailed - })?; + AcmeClientError::MissingChallenge + }) + } + async fn wait_for_termination(&self, order: &mut Order) -> Result<(), AcmeClientError> { // Exponential backoff until order changes status - let mut tries = 1u8; + let mut tries = 1; let mut delay = Duration::from_millis(250); let state = loop { sleep(delay).await; @@ -205,19 +188,15 @@ impl AcmeClient { match state.status { OrderStatus::Ready => break state, OrderStatus::Invalid => { - self.remove_http01_challenge_authorization(&challenge.token) - .await; return Err(AcmeClientError::ChallengeInvalid); } OrderStatus::Pending => { delay *= 2; tries += 1; - if tries < 5 { + if tries < MAX_RETRIES { trace!(?state, tries, attempt_in=?delay, "order not yet ready"); } else { - error!(?state, tries, "order not ready in 5 tries"); - self.remove_http01_challenge_authorization(&challenge.token) - .await; + error!(?state, tries, "order not ready in {MAX_RETRIES} tries"); return Err(AcmeClientError::ChallengeTimeout); } } @@ -225,12 +204,85 @@ impl AcmeClient { } }; - trace!(challenge.token, ?state, "challenge completed"); + trace!(?state, "challenge completed"); + + Ok(()) + } + + async fn complete_challenge( + &self, + ty: ChallengeType, + authorization: &Authorization, + order: &mut Order, + ) -> Result<(), AcmeClientError> { + // Don't complete challenge for orders that are already valid + if let AuthorizationStatus::Valid = authorization.status { + return Ok(()); + } + let challenge = Self::find_challenge(ty, authorization)?; + match ty { + ChallengeType::Http01 => self.complete_http01_challenge(challenge, order).await, + ChallengeType::Dns01 => { + self.complete_dns01_challenge(&authorization.identifier, challenge, order) + .await + } + _ => Err(AcmeClientError::ChallengeNotSupported), + } + } + + async fn complete_dns01_challenge( + &self, + identifier: &Identifier, + challenge: &Challenge, + order: &mut Order, + ) -> Result<(), AcmeClientError> { + let Identifier::Dns(domain) = identifier; + + let digest = order.key_authorization(challenge).dns_value(); + warn!("dns-01 challenge: _acme-challenge.{domain} 300 IN TXT \"{digest}\""); + + // Wait 120 secs to insert the record manually and for it to + // propagate before moving on + sleep(Duration::from_secs(120)).await; + + order + .set_challenge_ready(&challenge.url) + .await + .map_err(|error| { + error!(%error, "failed to mark challenge as ready"); + AcmeClientError::SetReadyFailed + })?; + + self.wait_for_termination(order).await + } + + async fn complete_http01_challenge( + &self, + challenge: &Challenge, + order: &mut Order, + ) -> Result<(), AcmeClientError> { + trace!(?challenge, "will complete challenge"); + + self.add_http01_challenge_authorization( + challenge.token.clone(), + order.key_authorization(challenge), + ) + .await; + + order + .set_challenge_ready(&challenge.url) + .await + .map_err(|error| { + error!(%error, "failed to mark challenge as ready"); + AcmeClientError::SetReadyFailed + })?; + + let res = self.wait_for_termination(order).await; self.remove_http01_challenge_authorization(&challenge.token) .await; - Ok(()) + res } } @@ -245,7 +297,8 @@ pub enum AcmeClientError { FetchingState, OrderCreation, OrderFinalizing, - MissingHttp01Challenge, + MissingChallenge, + ChallengeNotSupported, Serializing, SetReadyFailed, } @@ -278,6 +331,18 @@ pub struct ChallengeResponder { inner: S, } +impl<'r, S> AsResponderTo<&'r AddrStream> for ChallengeResponder +where + S: AsResponderTo<&'r AddrStream>, +{ + fn as_responder_to(&self, req: &'r AddrStream) -> Self { + Self { + client: self.client.clone(), + inner: self.inner.as_responder_to(req), + } + } +} + impl Service> for ChallengeResponder where S: Service, Response = Response, Error = Error> + Send + 'static, diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index e84702fbf..4edfcd999 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -1,3 +1,5 @@ +use std::io::Cursor; +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -8,8 +10,9 @@ use axum::response::Response; use axum::routing::{any, get, post}; use axum::{Json as AxumJson, Router}; use fqdn::FQDN; +use futures::Future; use http::StatusCode; -use instant_acme::AccountCredentials; +use instant_acme::{AccountCredentials, ChallengeType}; use serde::{Deserialize, Serialize}; use shuttle_common::models::error::ErrorKind; use shuttle_common::models::{project, user}; @@ -17,9 +20,10 @@ use tokio::sync::mpsc::Sender; use tower_http::trace::TraceLayer; use tracing::{debug, debug_span, field, Span}; +use crate::acme::AcmeClient; use crate::auth::{Admin, ScopedUser, User}; -use crate::custom_domain::AcmeClient; use crate::task::{self, BoxedTask}; +use crate::tls::GatewayCertResolver; use crate::worker::WORKER_QUEUE_SIZE; use crate::{AccountName, Error, GatewayService, ProjectName}; @@ -204,47 +208,85 @@ async fn request_acme_certificate( _: Admin, Extension(service): Extension>, Extension(acme_client): Extension, + Extension(resolver): Extension>, Path((project_name, fqdn)): Path<(ProjectName, String)>, AxumJson(credentials): AxumJson>, ) -> Result { let fqdn: FQDN = fqdn .parse() .map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?; - let (chain, async_keys) = acme_client.create_certificate(&fqdn, credentials).await?; - let private_key = async_keys.serialize_private_key_pem(); + + let (certs, private_key) = acme_client + .create_certificate(&fqdn.to_string(), ChallengeType::Http01, credentials) + .await?; service - .create_custom_domain(project_name, &fqdn, &chain, &private_key) + .create_custom_domain(project_name, &fqdn, &certs, &private_key) + .await?; + + let mut buf = Vec::new(); + buf.extend(certs.as_bytes()); + buf.extend(private_key.as_bytes()); + resolver + .serve_pem(&fqdn.to_string(), Cursor::new(buf)) .await?; - Ok("Certificate created".to_string()) + Ok("certificate created".to_string()) } -pub fn make_api( - service: Arc, - acme_client: AcmeClient, - sender: Sender, -) -> Router { - debug!("making api route"); - - Router::::new() - .route( - "/", - get(get_status) - ) - .route( - "/projects/:project", - get(get_project).delete(delete_project).post(post_project) - ) - .route("/users/:account_name", get(get_user).post(post_user)) - .route("/projects/:project/*any", any(route_project)) - .route("/admin/revive", post(revive_projects)) - .route("/admin/acme/:email", post(create_acme_account)) - .route("/admin/acme/request/:project_name/:fqdn", post(request_acme_certificate)) - .layer(Extension(service)) - .layer(Extension(acme_client)) - .layer(Extension(sender)) - .layer( +pub struct ApiBuilder { + router: Router, + service: Option>, + sender: Option>, + bind: Option, +} + +impl Default for ApiBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ApiBuilder { + pub fn new() -> Self { + Self { + router: Router::new(), + service: None, + sender: None, + bind: None, + } + } + + pub fn with_acme(mut self, acme: AcmeClient, resolver: Arc) -> Self { + self.router = self + .router + .route("/admin/acme/:email", post(create_acme_account)) + .route( + "/admin/acme/request/:project_name/:fqdn", + post(request_acme_certificate), + ) + .layer(Extension(acme)) + .layer(Extension(resolver)); + self + } + + pub fn with_service(mut self, service: Arc) -> Self { + self.service = Some(service); + self + } + + pub fn with_sender(mut self, sender: Sender) -> Self { + self.sender = Some(sender); + self + } + + pub fn binding_to(mut self, addr: SocketAddr) -> Self { + self.bind = Some(addr); + self + } + + pub fn with_default_traces(mut self) -> Self { + self.router = self.router.layer( TraceLayer::new_for_http() .make_span_with(|request: &Request| { debug_span!("request", http.uri = %request.uri(), http.method = %request.method(), http.status_code = field::Empty, account.name = field::Empty, account.project = field::Empty) @@ -255,7 +297,37 @@ pub fn make_api( debug!(latency = format_args!("{} ns", latency.as_nanos()), "finished processing request"); }, ), - ) + ); + self + } + + pub fn with_default_routes(mut self) -> Self { + self.router = self + .router + .route("/", get(get_status)) + .route( + "/projects/:project", + get(get_project).delete(delete_project).post(post_project), + ) + .route("/users/:account_name", get(get_user).post(post_user)) + .route("/projects/:project/*any", any(route_project)) + .route("/admin/revive", post(revive_projects)); + self + } + + pub fn into_router(self) -> Router { + let service = self.service.expect("a GatewayService is required"); + let sender = self.sender.expect("a task Sender is required"); + self.router + .layer(Extension(service)) + .layer(Extension(sender)) + } + + pub fn serve(self) -> impl Future> { + let bind = self.bind.expect("a socket address to bind to is required"); + let router = self.into_router(); + axum::Server::bind(&bind).serve(router.into_make_service()) + } } #[cfg(test)] @@ -287,7 +359,11 @@ pub mod tests { } }); - let mut router = make_api(Arc::clone(&service), world.acme_client(), sender); + let mut router = ApiBuilder::new() + .with_service(Arc::clone(&service)) + .with_sender(sender) + .with_default_routes() + .into_router(); let neo = service.create_user("neo".parse().unwrap()).await?; @@ -431,7 +507,11 @@ pub mod tests { } }); - let mut router = make_api(Arc::clone(&service), world.acme_client(), sender); + let mut router = ApiBuilder::new() + .with_service(Arc::clone(&service)) + .with_sender(sender) + .with_default_routes() + .into_router(); let get_neo = || { Request::builder() @@ -527,7 +607,11 @@ pub mod tests { } }); - let mut router = make_api(Arc::clone(&service), world.acme_client(), sender.clone()); + let mut router = ApiBuilder::new() + .with_service(Arc::clone(&service)) + .with_sender(sender) + .with_default_routes() + .into_router(); let get_status = || { Request::builder() diff --git a/gateway/src/api/mod.rs b/gateway/src/api/mod.rs index e7eb5861f..27f571e54 100644 --- a/gateway/src/api/mod.rs +++ b/gateway/src/api/mod.rs @@ -1,2 +1 @@ pub mod latest; -pub use latest::make_api; diff --git a/gateway/src/args.rs b/gateway/src/args.rs index 8646cf5b1..fd720025b 100644 --- a/gateway/src/args.rs +++ b/gateway/src/args.rs @@ -1,20 +1,26 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, path::PathBuf}; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use fqdn::FQDN; use crate::auth::Key; #[derive(Parser, Debug)] pub struct Args { - /// Uri to the `.sqlite` file used to store state - #[arg(long, default_value = "./gateway.sqlite")] - pub state: String, + /// Where to store gateway state (such as sqlite state, and certs) + #[arg(long, default_value = "./")] + pub state: PathBuf, #[command(subcommand)] pub command: Commands, } +#[derive(Debug, Clone, Copy, ValueEnum)] +pub enum UseTls { + Disable, + Enable, +} + #[derive(Subcommand, Debug)] pub enum Commands { Start(StartArgs), @@ -26,9 +32,15 @@ pub struct StartArgs { /// Address to bind the control plane to #[arg(long, default_value = "127.0.0.1:8001")] pub control: SocketAddr, - /// Address to bind the user plane to + /// Address to bind the bouncer service to + #[arg(long, default_value = "127.0.0.1:7999")] + pub bouncer: SocketAddr, + /// Address to bind the user proxy to #[arg(long, default_value = "127.0.0.1:8000")] pub user: SocketAddr, + /// Allows to disable the use of TLS in the user proxy service (DANGEROUS) + #[arg(long, default_value = "enable")] + pub use_tls: UseTls, #[command(flatten)] pub context: ContextArgs, } @@ -60,7 +72,7 @@ pub struct ContextArgs { #[arg(long, default_value = "shuttle_default")] pub network_name: String, /// FQDN where the proxy can be reached at - #[arg(long)] + #[arg(long, default_value = "shuttleapp.rs")] pub proxy_fqdn: FQDN, /// The path to the docker daemon socket #[arg(long, default_value = "/var/run/docker.sock")] diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 63fd84a7b..a36191533 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -8,24 +8,25 @@ use std::io; use std::pin::Pin; use std::str::FromStr; +use acme::AcmeClientError; use axum::response::{IntoResponse, Response}; use axum::Json; use bollard::Docker; -use custom_domain::AcmeClientError; use futures::prelude::*; use serde::{Deserialize, Deserializer, Serialize}; use shuttle_common::models::error::{ApiError, ErrorKind}; use tokio::sync::mpsc::error::SendError; use tracing::error; +pub mod acme; pub mod api; pub mod args; pub mod auth; -pub mod custom_domain; pub mod project; pub mod proxy; pub mod service; pub mod task; +pub mod tls; pub mod worker; use crate::service::{ContainerSettings, GatewayService}; @@ -83,6 +84,12 @@ impl From> for Error { } } +impl From for Error { + fn from(_: io::Error) -> Self { + Self::from(ErrorKind::Internal) + } +} + impl From for Error { fn from(error: AcmeClientError) -> Self { Self::source(ErrorKind::Internal, error) @@ -298,11 +305,11 @@ pub mod tests { use sqlx::SqlitePool; use tokio::sync::mpsc::channel; - use crate::api::make_api; - use crate::args::{ContextArgs, StartArgs}; + use crate::acme::AcmeClient; + use crate::api::latest::ApiBuilder; + use crate::args::{ContextArgs, StartArgs, UseTls}; use crate::auth::User; - use crate::custom_domain::AcmeClient; - use crate::proxy::make_proxy; + use crate::proxy::UserServiceBuilder; use crate::service::{ContainerSettings, GatewayService, MIGRATIONS}; use crate::worker::Worker; use crate::DockerContext; @@ -520,8 +527,10 @@ pub mod tests { let control: i16 = Uniform::from(9000..10000).sample(&mut rand::thread_rng()); let user = control + 1; + let bouncer = user + 1; let control = format!("127.0.0.1:{control}").parse().unwrap(); let user = format!("127.0.0.1:{user}").parse().unwrap(); + let bouncer = format!("127.0.0.1:{bouncer}").parse().unwrap(); let prefix = format!( "shuttle_test_{}_", @@ -541,6 +550,8 @@ pub mod tests { let args = StartArgs { control, user, + bouncer, + use_tls: UseTls::Disable, context: ContextArgs { docker_host, image, @@ -584,12 +595,8 @@ pub mod tests { Client::new(addr).with_hyper_client(self.hyper.clone()) } - pub fn fqdn(&self) -> String { - self.args() - .proxy_fqdn - .to_string() - .trim_end_matches('.') - .to_string() + pub fn fqdn(&self) -> FQDN { + self.args().proxy_fqdn } pub fn acme_client(&self) -> AcmeClient { @@ -644,21 +651,26 @@ pub mod tests { } }; - let api = make_api(Arc::clone(&service), world.acme_client(), log_out); let api_addr = format!("127.0.0.1:{}", base_port).parse().unwrap(); - let serve_api = hyper::Server::bind(&api_addr).serve(api.into_make_service()); let api_client = world.client(api_addr); - - let proxy = make_proxy(Arc::clone(&service), world.acme_client(), world.fqdn()); - let proxy_addr = format!("127.0.0.1:{}", base_port + 1).parse().unwrap(); - let serve_proxy = hyper::Server::bind(&proxy_addr).serve(proxy); - let proxy_client = world.client(proxy_addr); + let api = ApiBuilder::new() + .with_service(Arc::clone(&service)) + .with_sender(log_out) + .with_default_routes() + .binding_to(api_addr); + + let user_addr: SocketAddr = format!("127.0.0.1:{}", base_port + 1).parse().unwrap(); + let proxy_client = world.client(user_addr); + let user = UserServiceBuilder::new() + .with_service(Arc::clone(&service)) + .with_public(world.fqdn()) + .with_user_proxy_binding_to(user_addr); let _gateway = tokio::spawn(async move { tokio::select! { _ = worker.start() => {}, - _ = serve_api => {}, - _ = serve_proxy => {} + _ = api.serve() => {}, + _ = user.serve() => {} } }); diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 6e28c663e..f4eed5d24 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,21 +1,25 @@ use clap::Parser; +use fqdn::FQDN; use futures::prelude::*; +use instant_acme::{AccountCredentials, ChallengeType}; use opentelemetry::global; -use shuttle_gateway::args::{Args, Commands, InitArgs}; +use shuttle_gateway::acme::AcmeClient; +use shuttle_gateway::api::latest::ApiBuilder; +use shuttle_gateway::args::StartArgs; +use shuttle_gateway::args::{Args, Commands, InitArgs, UseTls}; use shuttle_gateway::auth::Key; -use shuttle_gateway::custom_domain::AcmeClient; -use shuttle_gateway::proxy::make_proxy; +use shuttle_gateway::proxy::UserServiceBuilder; use shuttle_gateway::service::{GatewayService, MIGRATIONS}; use shuttle_gateway::task; +use shuttle_gateway::tls::{make_tls_acceptor, ChainAndPrivateKey}; use shuttle_gateway::worker::Worker; -use shuttle_gateway::{api::make_api, args::StartArgs}; use sqlx::migrate::MigrateDatabase; use sqlx::{query, Sqlite, SqlitePool}; -use std::io; -use std::path::Path; +use std::io::{self, Cursor}; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; #[tokio::main(flavor = "multi_thread")] @@ -43,8 +47,11 @@ async fn main() -> io::Result<()> { .with(opentelemetry) .init(); - if !Path::new(&args.state).exists() { - Sqlite::create_database(&args.state).await.unwrap(); + let db_path = args.state.join("gateway.sqlite"); + let db_uri = db_path.to_str().unwrap(); + + if !db_path.exists() { + Sqlite::create_database(db_uri).await.unwrap(); } info!( @@ -53,23 +60,17 @@ async fn main() -> io::Result<()> { .unwrap() .to_string_lossy() ); - let db = SqlitePool::connect(&args.state).await.unwrap(); + let db = SqlitePool::connect(db_uri).await.unwrap(); MIGRATIONS.run(&db).await.unwrap(); match args.command { - Commands::Start(start_args) => start(db, start_args).await, + Commands::Start(start_args) => start(db, args.state, start_args).await, Commands::Init(init_args) => init(db, init_args).await, } } -async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> { - let fqdn = args - .context - .proxy_fqdn - .to_string() - .trim_end_matches('.') - .to_string(); +async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { let gateway = Arc::new(GatewayService::init(args.context.clone(), db).await); let worker = Worker::new(); @@ -125,20 +126,48 @@ async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> { let acme_client = AcmeClient::new(); - let api = make_api(Arc::clone(&gateway), acme_client.clone(), sender); - - let api_handle = tokio::spawn(axum::Server::bind(&args.control).serve(api.into_make_service())); + let mut api_builder = ApiBuilder::new() + .with_service(Arc::clone(&gateway)) + .with_sender(sender) + .binding_to(args.control); + + let mut user_builder = UserServiceBuilder::new() + .with_service(Arc::clone(&gateway)) + .with_public(args.context.proxy_fqdn.clone()) + .with_user_proxy_binding_to(args.user) + .with_bouncer(args.bouncer); + + if let UseTls::Enable = args.use_tls { + let (resolver, tls_acceptor) = make_tls_acceptor(); + + user_builder = user_builder + .with_acme(acme_client.clone()) + .with_tls(tls_acceptor); + + api_builder = api_builder.with_acme(acme_client.clone(), resolver.clone()); + + tokio::spawn(async move { + // make sure we have a certificate for ourselves + let certs = init_certs(fs, args.context.proxy_fqdn.clone(), acme_client.clone()).await; + resolver.serve_default_der(certs).await.unwrap(); + }); + } else { + warn!("TLS is disabled in the proxy service. This is only acceptable in testing, and should *never* be used in deployments."); + }; - let proxy = make_proxy(gateway, acme_client, fqdn); + let api_handle = api_builder + .with_default_routes() + .with_default_traces() + .serve(); - let proxy_handle = tokio::spawn(hyper::Server::bind(&args.user).serve(proxy)); + let user_handle = user_builder.serve(); debug!("starting up all services"); tokio::select!( _ = worker_handle => info!("worker handle finished"), _ = api_handle => error!("api handle finished"), - _ = proxy_handle => error!("proxy handle finished"), + _ = user_handle => error!("user handle finished"), _ = ambulance_handle => error!("ambulance handle finished"), ); @@ -161,3 +190,47 @@ async fn init(db: SqlitePool, args: InitArgs) -> io::Result<()> { println!("`{}` created as super user with key: {key}", args.name); Ok(()) } + +async fn init_certs>(fs: P, public: FQDN, acme: AcmeClient) -> ChainAndPrivateKey { + let tls_path = fs.as_ref().join("ssl.pem"); + + match ChainAndPrivateKey::load_pem(&tls_path) { + Ok(valid) => valid, + Err(_) => { + let creds_path = fs.as_ref().join("acme.json"); + warn!( + "no valid certificate found at {}, creating one...", + tls_path.display() + ); + + if !creds_path.exists() { + panic!( + "no ACME credentials found at {}, cannot continue with certificate creation", + creds_path.display() + ); + } + + let creds = std::fs::File::open(creds_path).unwrap(); + let creds: AccountCredentials = serde_json::from_reader(&creds).unwrap(); + + let identifier = format!("*.{public}"); + + // Use ::Dns01 challenge because that's the only supported + // challenge type for wildcard domains + let (chain, private_key) = acme + .create_certificate(&identifier, ChallengeType::Dns01, creds) + .await + .unwrap(); + + let mut buf = Vec::new(); + buf.extend(chain.as_bytes()); + buf.extend(private_key.as_bytes()); + + let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).unwrap(); + + certs.clone().save_pem(&tls_path).unwrap(); + + certs + } + } +} diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index 600a511eb..0dd0c698e 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -1,13 +1,19 @@ +use std::convert::Infallible; use std::future::Future; +use std::io; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use axum::body::HttpBody; +use axum::headers::{HeaderMapExt, Host}; use axum::response::{IntoResponse, Response}; +use axum_server::accept::DefaultAcceptor; +use axum_server::tls_rustls::RustlsAcceptor; +use fqdn::{fqdn, FQDN}; +use futures::future::{ready, Ready}; use futures::prelude::*; -use hyper::body::Body; +use hyper::body::{Body, HttpBody}; use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; use hyper::server::conn::AddrStream; @@ -17,94 +23,179 @@ use once_cell::sync::Lazy; use opentelemetry::global; use opentelemetry_http::HeaderInjector; use tower::{Service, ServiceBuilder}; -use tracing::{debug, debug_span, field, trace}; +use tracing::{debug, debug_span, error, field, trace}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::custom_domain::{AcmeClient, ChallengeResponder, ChallengeResponderLayer}; +use crate::acme::{AcmeClient, ChallengeResponderLayer}; use crate::service::GatewayService; use crate::{Error, ErrorKind, ProjectName}; static PROXY_CLIENT: Lazy>> = Lazy::new(|| ReverseProxy::new(Client::new())); -pub struct ProxyService { - gateway: Arc, - remote_addr: SocketAddr, - fqdn: String, +pub trait AsResponderTo { + fn as_responder_to(&self, req: R) -> Self; + + fn into_make_service(self) -> ResponderMakeService + where + Self: Sized, + { + ResponderMakeService { inner: self } + } } -impl Service> for ProxyService { - type Response = Response; - type Error = Error; - type Future = - Pin> + Send + 'static>>; +pub struct ResponderMakeService { + inner: S, +} + +impl<'r, S> Service<&'r AddrStream> for ResponderMakeService +where + S: AsResponderTo<&'r AddrStream>, +{ + type Response = S; + type Error = Infallible; + type Future = Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, mut req: Request) -> Self::Future { - let remote_addr = self.remote_addr.ip(); - let gateway = Arc::clone(&self.gateway); - let fqdn = self.fqdn.clone(); + fn call(&mut self, req: &'r AddrStream) -> Self::Future { + ready(Ok(self.inner.as_responder_to(req))) + } +} + +#[derive(Clone)] +pub struct UserProxy { + gateway: Arc, + remote_addr: SocketAddr, + public: FQDN, +} - Box::pin( - async move { - let span = debug_span!("proxy", http.method = %req.method(), http.uri = %req.uri(), http.status_code = field::Empty, project = field::Empty); - trace!(?req, "serving proxy request"); - let project_str = req - .headers() - .get("Host") - .map(|head| head.to_str().unwrap()) - .and_then(|host| host.strip_suffix('.').unwrap_or(host).strip_suffix(&fqdn)) - .ok_or_else(|| Error::from_kind(ErrorKind::ProjectNotFound))?; +impl<'r> AsResponderTo<&'r AddrStream> for UserProxy { + fn as_responder_to(&self, addr_stream: &'r AddrStream) -> Self { + let mut responder = self.clone(); + responder.remote_addr = addr_stream.remote_addr(); + responder + } +} - let project_name: ProjectName = project_str - .parse() - .map_err(|_| Error::from_kind(ErrorKind::InvalidProjectName))?; +impl UserProxy { + async fn proxy(self, mut req: Request) -> Result { + let span = debug_span!("proxy", http.method = %req.method(), http.host = ?req.headers().get("Host"), http.uri = %req.uri(), http.status_code = field::Empty, project = field::Empty); + trace!(?req, "serving proxy request"); - let project = gateway.find_project(&project_name).await?; + let project_str = req + .headers() + .typed_get::() + .map(|host| fqdn!(host.hostname())) + .and_then(|fqdn| { + debug!(host = %fqdn, public = %self.public, "comparing host key"); + if fqdn.is_subdomain_of(&self.public) && fqdn.depth() - self.public.depth() == 1 { + Some(fqdn.labels().next().unwrap().to_owned()) + } else { + None + } + }) + .ok_or_else(|| Error::from_kind(ErrorKind::ProjectNotFound))?; - // Record current project for tracing purposes - span.record("project", &project_name.to_string()); + let project_name: ProjectName = project_str + .parse() + .map_err(|_| Error::from_kind(ErrorKind::InvalidProjectName))?; - let target_ip = project - .target_ip()? - .ok_or_else(|| Error::from_kind(ErrorKind::ProjectNotReady))?; + let project = self.gateway.find_project(&project_name).await?; - let target_url = format!("http://{}:{}", target_ip, 8000); + // Record current project for tracing purposes + span.record("project", &project_name.to_string()); - let cx = span.context(); + let target_ip = project + .target_ip()? + .ok_or_else(|| Error::from_kind(ErrorKind::ProjectNotReady))?; - global::get_text_map_propagator(|propagator| { - propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut())) - }); + let target_url = format!("http://{}:{}", target_ip, 8000); - let proxy = PROXY_CLIENT - .call(remote_addr, &target_url, req) - .await - .map_err(|_| Error::from_kind(ErrorKind::ProjectUnavailable))?; + let cx = span.context(); - let (parts, body) = proxy.into_parts(); - let body = ::map_err(body, axum::Error::new).boxed_unsync(); + global::get_text_map_propagator(|propagator| { + propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut())) + }); - span.record("http.status_code", parts.status.as_u16()); + let proxy = PROXY_CLIENT + .call(self.remote_addr.ip(), &target_url, req) + .await + .map_err(|_| Error::from_kind(ErrorKind::ProjectUnavailable))?; - Ok(Response::from_parts(parts, body)) - } - .or_else(|err: Error| future::ready(Ok(err.into_response()))), - ) + let (parts, body) = proxy.into_parts(); + let body = ::map_err(body, axum::Error::new).boxed_unsync(); + + span.record("http.status_code", parts.status.as_u16()); + + Ok(Response::from_parts(parts, body)) } } -pub struct MakeProxyService { +impl Service> for UserProxy { + type Response = Response; + type Error = Error; + type Future = + Pin> + Send + 'static>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + self.clone() + .proxy(req) + .or_else(|err: Error| future::ready(Ok(err.into_response()))) + .boxed() + } +} + +#[derive(Clone)] +pub struct Bouncer { gateway: Arc, - acme_client: AcmeClient, - fqdn: String, + public: FQDN, } -impl<'r> Service<&'r AddrStream> for MakeProxyService { - type Response = ChallengeResponder; +impl<'r> AsResponderTo<&'r AddrStream> for Bouncer { + fn as_responder_to(&self, _req: &'r AddrStream) -> Self { + self.clone() + } +} + +impl Bouncer { + async fn bounce(self, req: Request) -> Result { + let mut resp = Response::builder(); + + let host = req.headers().typed_get::().unwrap(); + let hostname = host.hostname(); + let fqdn = fqdn!(hostname); + + let path = req.uri(); + + if fqdn.is_subdomain_of(&self.public) + || self + .gateway + .project_details_for_custom_domain(&fqdn) + .await + .is_ok() + { + resp = resp + .status(301) + .header("Location", format!("https://{hostname}{path}")); + } else { + resp = resp.status(404); + } + + let body = ::map_err(Body::empty(), axum::Error::new).boxed_unsync(); + + Ok(resp.body(body).unwrap()) + } +} + +impl Service> for Bouncer { + type Response = Response; type Error = Error; type Future = Pin> + Send + 'static>>; @@ -113,39 +204,134 @@ impl<'r> Service<&'r AddrStream> for MakeProxyService { Poll::Ready(Ok(())) } - fn call(&mut self, target: &'r AddrStream) -> Self::Future { - let gateway = Arc::clone(&self.gateway); - let acme_client = self.acme_client.clone(); - let remote_addr = target.remote_addr(); - let fqdn = self.fqdn.clone(); - - Box::pin(async move { - let challenge_response_layer = ChallengeResponderLayer::new(acme_client); - let proxy_service = ProxyService { - remote_addr, - gateway, - fqdn, - }; + fn call(&mut self, req: Request) -> Self::Future { + self.clone().bounce(req).boxed() + } +} - let service = ServiceBuilder::new() - .layer(challenge_response_layer) - .service(proxy_service); +pub struct UserServiceBuilder { + service: Option>, + acme: Option, + tls_acceptor: Option>, + bouncer_binds_to: Option, + user_binds_to: Option, + public: Option, +} - Ok(service) - }) +impl Default for UserServiceBuilder { + fn default() -> Self { + Self::new() } } -pub fn make_proxy( - gateway: Arc, - acme_client: AcmeClient, - fqdn: String, -) -> MakeProxyService { - debug!("making proxy"); - - MakeProxyService { - gateway, - acme_client, - fqdn: format!(".{fqdn}"), +impl UserServiceBuilder { + pub fn new() -> Self { + Self { + service: None, + public: None, + acme: None, + tls_acceptor: None, + bouncer_binds_to: None, + user_binds_to: None, + } + } + + pub fn with_public(mut self, public: FQDN) -> Self { + self.public = Some(public); + self + } + + pub fn with_service(mut self, service: Arc) -> Self { + self.service = Some(service); + self + } + + pub fn with_bouncer(mut self, bound_to: SocketAddr) -> Self { + self.bouncer_binds_to = Some(bound_to); + self + } + + pub fn with_user_proxy_binding_to(mut self, bound_to: SocketAddr) -> Self { + self.user_binds_to = Some(bound_to); + self + } + + pub fn with_acme(mut self, acme: AcmeClient) -> Self { + self.acme = Some(acme); + self + } + + pub fn with_tls(mut self, acceptor: RustlsAcceptor) -> Self { + self.tls_acceptor = Some(acceptor); + self + } + + pub fn serve(self) -> impl Future> { + let service = self.service.expect("a GatewayService is required"); + let public = self.public.expect("a public FQDN is required"); + let user_binds_to = self + .user_binds_to + .expect("a socket address to bind to is required"); + + let user_proxy = UserProxy { + gateway: service.clone(), + remote_addr: "127.0.0.1:80".parse().unwrap(), + public: public.clone(), + }; + + let bouncer = self.bouncer_binds_to.as_ref().map(|_| Bouncer { + gateway: service.clone(), + public: public.clone(), + }); + + let mut futs = Vec::new(); + if let Some(tls_acceptor) = self.tls_acceptor { + // TLS is enabled + let bouncer = bouncer.expect("TLS cannot be enabled without a bouncer"); + let bouncer_binds_to = self.bouncer_binds_to.unwrap(); + + let acme = self + .acme + .expect("TLS cannot be enabled without an ACME client"); + + let bouncer = ServiceBuilder::new() + .layer(ChallengeResponderLayer::new(acme)) + .service(bouncer); + + let bouncer = axum_server::Server::bind(bouncer_binds_to) + .serve(bouncer.into_make_service()) + .map(|handle| ("bouncer (with challenge responder)", handle)) + .boxed(); + + futs.push(bouncer); + + let user_with_tls = axum_server::Server::bind(user_binds_to) + .acceptor(tls_acceptor) + .serve(user_proxy.into_make_service()) + .map(|handle| ("user proxy (with TLS)", handle)) + .boxed(); + futs.push(user_with_tls); + } else { + if let Some(bouncer) = bouncer { + // bouncer is enabled + let bouncer_binds_to = self.bouncer_binds_to.unwrap(); + let bouncer = axum_server::Server::bind(bouncer_binds_to) + .serve(bouncer.into_make_service()) + .map(|handle| ("bouncer (without challenge responder)", handle)) + .boxed(); + futs.push(bouncer); + } + + let user_without_tls = axum_server::Server::bind(user_binds_to) + .serve(user_proxy.into_make_service()) + .map(|handle| ("user proxy (no TLS)", handle)) + .boxed(); + futs.push(user_without_tls); + } + + future::select_all(futs.into_iter()).map(|((name, resolved), _, _)| { + error!(service = %name, "exited early"); + resolved + }) } } diff --git a/gateway/src/service.rs b/gateway/src/service.rs index f85ff74a8..62457cd30 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -23,9 +23,9 @@ use sqlx::{query, Error as SqlxError, Row}; use tracing::{debug, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; +use crate::acme::CustomDomain; use crate::args::ContextArgs; use crate::auth::{Key, Permissions, User}; -use crate::custom_domain::CustomDomain; use crate::project::Project; use crate::task::TaskBuilder; use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName}; @@ -454,13 +454,13 @@ impl GatewayService { &self, project_name: ProjectName, fqdn: &Fqdn, - certificate: &str, + certs: &str, private_key: &str, ) -> Result<(), Error> { query("INSERT INTO custom_domains (fqdn, project_name, certificate, private_key) VALUES (?1, ?2, ?3, ?4)") .bind(fqdn.to_string()) .bind(&project_name) - .bind(certificate) + .bind(certs) .bind(private_key) .execute(&self.db) .await diff --git a/gateway/src/tls.rs b/gateway/src/tls.rs new file mode 100644 index 000000000..c19c6678e --- /dev/null +++ b/gateway/src/tls.rs @@ -0,0 +1,165 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufReader, Read, Write}; +use std::path::Path; +use std::sync::Arc; + +use axum_server::accept::DefaultAcceptor; +use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig}; +use futures::executor::block_on; +use pem::Pem; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::{self, CertifiedKey}; +use rustls::{Certificate, PrivateKey, ServerConfig}; +use rustls_pemfile::Item; +use shuttle_common::models::error::ErrorKind; +use tokio::runtime::Handle; +use tokio::sync::RwLock; + +use crate::Error; + +#[derive(Clone)] +pub struct ChainAndPrivateKey { + chain: Vec, + private_key: PrivateKey, +} + +impl ChainAndPrivateKey { + pub fn parse_pem(rd: R) -> Result { + let mut private_key = None; + let mut chain = Vec::new(); + + for item in rustls_pemfile::read_all(&mut BufReader::new(rd)) + .map_err(|_| Error::from_kind(ErrorKind::Internal))? + { + match item { + Item::X509Certificate(cert) => chain.push(Certificate(cert)), + Item::ECKey(key) | Item::PKCS8Key(key) | Item::RSAKey(key) => { + private_key = Some(PrivateKey(key)) + } + _ => return Err(Error::from_kind(ErrorKind::Internal)), + } + } + + Ok(Self { + chain, + private_key: private_key.unwrap(), + }) + } + + pub fn load_pem>(path: P) -> Result { + let rd = File::open(path)?; + Self::parse_pem(rd) + } + + pub fn into_pem(self) -> Result { + let mut pems = Vec::new(); + for cert in self.chain { + pems.push(Pem { + tag: "CERTIFICATE".to_string(), + contents: cert.0, + }); + } + + pems.push(Pem { + tag: "PRIVATE KEY".to_string(), + contents: self.private_key.0, + }); + + Ok(pem::encode_many(&pems)) + } + + pub fn into_certified_key(self) -> Result { + let signing_key = sign::any_supported_type(&self.private_key) + .map_err(|_| Error::from_kind(ErrorKind::Internal))?; + Ok(CertifiedKey::new(self.chain, signing_key)) + } + + pub fn save_pem>(self, path: P) -> Result<(), Error> { + let as_pem = self.into_pem()?; + let mut f = File::create(path)?; + f.write_all(as_pem.as_bytes())?; + Ok(()) + } +} + +pub struct GatewayCertResolver { + keys: RwLock>>, + default: RwLock>>, +} + +impl Default for GatewayCertResolver { + fn default() -> Self { + Self::new() + } +} + +impl GatewayCertResolver { + pub fn new() -> Self { + Self { + keys: RwLock::new(HashMap::default()), + default: RwLock::new(None), + } + } + + /// Get the loaded [CertifiedKey] associated with the given + /// domain. + pub async fn get(&self, sni: &str) -> Option> { + self.keys.read().await.get(sni).map(Arc::clone) + } + + pub async fn serve_default_der(&self, certs: ChainAndPrivateKey) -> Result<(), Error> { + *self.default.write().await = Some(Arc::new(certs.into_certified_key()?)); + Ok(()) + } + + pub async fn serve_default_pem(&self, rd: R) -> Result<(), Error> { + let certs = ChainAndPrivateKey::parse_pem(rd)?; + self.serve_default_der(certs).await + } + + /// Load a new certificate chain and private key to serve when + /// receiving incoming TLS connections for the given domain. + pub async fn serve_der(&self, fqdn: &str, certs: ChainAndPrivateKey) -> Result<(), Error> { + let certified_key = certs.into_certified_key()?; + self.keys + .write() + .await + .insert(fqdn.to_string(), Arc::new(certified_key)); + Ok(()) + } + + pub async fn serve_pem(&self, fqdn: &str, rd: R) -> Result<(), Error> { + let certs = ChainAndPrivateKey::parse_pem(rd)?; + self.serve_der(fqdn, certs).await + } +} + +impl ResolvesServerCert for GatewayCertResolver { + fn resolve(&self, client_hello: ClientHello) -> Option> { + let sni = client_hello.server_name()?; + let handle = Handle::current(); + let _ = handle.enter(); + block_on(async move { + if let Some(cert) = self.get(sni).await { + Some(cert) + } else { + self.default.read().await.clone() + } + }) + } +} + +pub fn make_tls_acceptor() -> (Arc, RustlsAcceptor) { + let resolver = Arc::new(GatewayCertResolver::new()); + + let mut server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::clone(&resolver) as Arc); + server_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + + let rustls_config = RustlsConfig::from_config(Arc::new(server_config)); + + (resolver, RustlsAcceptor::new(rustls_config)) +}