diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index a747bc673..ad1ac8357 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -119,7 +119,7 @@ impl std::fmt::Display for Error { impl StdError for Error {} -#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq, Hash)] #[sqlx(transparent)] pub struct ProjectName(String); diff --git a/gateway/src/service.rs b/gateway/src/service.rs index ed32cd884..afd03dd67 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -28,7 +28,8 @@ use crate::acme::CustomDomain; use crate::args::ContextArgs; use crate::auth::{Key, Permissions, ScopedUser, User}; use crate::project::Project; -use crate::task::TaskBuilder; +use crate::task::{BoxedTask, TaskBuilder}; +use crate::worker::TaskRouter; use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName}; pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations"); @@ -187,6 +188,7 @@ impl GatewayContextProvider { pub struct GatewayService { provider: GatewayContextProvider, db: SqlitePool, + task_router: TaskRouter, } impl GatewayService { @@ -201,7 +203,13 @@ impl GatewayService { let provider = GatewayContextProvider::new(docker, container_settings); - Self { provider, db } + let task_router = TaskRouter::new(); + + Self { + provider, + db, + task_router, + } } pub async fn route( @@ -547,6 +555,10 @@ impl GatewayService { pub fn new_task(self: &Arc) -> TaskBuilder { TaskBuilder::new(self.clone()) } + + pub fn task_router(&self) -> TaskRouter { + self.task_router.clone() + } } #[derive(Clone)] diff --git a/gateway/src/task.rs b/gateway/src/task.rs index 5a23c64ac..dc506306a 100644 --- a/gateway/src/task.rs +++ b/gateway/src/task.rs @@ -12,6 +12,7 @@ use uuid::Uuid; use crate::project::*; use crate::service::{GatewayContext, GatewayService}; +use crate::worker::TaskRouter; use crate::{AccountName, EndState, Error, ErrorKind, ProjectName, Refresh, State}; // Default maximum _total_ time a task is allowed to run @@ -199,7 +200,10 @@ impl TaskBuilder { } pub async fn send(self, sender: &Sender) -> Result { + let project_name = self.project_name.clone().expect("project_name is required"); + let task_router = self.service.task_router(); let (task, handle) = AndThenNotify::after(self.build()); + let task = Route::::to(project_name, Box::new(task), task_router); match timeout(TASK_SEND_TIMEOUT, sender.send(Box::new(task))).await { Ok(Ok(_)) => Ok(handle), _ => Err(Error::from_kind(ErrorKind::ServiceUnavailable)), @@ -207,6 +211,40 @@ impl TaskBuilder { } } +pub struct Route { + project_name: ProjectName, + inner: Option, + router: TaskRouter, +} + +impl Route { + pub fn to(project_name: ProjectName, what: T, router: TaskRouter) -> Self { + Self { + project_name, + inner: Some(what), + router, + } + } +} + +#[async_trait] +impl Task<()> for Route { + type Output = (); + + type Error = Error; + + async fn poll(&mut self, _ctx: ()) -> TaskResult { + if let Some(task) = self.inner.take() { + match self.router.route(&self.project_name, task).await { + Ok(_) => TaskResult::Done(()), + Err(_) => TaskResult::Err(Error::from_kind(ErrorKind::Internal)), + } + } else { + TaskResult::Done(()) + } + } +} + pub struct RunFn { f: F, _output: PhantomData, diff --git a/gateway/src/worker.rs b/gateway/src/worker.rs index 25a3914ba..b81bb1ad0 100644 --- a/gateway/src/worker.rs +++ b/gateway/src/worker.rs @@ -1,8 +1,13 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::RwLock; use tracing::{debug, info}; use crate::task::{BoxedTask, TaskResult}; -use crate::Error; +use crate::{Error, ProjectName}; pub const WORKER_QUEUE_SIZE: usize = 2048; @@ -71,3 +76,53 @@ impl Worker { Ok(self) } } + +pub struct TaskRouter { + table: Arc>>>, +} + +impl Clone for TaskRouter { + fn clone(&self) -> Self { + Self { + table: self.table.clone(), + } + } +} + +impl Default for TaskRouter { + fn default() -> Self { + Self::new() + } +} + +impl TaskRouter { + pub fn new() -> Self { + Self { + table: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +impl TaskRouter { + pub async fn route( + &self, + name: &ProjectName, + task: BoxedTask, + ) -> Result<(), SendError> { + let mut table = self.table.write().await; + if let Some(sender) = table.get(name) { + sender.send(task).await + } else { + let worker = Worker::new(); + let sender = worker.sender(); + + tokio::spawn(worker.start()); + + let res = sender.send(task).await; + + table.insert(name.clone(), sender); + + res + } + } +}