From 3259d7801b7673a86b33e0f29146f1612f0cbe48 Mon Sep 17 00:00:00 2001 From: Alex Krantz Date: Mon, 7 Nov 2022 00:18:36 -0800 Subject: [PATCH] feat(gateway): add custom domains table and routing --- common/src/models/error.rs | 8 ++ gateway/migrations/0001_custom_domains.sql | 7 ++ gateway/src/custom_domain.rs | 12 +++ gateway/src/lib.rs | 107 ++++++++++++++++++++- gateway/src/service.rs | 87 ++++++++++++++++- 5 files changed, 219 insertions(+), 2 deletions(-) create mode 100644 gateway/migrations/0001_custom_domains.sql create mode 100644 gateway/src/custom_domain.rs diff --git a/common/src/models/error.rs b/common/src/models/error.rs index 5e56779252..2ac2d512ae 100644 --- a/common/src/models/error.rs +++ b/common/src/models/error.rs @@ -44,6 +44,9 @@ pub enum ErrorKind { ProjectAlreadyExists, ProjectNotReady, ProjectUnavailable, + CustomDomainNotFound, + InvalidCustomDomain, + CustomDomainAlreadyExists, InvalidOperation, Internal, NotReady, @@ -75,6 +78,11 @@ impl From for ApiError { StatusCode::BAD_REQUEST, "a project with the same name already exists", ), + ErrorKind::InvalidCustomDomain => (StatusCode::BAD_REQUEST, "invalid custom domain"), + ErrorKind::CustomDomainNotFound => (StatusCode::NOT_FOUND, "custom domain not found"), + ErrorKind::CustomDomainAlreadyExists => { + (StatusCode::BAD_REQUEST, "custom domain already in use") + } ErrorKind::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized"), ErrorKind::Forbidden => (StatusCode::FORBIDDEN, "forbidden"), ErrorKind::NotReady => (StatusCode::INTERNAL_SERVER_ERROR, "service not ready"), diff --git a/gateway/migrations/0001_custom_domains.sql b/gateway/migrations/0001_custom_domains.sql new file mode 100644 index 0000000000..e00a079bd1 --- /dev/null +++ b/gateway/migrations/0001_custom_domains.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS custom_domains ( + fqdn TEXT PRIMARY KEY, + project_name TEXT NOT NULL REFERENCES projects (project_name), + state JSON NOT NULL +); + +CREATE INDEX IF NOT EXISTS custom_domains_fqdn_project_idx ON custom_domains (fqdn, project_name); diff --git a/gateway/src/custom_domain.rs b/gateway/src/custom_domain.rs new file mode 100644 index 0000000000..48cd09593a --- /dev/null +++ b/gateway/src/custom_domain.rs @@ -0,0 +1,12 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CustomDomain { + // TODO: update custom domain states, these are just placeholders for now + Creating, + Verifying, + IssuingCertificate, + Ready, + Errored, +} diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index d5d2980bde..3f0e61d5fd 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -8,20 +8,26 @@ use std::io; use std::pin::Pin; use std::str::FromStr; +use axum::headers::{Header, HeaderName, HeaderValue, Host}; +use axum::http::uri::Authority; use axum::response::{IntoResponse, Response}; use axum::Json; use bollard::Docker; use futures::prelude::*; use once_cell::sync::Lazy; use regex::Regex; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use shuttle_common::models::error::{ApiError, ErrorKind}; +use sqlx::database::{HasArguments, HasValueRef}; +use sqlx::encode::IsNull; +use sqlx::error::BoxDynError; use tokio::sync::mpsc::error::SendError; use tracing::error; pub mod api; pub mod args; pub mod auth; +pub mod custom_domain; pub mod project; pub mod proxy; pub mod service; @@ -169,6 +175,105 @@ impl<'de> Deserialize<'de> for AccountName { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Fqdn(fqdn::FQDN); + +impl FromStr for Fqdn { + type Err = Error; + + fn from_str(s: &str) -> Result { + let fqdn = + fqdn::FQDN::from_str(s).map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?; + Ok(Fqdn(fqdn)) + } +} + +impl std::fmt::Display for Fqdn { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl sqlx::Type for Fqdn +where + DB: sqlx::Database, + str: sqlx::Type, +{ + fn type_info() -> ::TypeInfo { + <&str as sqlx::Type>::type_info() + } + + fn compatible(ty: &::TypeInfo) -> bool { + <&str as sqlx::Type>::compatible(ty) + } +} + +impl<'q, DB> sqlx::Encode<'q, DB> for Fqdn +where + DB: sqlx::Database, + String: sqlx::Encode<'q, DB>, +{ + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + let owned = self.0.to_string(); + >::encode(owned, buf) + } +} + +impl<'r, DB> sqlx::Decode<'r, DB> for Fqdn +where + DB: sqlx::Database, + &'r str: sqlx::Decode<'r, DB>, +{ + fn decode(value: >::ValueRef) -> Result { + let value = <&str as sqlx::Decode>::decode(value)?; + Ok(value.parse()?) + } +} + +impl Serialize for Fqdn { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.0.to_string()) + } +} + +impl<'de> Deserialize<'de> for Fqdn { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + String::deserialize(deserializer)? + .parse() + .map_err(::custom) + } +} + +impl Header for Fqdn { + fn name() -> &'static HeaderName { + Host::name() + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let host = Host::decode(values)?; + let fqdn = fqdn::FQDN::from_str(host.hostname()) + .map_err(|_err| axum::headers::Error::invalid())?; + + Ok(Fqdn(fqdn)) + } + + fn encode>(&self, values: &mut E) { + let authority = Authority::from_str(&self.0.to_string()).unwrap(); + let host = Host::from(authority); + host.encode(values); + } +} + pub trait Context<'c>: Send + Sync { fn docker(&self) -> &'c Docker; diff --git a/gateway/src/service.rs b/gateway/src/service.rs index 2cae0e7bf3..29104f3806 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -22,9 +22,10 @@ use tracing::debug; use crate::args::StartArgs; use crate::auth::{Key, User}; +use crate::custom_domain::CustomDomain; use crate::project::{self, Project}; use crate::worker::Work; -use crate::{AccountName, Context, Error, ErrorKind, ProjectName, Service}; +use crate::{AccountName, Context, Error, ErrorKind, Fqdn, ProjectName, Service}; pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations"); static PROXY_CLIENT: Lazy>> = @@ -193,6 +194,16 @@ impl GatewayService { Self { provider, db } } + pub async fn route_fqdn(&self, req: Request) -> Result, Error> { + let fqdn = req + .headers() + .typed_get::() + .ok_or_else(|| Error::from(ErrorKind::CustomDomainNotFound))?; + let project_name = self.project_name_for_custom_domain(&fqdn).await?; + + self.route(&project_name, req).await + } + pub async fn route( &self, project_name: &ProjectName, @@ -439,6 +450,42 @@ impl GatewayService { }) } + pub async fn create_custom_domain( + &self, + project_name: ProjectName, + fqdn: Fqdn, + ) -> Result { + let state = SqlxJson(CustomDomain::Creating); + + query("INSERT INTO custom_domains (fqdn, project_name, state) VALUES (?1, ?2, ?3)") + .bind(&fqdn) + .bind(&project_name) + .bind(&state) + .execute(&self.db) + .await + .map_err(|err| { + if let Some(db_err_code) = err.as_database_error().and_then(DatabaseError::code) { + if db_err_code == "1555" { + return Error::from(ErrorKind::CustomDomainAlreadyExists); + } + } + + err.into() + })?; + + Ok(state.0) + } + + pub async fn project_name_for_custom_domain(&self, fqdn: &Fqdn) -> Result { + let project_name = query("SELECT project_name FROM custom_domains WHERE fqdn = ?1") + .bind(fqdn) + .fetch_optional(&self.db) + .await? + .map(|row| row.try_get("project_name").unwrap()) + .ok_or_else(|| Error::from(ErrorKind::CustomDomainNotFound))?; + Ok(project_name) + } + fn context(&self) -> GatewayContext { self.provider.context() } @@ -587,4 +634,42 @@ pub mod tests { Ok(()) } + + #[tokio::test] + async fn service_create_find_custom_domain() -> anyhow::Result<()> { + let world = World::new().await; + let svc = Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await); + + let account: AccountName = "neo".parse().unwrap(); + let project_name: ProjectName = "matrix".parse().unwrap(); + let domain: Fqdn = "neo.the.matrix".parse().unwrap(); + + svc.create_user(account.clone()).await.unwrap(); + + assert_err_kind!( + svc.project_name_for_custom_domain(&domain).await, + ErrorKind::CustomDomainNotFound + ); + + let _ = svc + .create_project(project_name.clone(), account.clone()) + .await + .unwrap(); + + svc.create_custom_domain(project_name.clone(), domain.clone()) + .await + .unwrap(); + + let project = svc.project_name_for_custom_domain(&domain).await.unwrap(); + + assert_eq!(project, project_name); + + assert_err_kind!( + svc.create_custom_domain(project_name.clone(), domain.clone()) + .await, + ErrorKind::CustomDomainAlreadyExists + ); + + Ok(()) + } }