Skip to content

Commit

Permalink
feat(gateway): add custom domains table and routing (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
akrantz01 authored Nov 8, 2022
1 parent b4055af commit 3ab6c71
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 2 deletions.
8 changes: 8 additions & 0 deletions common/src/models/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub enum ErrorKind {
ProjectAlreadyExists,
ProjectNotReady,
ProjectUnavailable,
CustomDomainNotFound,
InvalidCustomDomain,
CustomDomainAlreadyExists,
InvalidOperation,
Internal,
NotReady,
Expand Down Expand Up @@ -80,6 +83,11 @@ impl From<ErrorKind> for ApiError {
StatusCode::BAD_REQUEST,
"a project with the same name already exists",
),
ErrorKind::InvalidCustomDomain => (StatusCode::BAD_REQUEST, "invalid custom domain"),
ErrorKind::CustomDomainNotFound => (StatusCode::NOT_FOUND, "custom domain not found"),
ErrorKind::CustomDomainAlreadyExists => {
(StatusCode::BAD_REQUEST, "custom domain already in use")
}
ErrorKind::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized"),
ErrorKind::Forbidden => (StatusCode::FORBIDDEN, "forbidden"),
ErrorKind::NotReady => (StatusCode::INTERNAL_SERVER_ERROR, "service not ready"),
Expand Down
7 changes: 7 additions & 0 deletions gateway/migrations/0002_custom_domains.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS custom_domains (
fqdn TEXT PRIMARY KEY,
project_name TEXT NOT NULL REFERENCES projects (project_name),
state JSON NOT NULL
);

CREATE INDEX IF NOT EXISTS custom_domains_fqdn_project_idx ON custom_domains (fqdn, project_name);
12 changes: 12 additions & 0 deletions gateway/src/custom_domain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CustomDomain {
// TODO: update custom domain states, these are just placeholders for now
Creating,
Verifying,
IssuingCertificate,
Ready,
Errored,
}
107 changes: 106 additions & 1 deletion gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@ use std::io;
use std::pin::Pin;
use std::str::FromStr;

use axum::headers::{Header, HeaderName, HeaderValue, Host};
use axum::http::uri::Authority;
use axum::response::{IntoResponse, Response};
use axum::Json;
use bollard::Docker;
use futures::prelude::*;
use serde::{Deserialize, Deserializer, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use shuttle_common::models::error::{ApiError, ErrorKind};
use sqlx::database::{HasArguments, HasValueRef};
use sqlx::encode::IsNull;
use sqlx::error::BoxDynError;
use tokio::sync::mpsc::error::SendError;
use tracing::error;

pub mod api;
pub mod args;
pub mod auth;
pub mod custom_domain;
pub mod project;
pub mod proxy;
pub mod service;
Expand Down Expand Up @@ -164,6 +170,105 @@ impl<'de> Deserialize<'de> for AccountName {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Fqdn(fqdn::FQDN);

impl FromStr for Fqdn {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let fqdn =
fqdn::FQDN::from_str(s).map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?;
Ok(Fqdn(fqdn))
}
}

impl std::fmt::Display for Fqdn {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

impl<DB> sqlx::Type<DB> for Fqdn
where
DB: sqlx::Database,
str: sqlx::Type<DB>,
{
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
<&str as sqlx::Type<DB>>::type_info()
}

fn compatible(ty: &<DB as sqlx::Database>::TypeInfo) -> bool {
<&str as sqlx::Type<DB>>::compatible(ty)
}
}

impl<'q, DB> sqlx::Encode<'q, DB> for Fqdn
where
DB: sqlx::Database,
String: sqlx::Encode<'q, DB>,
{
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
let owned = self.0.to_string();
<String as sqlx::Encode<DB>>::encode(owned, buf)
}
}

impl<'r, DB> sqlx::Decode<'r, DB> for Fqdn
where
DB: sqlx::Database,
&'r str: sqlx::Decode<'r, DB>,
{
fn decode(value: <DB as HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
let value = <&str as sqlx::Decode<DB>>::decode(value)?;
Ok(value.parse()?)
}
}

impl Serialize for Fqdn {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.0.to_string())
}
}

impl<'de> Deserialize<'de> for Fqdn {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)?
.parse()
.map_err(<D::Error as serde::de::Error>::custom)
}
}

impl Header for Fqdn {
fn name() -> &'static HeaderName {
Host::name()
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
let host = Host::decode(values)?;
let fqdn = fqdn::FQDN::from_str(host.hostname())
.map_err(|_err| axum::headers::Error::invalid())?;

Ok(Fqdn(fqdn))
}

fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
let authority = Authority::from_str(&self.0.to_string()).unwrap();
let host = Host::from(authority);
host.encode(values);
}
}

pub trait DockerContext: Send + Sync {
fn docker(&self) -> &Docker;

Expand Down
87 changes: 86 additions & 1 deletion gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::args::ContextArgs;
use crate::auth::{Key, Permissions, User};
use crate::custom_domain::CustomDomain;
use crate::project::Project;
use crate::task::TaskBuilder;
use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName};
use crate::{AccountName, DockerContext, Error, ErrorKind, Fqdn, ProjectName};

pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");
static PROXY_CLIENT: Lazy<ReverseProxy<HttpConnector<GaiResolver>>> =
Expand Down Expand Up @@ -201,6 +202,16 @@ impl GatewayService {
Self { provider, db }
}

pub async fn route_fqdn(&self, req: Request<Body>) -> Result<Response<Body>, Error> {
let fqdn = req
.headers()
.typed_get::<Fqdn>()
.ok_or_else(|| Error::from(ErrorKind::CustomDomainNotFound))?;
let project_name = self.project_name_for_custom_domain(&fqdn).await?;

self.route(&project_name, req).await
}

pub async fn route(
&self,
project_name: &ProjectName,
Expand Down Expand Up @@ -448,6 +459,42 @@ impl GatewayService {
Ok(project)
}

pub async fn create_custom_domain(
&self,
project_name: ProjectName,
fqdn: Fqdn,
) -> Result<CustomDomain, Error> {
let state = SqlxJson(CustomDomain::Creating);

query("INSERT INTO custom_domains (fqdn, project_name, state) VALUES (?1, ?2, ?3)")
.bind(&fqdn)
.bind(&project_name)
.bind(&state)
.execute(&self.db)
.await
.map_err(|err| {
if let Some(db_err_code) = err.as_database_error().and_then(DatabaseError::code) {
if db_err_code == "1555" {
return Error::from(ErrorKind::CustomDomainAlreadyExists);
}
}

err.into()
})?;

Ok(state.0)
}

pub async fn project_name_for_custom_domain(&self, fqdn: &Fqdn) -> Result<ProjectName, Error> {
let project_name = query("SELECT project_name FROM custom_domains WHERE fqdn = ?1")
.bind(fqdn)
.fetch_optional(&self.db)
.await?
.map(|row| row.try_get("project_name").unwrap())
.ok_or_else(|| Error::from(ErrorKind::CustomDomainNotFound))?;
Ok(project_name)
}

pub fn context(&self) -> GatewayContext {
self.provider.context()
}
Expand Down Expand Up @@ -659,4 +706,42 @@ pub mod tests {

Ok(())
}

#[tokio::test]
async fn service_create_find_custom_domain() -> anyhow::Result<()> {
let world = World::new().await;
let svc = Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await);

let account: AccountName = "neo".parse().unwrap();
let project_name: ProjectName = "matrix".parse().unwrap();
let domain: Fqdn = "neo.the.matrix".parse().unwrap();

svc.create_user(account.clone()).await.unwrap();

assert_err_kind!(
svc.project_name_for_custom_domain(&domain).await,
ErrorKind::CustomDomainNotFound
);

let _ = svc
.create_project(project_name.clone(), account.clone())
.await
.unwrap();

svc.create_custom_domain(project_name.clone(), domain.clone())
.await
.unwrap();

let project = svc.project_name_for_custom_domain(&domain).await.unwrap();

assert_eq!(project, project_name);

assert_err_kind!(
svc.create_custom_domain(project_name.clone(), domain.clone())
.await,
ErrorKind::CustomDomainAlreadyExists
);

Ok(())
}
}

0 comments on commit 3ab6c71

Please sign in to comment.