Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: custom domain routing #484

Merged
merged 6 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions deployer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ pub async fn start_proxy(
fqdn: FQDN,
address_getter: impl AddressGetter,
) {
let make_service = make_service_fn(|socket: &AddrStream| {
let make_service = make_service_fn(move |socket: &AddrStream| {
let remote_address = socket.remote_addr();
let fqdn = format!(".{}", fqdn.to_string().trim_end_matches('.'));
let address_getter = address_getter.clone();
let fqdn = fqdn.clone();

async move {
Ok::<_, Infallible>(service_fn(move |req| {
Expand Down
39 changes: 27 additions & 12 deletions deployer/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{
};

use async_trait::async_trait;
use fqdn::FQDN;
use hyper::{
client::{connect::dns::GaiResolver, HttpConnector},
header::{HeaderValue, HOST, SERVER},
Expand All @@ -23,7 +24,7 @@ static SERVER_HEADER: Lazy<HeaderValue> = Lazy::new(|| "shuttle.rs".parse().unwr
#[instrument(name = "proxy_request", skip(address_getter), fields(http.method = %req.method(), http.uri = %req.uri(), http.status_code = field::Empty, service = field::Empty))]
pub async fn handle(
remote_address: SocketAddr,
fqdn: String,
fqdn: FQDN,
req: Request<Body>,
address_getter: impl AddressGetter,
) -> Result<Response<Body>, Infallible> {
Expand All @@ -33,44 +34,58 @@ pub async fn handle(
});
span.set_parent(parent_context);

let host = match req.headers().get(HOST) {
Some(host) => host.to_str().unwrap_or_default().to_owned(),
let host: FQDN = match req.headers().get(HOST) {
Some(host) => host
.to_str()
.unwrap_or_default()
.parse::<FQDN>()
.unwrap_or_default()
.to_owned(),
None => {
trace!("proxy request has to host header");
trace!("proxy request has no host header");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap());
}
};

let service = match host.strip_suffix(&fqdn) {
Some(service) => service,
if host != fqdn {
trace!(?host, "proxy won't serve foreign domain");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("this domain is not served by proxy"))
.unwrap());
}

// We only have one service per project, and its name coincides
// with that of the project
let service = match req.headers().get("X-Shuttle-Project") {
brokad marked this conversation as resolved.
Show resolved Hide resolved
Some(project) => project.to_str().unwrap_or_default().to_owned(),
None => {
trace!(host, "proxy won't serve foreign domain");
trace!("proxy request has no X-Shuttle-Project header");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("this domain is not served by proxy"))
.body(Body::from("request has no X-Shuttle-Project header"))
.unwrap());
}
};

// Record current service for tracing purposes
span.record("service", &service);

let proxy_address = match address_getter.get_address_for_service(service).await {
let proxy_address = match address_getter.get_address_for_service(&service).await {
Ok(Some(address)) => address,
Ok(None) => {
trace!(host, "host not found on this server");
let response_body = format!("could not find service for host: {}", host);
trace!(?host, service, "service not found on this server");
let response_body = format!("could not find service: {}", service);
return Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(response_body.into())
.unwrap());
}
Err(err) => {
error!(error = %err, host, "proxy failed to find address for host");
error!(error = %err, service, "proxy failed to find address for host");

let response_body = format!("failed to find service for host: {}", host);
return Ok(Response::builder()
Expand Down
8 changes: 5 additions & 3 deletions gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ async-trait = "0.1.52"

axum = { version = "0.5.8", features = [ "headers" ] }
axum-server = { version = "0.4.4", features = [ "tls-rustls" ] }
rustls = { version = "0.20.6" }
rustls-pemfile = { version = "1.0.1" }
pem = "1.1.0"

base64 = "0.13"
bollard = "0.13"
Expand All @@ -27,12 +24,17 @@ hyper = { version = "0.14.19", features = [ "stream" ] }
# not great, but waiting for WebSocket changes to be merged
hyper-reverse-proxy = { git = "https://github.com/chesedo/hyper-reverse-proxy", branch = "bug/host_header" }
instant-acme = "0.1.0"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

lazy_static = "1.4.0"
once_cell = "1.14.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-datadog = { version = "0.6.0", features = ["reqwest-client"] }
opentelemetry-http = "0.7.0"
pem = "1.1.0"
rand = "0.8.5"
rcgen = "0.10.0"
rustls = { version = "0.20.6" }
rustls-pemfile = { version = "1.0.1" }
serde = { version = "1.0.137", features = [ "derive" ] }
serde_json = "1.0.81"
sqlx = { version = "0.5.11", features = [ "sqlite", "json", "runtime-tokio-rustls", "migrate" ] }
Expand Down
6 changes: 4 additions & 2 deletions gateway/src/acme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::time::Duration;

use axum::body::boxed;
use axum::response::Response;
use fqdn::FQDN;
use futures::future::BoxFuture;
use hyper::server::conn::AddrStream;
use hyper::{Body, Request};
Expand All @@ -25,9 +26,10 @@ const MAX_RETRIES: usize = 15;

#[derive(Debug, Eq, PartialEq)]
pub struct CustomDomain {
pub fqdn: FQDN,
pub project_name: ProjectName,
pub certificate: Vec<u8>,
pub private_key: Vec<u8>,
pub certificate: String,
pub private_key: String,
}

/// An ACME client implementation that completes Http01 challenges
Expand Down
58 changes: 41 additions & 17 deletions gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use tokio::sync::mpsc::Sender;
use tower_http::trace::TraceLayer;
use tracing::{debug, debug_span, field, Span};

use crate::acme::AcmeClient;
use crate::acme::{AcmeClient, CustomDomain};
use crate::auth::{Admin, ScopedUser, User};
use crate::task::{self, BoxedTask};
use crate::project::{Project, ProjectCreating};
use crate::task::{self, BoxedTask, TaskResult};
use crate::tls::GatewayCertResolver;
use crate::worker::WORKER_QUEUE_SIZE;
use crate::{AccountName, Error, GatewayService, ProjectName};
Expand Down Expand Up @@ -108,7 +109,6 @@ async fn post_project(
service
.new_task()
.project(project.clone())
.account(name.clone())
.send(&sender)
.await?;

Expand All @@ -123,18 +123,12 @@ async fn post_project(
async fn delete_project(
Extension(service): Extension<Arc<GatewayService>>,
Extension(sender): Extension<Sender<BoxedTask>>,
ScopedUser {
scope: _,
user: User { name, .. },
}: ScopedUser,
Path(project): Path<ProjectName>,
ScopedUser { scope: project, .. }: ScopedUser,
) -> Result<AxumJson<project::Response>, Error> {
let project_name = project.clone();

let state = service.find_project(&project_name).await?;
let state = service.find_project(&project).await?;

let mut response = project::Response {
name: project_name.to_string(),
name: project.to_string(),
state: state.into(),
};

Expand All @@ -146,7 +140,6 @@ async fn delete_project(
service
.new_task()
.project(project)
.account(name)
.and_then(task::destroy())
.send(&sender)
.await?;
Expand Down Expand Up @@ -209,19 +202,50 @@ async fn request_acme_certificate(
Extension(service): Extension<Arc<GatewayService>>,
Extension(acme_client): Extension<AcmeClient>,
Extension(resolver): Extension<Arc<GatewayCertResolver>>,
Extension(sender): Extension<Sender<BoxedTask>>,
Path((project_name, fqdn)): Path<(ProjectName, String)>,
AxumJson(credentials): AxumJson<AccountCredentials<'_>>,
) -> Result<String, Error> {
let fqdn: FQDN = fqdn
.parse()
.map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?;

let (certs, private_key) = acme_client
.create_certificate(&fqdn.to_string(), ChallengeType::Http01, credentials)
.await?;
let (certs, private_key) = match service.project_details_for_custom_domain(&fqdn).await {
brokad marked this conversation as resolved.
Show resolved Hide resolved
Ok(CustomDomain {
certificate,
private_key,
..
}) => (certificate, private_key),
Err(err) if err.kind() == ErrorKind::CustomDomainNotFound => {
let (certs, private_key) = acme_client
.create_certificate(&fqdn.to_string(), ChallengeType::Http01, credentials)
.await?;
service
.create_custom_domain(project_name.clone(), &fqdn, &certs, &private_key)
.await?;
(certs, private_key)
}
Err(err) => return Err(err),
};

// destroy and recreate the project with the new domain
service
.create_custom_domain(project_name, &fqdn, &certs, &private_key)
.new_task()
.project(project_name)
.and_then(task::destroy())
.and_then(task::run_until_done())
.and_then(task::run({
let fqdn = fqdn.to_string();
move |ctx| {
let fqdn = fqdn.clone();
async move {
let creating = ProjectCreating::new_with_random_initial_key(ctx.project_name)
.with_fqdn(fqdn);
TaskResult::Done(Project::Creating(creating))
}
}
}))
.send(&sender)
.await?;

let mut buf = Vec::new();
Expand Down
7 changes: 7 additions & 0 deletions gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ impl StdError for Error {}
#[sqlx(transparent)]
pub struct ProjectName(String);

impl ProjectName {
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}

impl<'de> Deserialize<'de> for ProjectName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand Down Expand Up @@ -777,6 +783,7 @@ pub mod tests {
.request(
Request::get("/hello")
.header("Host", "matrix.test.shuttleapp.rs")
.header("x-shuttle-project", "matrix")
.body(Body::empty())
.unwrap(),
)
Expand Down
24 changes: 19 additions & 5 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use fqdn::FQDN;
use futures::prelude::*;
use instant_acme::{AccountCredentials, ChallengeType};
use opentelemetry::global;
use shuttle_gateway::acme::AcmeClient;
use shuttle_gateway::acme::{AcmeClient, CustomDomain};
use shuttle_gateway::api::latest::ApiBuilder;
use shuttle_gateway::args::StartArgs;
use shuttle_gateway::args::{Args, Commands, InitArgs, UseTls};
Expand Down Expand Up @@ -77,7 +77,7 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {

let sender = worker.sender();

for (project_name, account_name) in gateway
for (project_name, _) in gateway
.iter_projects()
.await
.expect("could not list projects")
Expand All @@ -86,7 +86,6 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {
.clone()
.new_task()
.project(project_name)
.account(account_name)
.and_then(task::refresh())
.send(&sender)
.await
Expand All @@ -110,11 +109,10 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
if let Ok(projects) = gateway.iter_projects().await {
for (project_name, account_name) in projects {
for (project_name, _) in projects {
let _ = gateway
.new_task()
.project(project_name)
.account(account_name)
.and_then(task::check_health())
.send(&sender)
.await;
Expand Down Expand Up @@ -146,6 +144,22 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {

api_builder = api_builder.with_acme(acme_client.clone(), resolver.clone());

for CustomDomain {
fqdn,
certificate,
private_key,
..
} in gateway.iter_custom_domains().await.unwrap()
{
let mut buf = Vec::new();
buf.extend(certificate.as_bytes());
buf.extend(private_key.as_bytes());
resolver
.serve_pem(&fqdn.to_string(), Cursor::new(buf))
.await
.unwrap();
}

tokio::spawn(async move {
// make sure we have a certificate for ourselves
let certs = init_certs(fs, args.context.proxy_fqdn.clone(), acme_client.clone()).await;
Expand Down
Loading