diff --git a/Cargo.lock b/Cargo.lock index 584fb54aa..1c62206d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -989,7 +989,7 @@ dependencies = [ "serde_bytes", "serde_json", "time 0.3.11", - "uuid 1.1.2", + "uuid 1.2.1", ] [[package]] @@ -1221,7 +1221,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "uuid 1.1.2", + "uuid 1.2.1", "webbrowser", ] @@ -5329,7 +5329,7 @@ dependencies = [ "serde_json", "strum", "tracing", - "uuid 1.1.2", + "uuid 1.2.1", ] [[package]] @@ -5377,7 +5377,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", - "uuid 1.1.2", + "uuid 1.2.1", ] [[package]] @@ -5416,6 +5416,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "uuid 1.2.1", ] [[package]] @@ -5497,7 +5498,7 @@ dependencies = [ "tower", "tracing", "tracing-subscriber", - "uuid 1.1.2", + "uuid 1.2.1", "warp", ] @@ -5749,7 +5750,7 @@ dependencies = [ "thiserror", "tokio-stream", "url", - "uuid 1.1.2", + "uuid 1.2.1", "whoami", ] @@ -7011,9 +7012,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.1.2" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f" +checksum = "feb41e78f93363bb2df8b0e86a2ca30eed7806ea16ea0c790d757cf93f79be83" dependencies = [ "getrandom 0.2.7", "serde", diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 0c8e54807..eb54ee0c6 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -32,6 +32,7 @@ tower-http = { version = "0.3.4", features = ["trace"] } tracing = "0.1.35" tracing-opentelemetry = "0.18.0" tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } +uuid = { version = "1.2.1", features = [ "v4" ] } [dependencies.shuttle-common] version = "0.7.2" diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index 1942e16ba..8ebb68730 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -15,7 +15,7 @@ use tower_http::trace::TraceLayer; use tracing::{debug, debug_span, field, Span}; use crate::auth::{Admin, ScopedUser, User}; -use crate::worker::Work; +use crate::task::{self, BoxedTask}; use crate::{AccountName, Error, GatewayService, ProjectName}; #[derive(Serialize, Deserialize)] @@ -79,39 +79,52 @@ async fn get_project( async fn post_project( Extension(service): Extension>, - Extension(sender): Extension>, + Extension(sender): Extension>, User { name, .. }: User, Path(project): Path, ) -> Result, Error> { - let work = service.create_project(project.clone(), name).await?; + let state = service + .create_project(project.clone(), name.clone()) + .await?; - let name = work.project_name.to_string(); - let state = work.work.clone().into(); + service + .new_task() + .project(project.clone()) + .account(name.clone()) + .send(&sender) + .await?; - sender.send(work).await?; - - let response = project::Response { name, state }; + let response = project::Response { + name: project.to_string(), + state: state.into(), + }; Ok(AxumJson(response)) } async fn delete_project( Extension(service): Extension>, - Extension(sender): Extension>, + Extension(sender): Extension>, ScopedUser { scope: _, user: User { name, .. }, }: ScopedUser, Path(project): Path, ) -> Result, Error> { - let work = service.destroy_project(project, name).await?; - - let name = work.project_name.to_string(); - let state = work.work.clone().into(); + let project_name = project.clone(); - sender.send(work).await?; + service + .new_task() + .project(project) + .account(name) + .and_then(task::destroy()) + .send(&sender) + .await?; - let response = project::Response { name, state }; + let response = project::Response { + name: project_name.to_string(), + state: shuttle_common::models::project::State::Destroying, + }; Ok(AxumJson(response)) } @@ -123,7 +136,7 @@ async fn route_project( service.route(&scope, req).await } -async fn get_status(Extension(sender): Extension>) -> Response { +async fn get_status(Extension(sender): Extension>) -> Response { let (status, body) = if !sender.is_closed() && sender.capacity() > 0 { (StatusCode::OK, StatusResponse::healthy()) } else { @@ -140,8 +153,9 @@ async fn get_status(Extension(sender): Extension>) -> Response, sender: Sender) -> Router { +pub fn make_api(service: Arc, sender: Sender) -> Router { debug!("making api route"); + Router::::new() .route( "/", @@ -185,14 +199,13 @@ pub mod tests { use super::*; use crate::service::GatewayService; use crate::tests::{RequestBuilderExt, World}; - use crate::worker::Work; #[tokio::test] async fn api_create_get_delete_projects() -> anyhow::Result<()> { let world = World::new().await; let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); - let (sender, mut receiver) = channel::(256); + let (sender, mut receiver) = channel::(256); tokio::spawn(async move { while receiver.recv().await.is_some() { // do not do any work with inbound requests @@ -327,7 +340,7 @@ pub mod tests { let world = World::new().await; let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); - let (sender, mut receiver) = channel::(256); + let (sender, mut receiver) = channel::(256); tokio::spawn(async move { while receiver.recv().await.is_some() { // do not do any work with inbound requests @@ -416,7 +429,7 @@ pub mod tests { let world = World::new().await; let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); - let (sender, mut receiver) = channel::(1); + let (sender, mut receiver) = channel::(1); let (ctl_send, ctl_recv) = oneshot::channel(); let (done_send, done_recv) = oneshot::channel(); let worker = tokio::spawn(async move { @@ -468,6 +481,7 @@ pub mod tests { assert_eq!(resp.status(), StatusCode::OK); worker.abort(); + let _ = worker.await; let resp = router.call(get_status()).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 6a99be184..416b39e5c 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -25,6 +25,7 @@ pub mod auth; pub mod project; pub mod proxy; pub mod service; +pub mod task; pub mod worker; use crate::service::{ContainerSettings, GatewayService}; @@ -169,22 +170,22 @@ impl<'de> Deserialize<'de> for AccountName { } } -pub trait Context<'c>: Send + Sync { - fn docker(&self) -> &'c Docker; +pub trait DockerContext: Send + Sync { + fn docker(&self) -> &Docker; - fn container_settings(&self) -> &'c ContainerSettings; + fn container_settings(&self) -> &ContainerSettings; } #[async_trait] -pub trait Service<'c> { - type Context: Context<'c>; +pub trait Service { + type Context; - type State: EndState<'c>; + type State: EndState; type Error; /// Asks for the latest available context for task execution - fn context(&'c self) -> Self::Context; + fn context(&self) -> Self::Context; /// Commit a state update to persistence async fn update(&self, state: &Self::State) -> Result<(), Self::Error>; @@ -193,42 +194,39 @@ pub trait Service<'c> { /// A generic state which can, when provided with a [`Context`], do /// some work and advance itself #[async_trait] -pub trait State<'c>: Send + Sized + Clone { +pub trait State: Send { type Next; type Error; - async fn next>(self, ctx: &C) -> Result; + async fn next(self, ctx: &Ctx) -> Result; } -/// A [`State`] which contains all its transitions, including -/// failures -pub trait EndState<'c> +pub type StateTryStream<'c, St, Err> = Pin> + Send + 'c>>; + +pub trait EndState where - Self: State<'c, Error = Infallible, Next = Self>, + Self: State, { - type ErrorVariant; - fn is_done(&self) -> bool; - - fn into_result(self) -> Result; } -pub type StateTryStream<'c, St, Err> = Pin> + Send + 'c>>; - -pub trait EndStateExt<'c>: EndState<'c> { +pub trait EndStateExt: TryState + EndState +where + Ctx: Sync, + Self: Clone, +{ /// Convert the state into a [`TryStream`] that yields /// the generated states. /// /// This stream will not end. - fn into_stream(self, ctx: Ctx) -> StateTryStream<'c, Self, Self::ErrorVariant> + fn into_stream<'c>(self, ctx: &'c Ctx) -> StateTryStream<'c, Self, Self::ErrorVariant> where Self: 'c, - Ctx: 'c + Context<'c>, { Box::pin(stream::try_unfold((self, ctx), |(state, ctx)| async move { state - .next(&ctx) + .next(ctx) .await .unwrap() // EndState's `next` is Infallible .into_result() @@ -237,29 +235,42 @@ pub trait EndStateExt<'c>: EndState<'c> { } } -impl<'c, S> EndStateExt<'c> for S where S: EndState<'c> {} +impl EndStateExt for S +where + S: Clone + TryState + EndState, + Ctx: Send + Sync, +{ +} + +/// A [`State`] which contains all its transitions, including +/// failures +pub trait TryState: Sized { + type ErrorVariant; + + fn into_result(self) -> Result; +} -pub trait IntoEndState<'c, E> +pub trait IntoTryState where - E: EndState<'c>, + S: TryState, { - fn into_end_state(self) -> Result; + fn into_try_state(self) -> Result; } -impl<'c, E, S, Err> IntoEndState<'c, E> for Result +impl IntoTryState for Result where - E: EndState<'c> + From + From, + S: TryState + From + From, { - fn into_end_state(self) -> Result { - self.map(|s| E::from(s)).or_else(|err| Ok(E::from(err))) + fn into_try_state(self) -> Result { + self.map(|s| S::from(s)).or_else(|err| Ok(S::from(err))) } } #[async_trait] -pub trait Refresh: Sized { +pub trait Refresh: Sized { type Error: StdError; - async fn refresh<'c, C: Context<'c>>(self, ctx: &C) -> Result; + async fn refresh(self, ctx: &Ctx) -> Result; } #[cfg(test)] @@ -284,7 +295,6 @@ pub mod tests { use shuttle_common::models::{project, service}; use sqlx::SqlitePool; use tokio::sync::mpsc::channel; - use tracing::info; use crate::api::make_api; use crate::args::{ContextArgs, StartArgs}; @@ -292,7 +302,7 @@ pub mod tests { use crate::proxy::make_proxy; use crate::service::{ContainerSettings, GatewayService, MIGRATIONS}; use crate::worker::Worker; - use crate::Context; + use crate::DockerContext; macro_rules! value_block_helper { ($next:ident, $block:block) => { @@ -355,7 +365,7 @@ pub mod tests { $($(#[$($meta:tt)*])* $($patterns:pat_param)|+ $(if $guards:expr)? $(=> $mores:block)?,)+ } => {{ let state = $state; - let mut stream = crate::EndStateExt::into_stream(state, $ctx); + let mut stream = crate::EndStateExt::into_stream(state, &$ctx); assert_stream_matches!( stream, $($(#[$($meta)*])* $($patterns)|+ $(if $guards)? $(=> $mores)?,)+ @@ -487,11 +497,11 @@ pub mod tests { pool: SqlitePool, } - #[derive(Clone, Copy)] - pub struct WorldContext<'c> { - pub docker: &'c Docker, - pub container_settings: &'c ContainerSettings, - pub hyper: &'c HyperClient, + #[derive(Clone)] + pub struct WorldContext { + pub docker: Docker, + pub container_settings: ContainerSettings, + pub hyper: HyperClient, } impl World { @@ -579,20 +589,20 @@ pub mod tests { impl World { pub fn context(&self) -> WorldContext { WorldContext { - docker: &self.docker, - container_settings: &self.settings, - hyper: &self.hyper, + docker: self.docker.clone(), + container_settings: self.settings.clone(), + hyper: self.hyper.clone(), } } } - impl<'c> Context<'c> for WorldContext<'c> { - fn docker(&self) -> &'c Docker { - self.docker + impl DockerContext for WorldContext { + fn docker(&self) -> &Docker { + &self.docker } - fn container_settings(&self) -> &'c ContainerSettings { - self.container_settings + fn container_settings(&self) -> &ContainerSettings { + &self.container_settings } } @@ -600,17 +610,19 @@ pub mod tests { async fn end_to_end() { let world = World::new().await; let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); - let worker = Worker::new(Arc::clone(&service)); + let worker = Worker::new(); let (log_out, mut log_in) = channel(256); tokio::spawn({ let sender = worker.sender(); async move { while let Some(work) = log_in.recv().await { - info!("work: {work:?}"); - sender.send(work).await.unwrap() + sender + .send(work) + .await + .map_err(|_| "could not send work") + .unwrap(); } - info!("work channel closed"); } }); @@ -777,8 +789,8 @@ pub mod tests { ) .await .unwrap(); - println!("{resp:?}"); - if matches!(resp.status(), StatusCode::NOT_FOUND) { + let resp = serde_json::from_slice::(resp.body().as_slice()).unwrap(); + if matches!(resp.state, project::State::Destroyed) { break; } }); diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 0696bad8e..a05d5df14 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -3,16 +3,18 @@ use futures::prelude::*; use opentelemetry::global; use shuttle_gateway::args::{Args, Commands, ExecCmd, ExecCmds, InitArgs}; use shuttle_gateway::auth::Key; +use shuttle_gateway::project; use shuttle_gateway::proxy::make_proxy; use shuttle_gateway::service::{GatewayService, MIGRATIONS}; -use shuttle_gateway::worker::{Work, Worker}; +use shuttle_gateway::task; +use shuttle_gateway::worker::Worker; use shuttle_gateway::{api::make_api, args::StartArgs}; -use shuttle_gateway::{project, Refresh, Service}; use sqlx::migrate::MigrateDatabase; use sqlx::{query, Sqlite, SqlitePool}; use std::io; use std::path::Path; use std::sync::Arc; +use std::time::Duration; use tracing::{debug, error, info, trace}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -71,43 +73,26 @@ async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> { .to_string(); let gateway = Arc::new(GatewayService::init(args.context.clone(), db).await); - let worker = Worker::new(Arc::clone(&gateway)); + let worker = Worker::new(); let sender = worker.sender(); - let gateway_clone = gateway.clone(); - let sender_clone = sender.clone(); - - tokio::spawn(async move { - for Work { - project_name, - account_name, - work, - } in gateway_clone - .iter_projects() + for (project_name, account_name) in gateway + .iter_projects() + .await + .expect("could not list projects") + { + gateway + .clone() + .new_task() + .project(project_name) + .account(account_name) + .and_then(task::refresh()) + .send(&sender) .await - .expect("could not list projects") - { - match work.refresh(&gateway_clone.context()).await { - Ok(work) => sender_clone - .send(Work { - account_name, - project_name, - work, - }) - .await - .unwrap(), - Err(err) => { - error!( - error = %err, - %account_name, - %project_name, - "could not refresh state. Skipping it for now.", - ); - } - } - } - }); + .ok() + .unwrap(); + } let worker_handle = tokio::spawn( worker @@ -116,6 +101,29 @@ async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> { .map_err(|err| error!("worker error: {}", err)), ); + // Every 60secs go over all `::Ready` projects and check their + // health + let ambulance_handle = tokio::spawn({ + let gateway = Arc::clone(&gateway); + let sender = sender.clone(); + async move { + loop { + tokio::time::sleep(Duration::from_secs(60)).await; + if let Ok(projects) = gateway.iter_projects().await { + for (project_name, account_name) in projects { + let _ = gateway + .new_task() + .project(project_name) + .account(account_name) + .and_then(task::check_health()) + .send(&sender) + .await; + } + } + } + } + }); + let api = make_api(Arc::clone(&gateway), sender); let api_handle = tokio::spawn(axum::Server::bind(&args.control).serve(api.into_make_service())); @@ -128,8 +136,9 @@ async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> { tokio::select!( _ = worker_handle => info!("worker handle finished"), - _ = api_handle => info!("api handle finished"), - _ = proxy_handle => info!("proxy handle finished"), + _ = api_handle => error!("api handle finished"), + _ = proxy_handle => error!("proxy handle finished"), + _ = ambulance_handle => error!("ambulance handle finished"), ); Ok(()) diff --git a/gateway/src/project.rs b/gateway/src/project.rs index 5e9ec6ecb..92e539731 100644 --- a/gateway/src/project.rs +++ b/gateway/src/project.rs @@ -9,17 +9,18 @@ use bollard::errors::Error as DockerError; use bollard::models::{ContainerConfig, ContainerInspectResponse, ContainerStateStatusEnum}; use futures::prelude::*; use http::uri::InvalidUri; -use http::StatusCode; +use http::Uri; use hyper::client::HttpConnector; use hyper::Client; use once_cell::sync::Lazy; +use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Serialize}; use tokio::time; use tracing::{debug, error}; use crate::{ - ContainerSettings, Context, EndState, Error, ErrorKind, IntoEndState, ProjectName, Refresh, - State, + ContainerSettings, DockerContext, EndState, Error, ErrorKind, IntoTryState, ProjectName, + Refresh, State, TryState, }; macro_rules! safe_unwrap { @@ -58,14 +59,18 @@ macro_rules! impl_from_variant { } const RUNTIME_API_PORT: u16 = 8001; +const MAX_RESTARTS: i64 = 3; // Client used for health checks static CLIENT: Lazy> = Lazy::new(Client::new); #[async_trait] -impl Refresh for ContainerInspectResponse { +impl Refresh for ContainerInspectResponse +where + Ctx: DockerContext, +{ type Error = DockerError; - async fn refresh<'c, C: Context<'c>>(self, ctx: &C) -> Result { + async fn refresh(self, ctx: &Ctx) -> Result { ctx.docker() .inspect_container(self.id.as_ref().unwrap(), None) .await @@ -116,6 +121,10 @@ impl Project { } } + pub fn create(project_name: ProjectName) -> Self { + Self::Creating(ProjectCreating::new_with_random_initial_key(project_name)) + } + pub fn destroy(self) -> Result { if let Some(container) = self.container() { Ok(Self::Destroying(ProjectDestroying { container })) @@ -124,6 +133,14 @@ impl Project { } } + pub fn is_ready(&self) -> bool { + matches!(self, Self::Ready(_)) + } + + pub fn is_destroyed(&self) -> bool { + matches!(self, Self::Destroyed(_)) + } + pub fn target_ip(&self) -> Result, Error> { match self.clone() { Self::Ready(project_ready) => Ok(Some(*project_ready.target_ip())), @@ -164,6 +181,14 @@ impl Project { } } + pub fn initial_key(&self) -> Option<&String> { + if let Self::Creating(ProjectCreating { initial_key, .. }) = self { + Some(initial_key) + } else { + None + } + } + pub fn container_id(&self) -> Option { self.container().and_then(|container| container.id) } @@ -186,27 +211,30 @@ impl From for shuttle_common::models::project::State { } #[async_trait] -impl<'c> State<'c> for Project { +impl State for Project +where + Ctx: DockerContext, +{ type Next = Self; type Error = Infallible; - async fn next>(self, ctx: &C) -> Result { + async fn next(self, ctx: &Ctx) -> Result { let previous = self.clone(); let previous_state = previous.state(); let mut new = match self { - Self::Creating(creating) => creating.next(ctx).await.into_end_state(), - Self::Starting(ready) => ready.next(ctx).await.into_end_state(), + Self::Creating(creating) => creating.next(ctx).await.into_try_state(), + Self::Starting(ready) => ready.next(ctx).await.into_try_state(), Self::Started(started) => match started.next(ctx).await { Ok(ProjectReadying::Ready(ready)) => Ok(ready.into()), Ok(ProjectReadying::Started(started)) => Ok(started.into()), Err(err) => Ok(Self::Errored(err)), }, - Self::Ready(ready) => ready.next(ctx).await.into_end_state(), - Self::Stopped(stopped) => stopped.next(ctx).await.into_end_state(), - Self::Stopping(stopping) => stopping.next(ctx).await.into_end_state(), - Self::Destroying(destroying) => destroying.next(ctx).await.into_end_state(), - Self::Destroyed(destroyed) => destroyed.next(ctx).await.into_end_state(), + Self::Ready(ready) => ready.next(ctx).await.into_try_state(), + Self::Stopped(stopped) => stopped.next(ctx).await.into_try_state(), + Self::Stopping(stopping) => stopping.next(ctx).await.into_try_state(), + Self::Destroying(destroying) => destroying.next(ctx).await.into_try_state(), + Self::Destroyed(destroyed) => destroyed.next(ctx).await.into_try_state(), Self::Errored(errored) => Ok(Self::Errored(errored)), }; @@ -228,15 +256,17 @@ impl<'c> State<'c> for Project { } } -impl<'c> EndState<'c> for Project { - type ErrorVariant = ProjectError; - +impl EndState for Project +where + Ctx: DockerContext, +{ fn is_done(&self) -> bool { - matches!( - self, - Self::Errored(_) | Self::Ready(_) | Self::Stopped(_) | Self::Destroyed(_) - ) + matches!(self, Self::Errored(_) | Self::Ready(_) | Self::Destroyed(_)) } +} + +impl TryState for Project { + type ErrorVariant = ProjectError; fn into_result(self) -> Result { match self { @@ -247,7 +277,10 @@ impl<'c> EndState<'c> for Project { } #[async_trait] -impl Refresh for Project { +impl Refresh for Project +where + Ctx: DockerContext, +{ type Error = Error; /// TODO: we could be a bit more clever than this by using the @@ -255,7 +288,7 @@ impl Refresh for Project { /// state which is probably prone to erroneously setting the /// project into the wrong state if the docker is transitioning /// the state of its resources under us - async fn refresh<'c, C: Context<'c>>(self, ctx: &C) -> Result { + async fn refresh(self, ctx: &Ctx) -> Result { let _container = if let Some(container_id) = self.container_id() { Some(ctx.docker().inspect_container(&container_id, None).await?) } else { @@ -272,7 +305,8 @@ impl Refresh for Project { let container = container.refresh(ctx).await?; match container.state.as_ref().unwrap().status.as_ref().unwrap() { ContainerStateStatusEnum::RUNNING => { - Self::Started(ProjectStarted { container }) + let service = Service::from_container(container.clone())?; + Self::Started(ProjectStarted { container, service }) } ContainerStateStatusEnum::CREATED => { Self::Starting(ProjectStarting { container }) @@ -308,11 +342,16 @@ impl ProjectCreating { } } + pub fn new_with_random_initial_key(project_name: ProjectName) -> Self { + let initial_key = Alphanumeric.sample_string(&mut rand::thread_rng(), 32); + Self::new(project_name, initial_key) + } + pub fn project_name(&self) -> &ProjectName { &self.project_name } - fn container_name<'c, C: Context<'c>>(&self, ctx: &C) -> String { + fn container_name(&self, ctx: &C) -> String { let prefix = &ctx.container_settings().prefix; let Self { project_name, .. } = &self; @@ -320,7 +359,7 @@ impl ProjectCreating { format!("{prefix}{project_name}_run") } - fn generate_container_config<'c, C: Context<'c>>( + fn generate_container_config( &self, ctx: &C, ) -> (CreateContainerOptions, Config) { @@ -408,11 +447,14 @@ Config: {config:#?} } #[async_trait] -impl<'c> State<'c> for ProjectCreating { +impl State for ProjectCreating +where + Ctx: DockerContext, +{ type Next = ProjectStarting; type Error = ProjectError; - async fn next>(self, ctx: &C) -> Result { + async fn next(self, ctx: &Ctx) -> Result { let container_name = self.container_name(ctx); let container = ctx .docker() @@ -441,11 +483,14 @@ pub struct ProjectStarting { } #[async_trait] -impl<'c> State<'c> for ProjectStarting { +impl State for ProjectStarting +where + Ctx: DockerContext, +{ type Next = ProjectStarted; type Error = ProjectError; - async fn next>(self, ctx: &C) -> Result { + async fn next(self, ctx: &Ctx) -> Result { let container_id = self.container.id.as_ref().unwrap(); ctx.docker() .start_container::(container_id, None) @@ -459,15 +504,18 @@ impl<'c> State<'c> for ProjectStarting { } })?; - Ok(Self::Next { - container: self.container.refresh(ctx).await?, - }) + let container = self.container.refresh(ctx).await?; + + let service = Service::from_container(container.clone())?; + + Ok(Self::Next { container, service }) } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct ProjectStarted { container: ContainerInspectResponse, + service: Service, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -477,35 +525,20 @@ pub enum ProjectReadying { } #[async_trait] -impl<'c> State<'c> for ProjectStarted { +impl State for ProjectStarted +where + Ctx: DockerContext, +{ type Next = ProjectReadying; type Error = ProjectError; - async fn next>(self, ctx: &C) -> Result { + async fn next(self, ctx: &Ctx) -> Result { time::sleep(Duration::from_secs(1)).await; + let container = self.container.refresh(ctx).await?; - let ready_service = if matches!( - safe_unwrap!(container.state.status), - ContainerStateStatusEnum::RUNNING - ) { - let service = Service::from_container(container.clone())?; - let uri = format!( - "http://{}:8001/projects/{}/status", - service.target, service.name - ); - let uri = uri.parse()?; - let res = CLIENT.get(uri).await?; - - if res.status() == StatusCode::OK { - Some(service) - } else { - None - } - } else { - None - }; + let mut service = self.service; - if let Some(service) = ready_service { + if service.is_healthy().await { Ok(Self::Next::Ready(ProjectReady { container, service })) } else { let started_at = @@ -520,7 +553,7 @@ impl<'c> State<'c> for ProjectStarted { )); } - Ok(Self::Next::Started(ProjectStarted { container })) + Ok(Self::Next::Started(ProjectStarted { container, service })) } } } @@ -532,11 +565,14 @@ pub struct ProjectReady { } #[async_trait] -impl<'c> State<'c> for ProjectReady { +impl State for ProjectReady +where + Ctx: DockerContext, +{ type Next = Self; type Error = ProjectError; - async fn next>(self, _ctx: &C) -> Result { + async fn next(mut self, _ctx: &Ctx) -> Result { Ok(self) } } @@ -549,12 +585,32 @@ impl ProjectReady { pub fn target_ip(&self) -> &IpAddr { &self.service.target } + + pub async fn is_healthy(&mut self) -> bool { + self.service.is_healthy().await + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct HealthCheckRecord { + at: chrono::DateTime, + is_healthy: bool, +} + +impl HealthCheckRecord { + pub fn new(is_healthy: bool) -> Self { + Self { + at: chrono::Utc::now(), + is_healthy, + } + } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Service { name: String, target: IpAddr, + last_check: Option, } impl Service { @@ -572,8 +628,22 @@ impl Service { Ok(Self { name: resource_name, target, + last_check: None, }) } + + pub fn uri>(&self, path: S) -> Result { + format!("http://{}:8001{}", self.target, path.as_ref()) + .parse::() + .map_err(|err| err.into()) + } + + pub async fn is_healthy(&mut self) -> bool { + let uri = self.uri(format!("/projects/{}/status", self.name)).unwrap(); + let is_healthy = matches!(CLIENT.get(uri).await, Ok(res) if res.status().is_success()); + self.last_check = Some(HealthCheckRecord::new(is_healthy)); + is_healthy + } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -582,12 +652,15 @@ pub struct ProjectStopping { } #[async_trait] -impl<'c> State<'c> for ProjectStopping { +impl State for ProjectStopping +where + Ctx: DockerContext, +{ type Next = ProjectStopped; type Error = ProjectError; - async fn next>(self, ctx: &C) -> Result { + async fn next(self, ctx: &Ctx) -> Result { let Self { container } = self; ctx.docker() .stop_container( @@ -607,15 +680,22 @@ pub struct ProjectStopped { } #[async_trait] -impl<'c> State<'c> for ProjectStopped { +impl State for ProjectStopped +where + Ctx: DockerContext, +{ type Next = ProjectStarting; type Error = ProjectError; - async fn next>(self, _ctx: &C) -> Result { - // If stopped, try to restart - Ok(ProjectStarting { - container: self.container, - }) + async fn next(self, _ctx: &Ctx) -> Result { + // If stopped, and has not restarted too much, try to restart + if self.container.restart_count.unwrap_or_default() < MAX_RESTARTS { + Ok(ProjectStarting { + container: self.container, + }) + } else { + Err(ProjectError::internal("too many restarts")) + } } } @@ -625,11 +705,14 @@ pub struct ProjectDestroying { } #[async_trait] -impl<'c> State<'c> for ProjectDestroying { +impl State for ProjectDestroying +where + Ctx: DockerContext, +{ type Next = ProjectDestroyed; type Error = ProjectError; - async fn next>(self, ctx: &C) -> Result { + async fn next(self, ctx: &Ctx) -> Result { let container_id = self.container.id.as_ref().unwrap(); ctx.docker() .stop_container(container_id, Some(StopContainerOptions { t: 1 })) @@ -657,11 +740,14 @@ pub struct ProjectDestroyed { } #[async_trait] -impl<'c> State<'c> for ProjectDestroyed { +impl State for ProjectDestroyed +where + Ctx: DockerContext, +{ type Next = ProjectDestroyed; type Error = ProjectError; - async fn next>(self, _ctx: &C) -> Result { + async fn next(self, _ctx: &Ctx) -> Result { Ok(self) } } @@ -739,38 +825,42 @@ impl From for Error { } #[async_trait] -impl<'c> State<'c> for ProjectError { +impl State for ProjectError +where + Ctx: DockerContext, +{ type Next = Self; type Error = Infallible; - async fn next>(self, _ctx: &C) -> Result { + async fn next(self, _ctx: &Ctx) -> Result { Ok(self) } } pub mod exec { + use std::sync::Arc; + use bollard::service::ContainerState; use crate::{ service::GatewayService, - worker::{do_work, Work}, + task::{self, TaskResult}, }; use super::*; pub async fn revive(gateway: GatewayService) -> Result<(), ProjectError> { let mut mutations = Vec::new(); + let gateway = Arc::new(gateway); - for Work { - project_name, - account_name, - work, - } in gateway + for (project_name, account_name) in gateway .iter_projects() .await .expect("could not list projects") { - if let Project::Errored(ProjectError { ctx: Some(ctx), .. }) = work { + if let Project::Errored(ProjectError { ctx: Some(ctx), .. }) = + gateway.find_project(&project_name).await.unwrap() + { if let Some(container) = ctx.container() { if let Ok(container) = gateway .context() @@ -783,21 +873,28 @@ pub mod exec { .. }) = container.state { - mutations.push(Work { - project_name, - account_name, - work: Project::Stopped(ProjectStopped { container }), - }); + mutations.push(( + project_name.clone(), + gateway + .new_task() + .project(project_name) + .account(account_name) + .and_then(task::run(|ctx| async move { + TaskResult::Done(Project::Stopped(ProjectStopped { + container: ctx.state.container().unwrap(), + })) + })) + .build(), + )); } } } } } - for work in mutations { - debug!(?work, "project will be revived"); - - do_work(work, &gateway).await; + for (project_name, mut work) in mutations { + debug!(?project_name, "project will be revived"); + while let TaskResult::Pending(_) = work.poll(()).await {} } Ok(()) @@ -856,7 +953,7 @@ pub mod tests { futures::pin_mut!(delay); let mut project_readying = project_started .unwrap() - .into_stream(ctx) + .into_stream(&ctx) .take_until(delay) .try_skip_while(|state| future::ready(Ok(!matches!(state, Project::Ready(_))))); diff --git a/gateway/src/service.rs b/gateway/src/service.rs index 54d0065e8..5460f9a96 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -14,7 +14,6 @@ use hyper_reverse_proxy::ReverseProxy; use once_cell::sync::Lazy; use opentelemetry::global; use opentelemetry_http::HeaderInjector; -use rand::distributions::{Alphanumeric, DistString}; use sqlx::error::DatabaseError; use sqlx::migrate::Migrator; use sqlx::sqlite::SqlitePool; @@ -25,9 +24,9 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::args::ContextArgs; use crate::auth::{Key, User}; -use crate::project::{self, Project}; -use crate::worker::Work; -use crate::{AccountName, Context, Error, ErrorKind, ProjectName, Service}; +use crate::project::Project; +use crate::task::TaskBuilder; +use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName}; pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations"); static PROXY_CLIENT: Lazy>> = @@ -148,6 +147,7 @@ impl<'d> ContainerSettingsBuilder<'d> { } } +#[derive(Clone)] pub struct ContainerSettings { pub prefix: String, pub image: String, @@ -175,8 +175,8 @@ impl GatewayContextProvider { pub fn context(&self) -> GatewayContext { GatewayContext { - docker: &self.docker, - settings: &self.settings, + docker: self.docker.clone(), + settings: self.settings.clone(), } } } @@ -234,16 +234,14 @@ impl GatewayService { Ok(resp) } - pub async fn iter_projects(&self) -> Result, Error> { - let iter = query("SELECT * FROM projects") + pub async fn iter_projects( + &self, + ) -> Result, Error> { + let iter = query("SELECT project_name, account_name FROM projects") .fetch_all(&self.db) .await? .into_iter() - .map(|row| Work { - project_name: row.get("project_name"), - work: row.get::, _>("project_state").0, - account_name: row.get("account_name"), - }); + .map(|row| (row.get("project_name"), row.get("account_name"))); Ok(iter) } @@ -260,21 +258,16 @@ impl GatewayService { .ok_or_else(|| Error::from_kind(ErrorKind::ProjectNotFound)) } - async fn update_project( + pub async fn update_project( &self, project_name: &ProjectName, project: &Project, ) -> Result<(), Error> { - let query = match project { - Project::Destroyed(_) => { - query("DELETE FROM projects WHERE project_name = ?1").bind(project_name) - } - _ => query("UPDATE projects SET project_state = ?1 WHERE project_name = ?2") - .bind(SqlxJson(project)) - .bind(project_name), - }; - - query.execute(&self.db).await?; + query("UPDATE projects SET project_state = ?1 WHERE project_name = ?2") + .bind(SqlxJson(project)) + .bind(project_name) + .execute(&self.db) + .await?; Ok(()) } @@ -402,18 +395,48 @@ impl GatewayService { &self, project_name: ProjectName, account_name: AccountName, - ) -> Result { - let initial_key = Alphanumeric.sample_string(&mut rand::thread_rng(), 32); + ) -> Result { + if let Some(row) = query("SELECT project_name, account_name, initial_key, project_state FROM projects WHERE project_name = ?1 AND account_name = ?2") + .bind(&project_name) + .bind(&account_name) + .fetch_optional(&self.db) + .await? + { + // If the project already exists and belongs to this account + let project = row.get::, _>("project_state").0; + if project.is_destroyed() { + // But is in `::Destroyed` state, recreate it + let project = SqlxJson(Project::create(project_name.clone())); + query("UPDATE projects SET project_state = ?1 AND initial_key = ?2 WHERE project_name = ?3") + .bind(&project) + .bind(project.initial_key().unwrap()) + .bind(&project_name) + .execute(&self.db) + .await?; + Ok(project.0) + } else { + // Otherwise it already exists + Err(Error::from_kind(ErrorKind::ProjectAlreadyExists)) + } + } else { + // Otherwise attempt to create a new one. This will fail + // outright if the project already exists (this happens if + // it belongs to another account). + self.insert_project(project_name, account_name).await + } + } - let project = SqlxJson(Project::Creating(project::ProjectCreating::new( - project_name.clone(), - initial_key.clone(), - ))); + pub async fn insert_project( + &self, + project_name: ProjectName, + account_name: AccountName, + ) -> Result { + let project = SqlxJson(Project::create(project_name.clone())); query("INSERT INTO projects (project_name, account_name, initial_key, project_state) VALUES (?1, ?2, ?3, ?4)") .bind(&project_name) .bind(&account_name) - .bind(&initial_key) + .bind(project.initial_key().unwrap()) .bind(&project) .execute(&self.db) .await @@ -431,88 +454,32 @@ impl GatewayService { let project = project.0; - Ok(Work { - project_name, - account_name, - work: project, - }) - } - - pub async fn destroy_project( - &self, - project_name: ProjectName, - account_name: AccountName, - ) -> Result { - let project = self.find_project(&project_name).await?.destroy()?; - - Ok(Work { - project_name, - account_name, - work: project, - }) + Ok(project) } pub fn context(&self) -> GatewayContext { self.provider.context() } -} - -#[async_trait] -impl<'c> Service<'c> for GatewayService { - type Context = GatewayContext<'c>; - - type State = Work; - - type Error = Error; - - fn context(&'c self) -> Self::Context { - GatewayService::context(self) - } - async fn update( - &self, - Work { - project_name, work, .. - }: &Self::State, - ) -> Result<(), Self::Error> { - self.update_project(project_name, work).await - } -} - -#[async_trait] -impl<'c> Service<'c> for Arc { - type Context = GatewayContext<'c>; - - type State = Work; - - type Error = Error; - - fn context(&'c self) -> Self::Context { - GatewayService::context(self) - } - - async fn update( - &self, - Work { - project_name, work, .. - }: &Self::State, - ) -> Result<(), Self::Error> { - self.update_project(project_name, work).await + /// Create a builder for a new [ProjectTask] + pub fn new_task(self: &Arc) -> TaskBuilder { + TaskBuilder::new(self.clone()) } } -pub struct GatewayContext<'c> { - docker: &'c Docker, - settings: &'c ContainerSettings, +#[derive(Clone)] +pub struct GatewayContext { + docker: Docker, + settings: ContainerSettings, } -impl<'c> Context<'c> for GatewayContext<'c> { - fn docker(&self) -> &'c Docker { - self.docker +impl DockerContext for GatewayContext { + fn docker(&self) -> &Docker { + &self.docker } - fn container_settings(&self) -> &'c ContainerSettings { - self.settings + fn container_settings(&self) -> &ContainerSettings { + &self.settings } } @@ -522,7 +489,9 @@ pub mod tests { use std::str::FromStr; use super::*; + use crate::task::{self, TaskResult}; use crate::tests::{assert_err_kind, World}; + use crate::{Error, ErrorKind}; #[tokio::test] async fn service_create_find_user() -> anyhow::Result<()> { @@ -581,6 +550,7 @@ pub mod tests { let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); let neo: AccountName = "neo".parse().unwrap(); + let trinity: AccountName = "trinity".parse().unwrap(); let matrix: ProjectName = "matrix".parse().unwrap(); let creating_same_project_name = |project: &Project, project_name: &ProjectName| { @@ -591,35 +561,110 @@ pub mod tests { }; svc.create_user(neo.clone()).await.unwrap(); + svc.create_user(trinity.clone()).await.unwrap(); - let work = svc + let project = svc .create_project(matrix.clone(), neo.clone()) .await .unwrap(); - // work work work work - let project = work.work; - assert!(creating_same_project_name(&project, &matrix)); assert_eq!(svc.find_project(&matrix).await.unwrap(), project); - let work = svc.destroy_project(matrix.clone(), neo).await.unwrap(); - - let project = work.work; + let mut work = svc + .new_task() + .project(matrix.clone()) + .account(neo.clone()) + .and_then(task::destroy()) + .build(); - assert!(matches!(&project, Project::Destroyed(_))); - - svc.update_project(&matrix, &project).await.unwrap(); + while let TaskResult::Pending(_) = work.poll(()).await {} + assert!(matches!(work.poll(()).await, TaskResult::Done(()))); + // After project has been destroyed... assert!(matches!( svc.find_project(&matrix).await, + Ok(Project::Destroyed(_)) + )); + + // If recreated by a different user + assert!(matches!( + svc.create_project(matrix.clone(), trinity.clone()).await, Err(Error { - kind: ErrorKind::ProjectNotFound, + kind: ErrorKind::ProjectAlreadyExists, .. }) )); + // If recreated by the same user + assert!(matches!( + svc.create_project(matrix, neo).await, + Ok(Project::Creating(_)) + )); + + Ok(()) + } + + #[tokio::test] + async fn service_create_ready_kill_restart_docker() -> anyhow::Result<()> { + let world = World::new().await; + let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + + let neo: AccountName = "neo".parse().unwrap(); + let matrix: ProjectName = "matrix".parse().unwrap(); + + svc.create_user(neo.clone()).await.unwrap(); + svc.create_project(matrix.clone(), neo.clone()) + .await + .unwrap(); + + let mut task = svc + .new_task() + .account(neo.clone()) + .project(matrix.clone()) + .build(); + + while let TaskResult::Pending(_) = task.poll(()).await { + // keep polling + } + + let project = svc.find_project(&matrix).await.unwrap(); + println!("{:?}", project); + assert!(project.is_ready()); + + let container = project.container().unwrap(); + svc.context() + .docker() + .kill_container::(container.name.unwrap().strip_prefix('/').unwrap(), None) + .await + .unwrap(); + + println!("killed container"); + + let mut ambulance_task = svc + .new_task() + .project(matrix.clone()) + .account(neo.clone()) + .and_then(task::check_health()) + .build(); + + // the first poll will trigger a refresh + let _ = ambulance_task.poll(()).await; + + let project = svc.find_project(&matrix).await.unwrap(); + println!("{:?}", project); + assert!(!project.is_ready()); + + // the subsequent will trigger a restart task + while let TaskResult::Pending(_) = ambulance_task.poll(()).await { + // keep polling + } + + let project = svc.find_project(&matrix).await.unwrap(); + println!("{:?}", project); + assert!(project.is_ready()); + Ok(()) } } diff --git a/gateway/src/task.rs b/gateway/src/task.rs new file mode 100644 index 000000000..2a1d38f1f --- /dev/null +++ b/gateway/src/task.rs @@ -0,0 +1,404 @@ +use futures::Future; +use std::collections::VecDeque; +use std::marker::PhantomData; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::Sender; +use tracing::warn; +use uuid::Uuid; + +use crate::project::*; +use crate::service::{GatewayContext, GatewayService}; +use crate::{AccountName, EndState, Error, ErrorKind, ProjectName, Refresh, State}; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300); + +#[async_trait] +pub trait Task: Send { + type Output; + + type Error; + + async fn poll(&mut self, ctx: Ctx) -> TaskResult; +} + +#[async_trait] +impl Task for Box +where + Ctx: Send + 'static, + T: Task + ?Sized, +{ + type Output = T::Output; + + type Error = T::Error; + + async fn poll(&mut self, ctx: Ctx) -> TaskResult { + self.as_mut().poll(ctx).await + } +} + +#[must_use] +#[derive(Debug, PartialEq, Eq)] +pub enum TaskResult { + /// More work needs to be done + Pending(R), + /// No further work needed + Done(R), + /// Try again later + TryAgain, + /// Task has been cancelled + Cancelled, + /// Task has failed + Err(E), +} + +impl TaskResult { + pub fn ok(self) -> Option { + match self { + Self::Pending(r) | Self::Done(r) => Some(r), + _ => None, + } + } + + pub fn as_ref(&self) -> TaskResult<&R, &E> { + match self { + Self::Pending(r) => TaskResult::Pending(r), + Self::Done(r) => TaskResult::Done(r), + Self::TryAgain => TaskResult::TryAgain, + Self::Cancelled => TaskResult::Cancelled, + Self::Err(e) => TaskResult::Err(e), + } + } +} + +pub fn run(f: F) -> impl Task +where + F: FnMut(ProjectContext) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + RunFn { + f, + _output: PhantomData, + } +} + +pub fn refresh() -> impl Task { + run(|ctx: ProjectContext| async move { + match ctx.state.refresh(&ctx.gateway).await { + Ok(new) => TaskResult::Done(new), + Err(err) => TaskResult::Err(err), + } + }) +} + +pub fn destroy() -> impl Task { + run(|ctx| async move { + match ctx.state.destroy() { + Ok(state) => TaskResult::Done(state), + Err(err) => TaskResult::Err(err), + } + }) +} + +pub fn check_health() -> impl Task { + run(|ctx| async move { + if let Project::Ready(mut ready) = ctx.state { + if ready.is_healthy().await { + TaskResult::Done(Project::Ready(ready)) + } else { + match Project::Ready(ready).refresh(&ctx.gateway).await { + Ok(update) => TaskResult::Done(update), + Err(err) => TaskResult::Err(err), + } + } + } else { + TaskResult::Err(Error::from_kind(ErrorKind::NotReady)) + } + }) +} + +pub fn run_until_done() -> impl Task { + RunUntilDone +} + +pub struct TaskBuilder { + project_name: Option, + account_name: Option, + service: Arc, + timeout: Option, + tasks: VecDeque>, +} + +impl TaskBuilder { + pub fn new(service: Arc) -> Self { + Self { + service, + project_name: None, + account_name: None, + timeout: None, + tasks: VecDeque::new(), + } + } +} + +impl TaskBuilder { + pub fn project(mut self, name: ProjectName) -> Self { + self.project_name = Some(name); + self + } + + pub fn account(mut self, name: AccountName) -> Self { + self.account_name = Some(name); + self + } + + pub fn and_then(mut self, task: T) -> Self + where + T: Task + 'static, + { + self.tasks.push_back(Box::new(task)); + self + } + + pub fn with_timeout(mut self, duration: Duration) -> Self { + self.timeout = Some(duration); + self + } + + pub fn build(mut self) -> BoxedTask { + self.tasks.push_back(Box::new(RunUntilDone)); + + let timeout = self.timeout.unwrap_or(DEFAULT_TIMEOUT); + + Box::new(WithTimeout::on( + timeout, + ProjectTask { + uuid: Uuid::new_v4(), + project_name: self.project_name.expect("project_name is required"), + account_name: self.account_name.expect("account_name is required"), + service: self.service, + tasks: self.tasks, + }, + )) + } + + pub async fn send(self, sender: &Sender) -> Result<(), SendError> { + sender.send(self.build()).await + } +} + +pub struct RunFn { + f: F, + _output: PhantomData, +} + +#[async_trait] +impl Task for RunFn +where + F: FnMut(ProjectContext) -> Fut + Send, + Fut: Future> + Send, +{ + type Output = Project; + + type Error = Error; + + async fn poll(&mut self, ctx: ProjectContext) -> TaskResult { + (self.f)(ctx).await + } +} + +/// Advance a project's state until it's returning `is_done` +pub struct RunUntilDone; + +#[async_trait] +impl Task for RunUntilDone { + type Output = Project; + + type Error = Error; + + async fn poll(&mut self, ctx: ProjectContext) -> TaskResult { + if !>::is_done(&ctx.state) { + TaskResult::Pending(ctx.state.next(&ctx.gateway).await.unwrap()) + } else { + TaskResult::Done(ctx.state) + } + } +} + +pub struct WithTimeout { + inner: T, + start: Option, + timeout: Duration, +} + +impl WithTimeout { + pub fn on(timeout: Duration, inner: T) -> Self { + Self { + inner, + start: None, + timeout, + } + } +} + +#[async_trait] +impl Task for WithTimeout +where + Ctx: Send + 'static, + T: Task, +{ + type Output = T::Output; + + type Error = T::Error; + + async fn poll(&mut self, ctx: Ctx) -> TaskResult { + if self.start.is_none() { + self.start = Some(Instant::now()); + } + + if Instant::now() - *self.start.as_ref().unwrap() > self.timeout { + warn!( + "task has timed out: was running for more than {}s", + self.timeout.as_secs() + ); + return TaskResult::Cancelled; + } + + self.inner.poll(ctx).await + } +} + +/// A collection of tasks scoped to a specific project. +/// +/// All the tasks in the collection are run to completion. If an error +/// is encountered, the `ProjectTask` completes early passing through +/// the error. The value returned by the inner tasks upon their +/// completion is committed back to persistence through +/// [GatewayService]. +pub struct ProjectTask { + uuid: Uuid, + project_name: ProjectName, + account_name: AccountName, + service: Arc, + tasks: VecDeque, +} + +impl ProjectTask { + pub fn uuid(&self) -> &Uuid { + &self.uuid + } +} + +/// A context for tasks which are scoped to a specific project. +/// +/// This will be always instantiated with the latest known state of +/// the project and gives access to the broader gateway context. +#[derive(Clone)] +pub struct ProjectContext { + /// The name of the project this task is about + pub project_name: ProjectName, + /// The name of the user the project belongs to + pub account_name: AccountName, + /// The gateway context in which this task is running + pub gateway: GatewayContext, + /// The last known state of the project + pub state: Project, +} + +pub type BoxedTask = Box>; + +#[async_trait] +impl Task<()> for ProjectTask +where + T: Task, +{ + type Output = (); + + type Error = Error; + + async fn poll(&mut self, _: ()) -> TaskResult { + if self.tasks.is_empty() { + return TaskResult::Done(()); + } + + let ctx = self.service.context(); + + let project = match self.service.find_project(&self.project_name).await { + Ok(project) => project, + Err(err) => return TaskResult::Err(err), + }; + + let project_ctx = ProjectContext { + project_name: self.project_name.clone(), + account_name: self.account_name.clone(), + gateway: ctx, + state: project, + }; + + let task = self.tasks.front_mut().unwrap(); + + let res = task.poll(project_ctx).await; + + if let Some(update) = res.as_ref().ok() { + match self + .service + .update_project(&self.project_name, update) + .await + { + Ok(_) => {} + Err(err) => return TaskResult::Err(err), + } + } + + match res { + TaskResult::Pending(_) => TaskResult::Pending(()), + TaskResult::TryAgain => TaskResult::TryAgain, + TaskResult::Done(_) => { + let _ = self.tasks.pop_front().unwrap(); + if self.tasks.is_empty() { + TaskResult::Done(()) + } else { + TaskResult::Pending(()) + } + } + TaskResult::Cancelled => TaskResult::Cancelled, + TaskResult::Err(err) => TaskResult::Err(err), + } + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + struct NeverEnding; + + #[async_trait] + impl Task<()> for NeverEnding { + type Output = (); + + type Error = (); + + async fn poll(&mut self, _ctx: ()) -> TaskResult { + TaskResult::Pending(()) + } + } + + #[tokio::test] + async fn task_with_timeout() -> anyhow::Result<()> { + let timeout = Duration::from_secs(1); + + let mut task_with_timeout = WithTimeout::on(timeout, NeverEnding); + + let start = Instant::now(); + + while let TaskResult::Pending(()) = task_with_timeout.poll(()).await { + assert!(Instant::now() - start <= timeout + Duration::from_secs(1)); + } + + assert_eq!(task_with_timeout.poll(()).await, TaskResult::Cancelled); + + Ok(()) + } +} diff --git a/gateway/src/worker.rs b/gateway/src/worker.rs index d371bdb66..9584fa245 100644 --- a/gateway/src/worker.rs +++ b/gateway/src/worker.rs @@ -1,88 +1,32 @@ -use std::fmt::Debug; -use std::sync::Arc; - use tokio::sync::mpsc::{channel, Receiver, Sender}; use tracing::{debug, info}; -use crate::project::Project; -use crate::service::GatewayService; -use crate::{AccountName, Context, EndState, Error, ProjectName, Refresh, Service, State}; - -#[must_use] -#[derive(Debug, Clone)] -pub struct Work { - pub project_name: ProjectName, - pub account_name: AccountName, - pub work: W, -} - -#[async_trait] -impl Refresh for Work -where - W: Refresh + Send, -{ - type Error = W::Error; +use crate::task::{BoxedTask, TaskResult}; +use crate::Error; - async fn refresh<'c, C: Context<'c>>(self, ctx: &C) -> Result { - Ok(Self { - project_name: self.project_name, - account_name: self.account_name, - work: self.work.refresh(ctx).await?, - }) - } -} +const WORKER_QUEUE_SIZE: usize = 2048; -#[async_trait] -impl<'c, W> State<'c> for Work -where - W: State<'c>, -{ - type Next = Work; - - type Error = W::Error; - - async fn next>(self, ctx: &C) -> Result { - Ok(Work:: { - project_name: self.project_name, - account_name: self.account_name, - work: self.work.next(ctx).await?, - }) - } +pub struct Worker { + send: Option>, + recv: Receiver, } -impl<'c, W> EndState<'c> for Work +impl Default for Worker where - W: EndState<'c>, + W: Send, { - type ErrorVariant = W::ErrorVariant; - - fn is_done(&self) -> bool { - self.work.is_done() - } - - fn into_result(self) -> Result { - Ok(Self { - project_name: self.project_name, - account_name: self.account_name, - work: self.work.into_result()?, - }) + fn default() -> Self { + Self::new() } } -pub struct Worker, W = Work> { - service: Svc, - send: Option>, - recv: Receiver, -} - -impl Worker +impl Worker where W: Send, { - pub fn new(service: Svc) -> Self { - let (send, recv) = channel(32); + pub fn new() -> Self { + let (send, recv) = channel(WORKER_QUEUE_SIZE); Self { - service, send: Some(send), recv, } @@ -97,11 +41,7 @@ where } } -impl Worker -where - Svc: for<'c> Service<'c, State = W, Error = Error>, - W: Debug + Send + for<'c> EndState<'c>, -{ +impl Worker { /// Starts the worker, waiting and processing elements from the /// queue until the last sending end for the channel is dropped, /// at which point this future resolves. @@ -115,160 +55,19 @@ where let _ = self.send.take().unwrap(); debug!("starting worker"); - while let Some(work) = self.recv.recv().await { - debug!(?work, "received work"); - do_work(work, &self.service).await; - } - - Ok(self) - } -} - -pub async fn do_work< - 'c, - E: std::fmt::Display, - S: Service<'c, State = W, Error = E>, - W: EndState<'c> + Debug, ->( - mut work: W, - service: &'c S, -) { - loop { - work = { - let context = service.context(); - - // Safety: EndState's transitions are Infallible - work.next(&context).await.unwrap() - }; - - match service.update(&work).await { - Ok(_) => {} - Err(err) => info!("failed to update a state: {}\nstate: {:?}", err, work), - }; - - if work.is_done() { - break; - } else { - debug!(?work, "work not done yet"); - } - } -} - -#[cfg(test)] -pub mod tests { - use std::convert::Infallible; - - use anyhow::anyhow; - use tokio::sync::Mutex; - - use super::*; - use crate::tests::{World, WorldContext}; - - pub struct DummyService { - world: World, - state: Mutex>, - } - - impl DummyService<()> { - pub async fn new() -> DummyService { - let world = World::new().await; - DummyService { - world, - state: Mutex::new(None), + while let Some(mut work) = self.recv.recv().await { + loop { + match work.poll(()).await { + TaskResult::Done(_) | TaskResult::Cancelled => break, + TaskResult::Pending(_) | TaskResult::TryAgain => continue, + TaskResult::Err(err) => { + info!("task failed: {err}"); + break; + } + } } } - } - - #[async_trait] - impl<'c, S> Service<'c> for DummyService - where - S: EndState<'c> + Sync, - { - type Context = WorldContext<'c>; - - type State = S; - type Error = Error; - - fn context(&'c self) -> Self::Context { - self.world.context() - } - - async fn update(&self, state: &Self::State) -> Result<(), Self::Error> { - let mut lock = self.state.lock().await; - *lock = Some(Self::State::clone(state)); - Ok(()) - } - } - - #[derive(Debug, PartialEq, Eq, Clone)] - pub struct FiniteState { - count: usize, - max_count: usize, - } - - #[async_trait] - impl<'c> State<'c> for FiniteState { - type Next = Self; - - type Error = Infallible; - - async fn next>(mut self, _ctx: &C) -> Result { - if self.count < self.max_count { - self.count += 1; - } - Ok(self) - } - } - - impl<'c> EndState<'c> for FiniteState { - type ErrorVariant = anyhow::Error; - - fn is_done(&self) -> bool { - self.count == self.max_count - } - - fn into_result(self) -> Result { - if self.count > self.max_count { - Err(anyhow!( - "count is over max_count: {} > {}", - self.count, - self.max_count - )) - } else { - Ok(self) - } - } - } - - #[tokio::test] - async fn worker_queue_and_proceed_until_done() { - let svc = DummyService::new::().await; - - let worker = Worker::new(svc); - - { - let sender = worker.sender(); - - let state = FiniteState { - count: 0, - max_count: 42, - }; - - sender.send(state).await.unwrap(); - } - - let Worker { - service: DummyService { state, .. }, - .. - } = worker.start().await.unwrap(); - - assert_eq!( - *state.lock().await, - Some(FiniteState { - count: 42, - max_count: 42 - }) - ); + Ok(self) } }