Skip to content

Commit

Permalink
feat: per-project parallelism (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
brokad authored Dec 14, 2022
1 parent 5e604b4 commit ae8ee01
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 4 deletions.
2 changes: 1 addition & 1 deletion gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl std::fmt::Display for Error {

impl StdError for Error {}

#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq)]
#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq, Hash)]
#[sqlx(transparent)]
pub struct ProjectName(String);

Expand Down
16 changes: 14 additions & 2 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use crate::acme::CustomDomain;
use crate::args::ContextArgs;
use crate::auth::{Key, Permissions, ScopedUser, User};
use crate::project::Project;
use crate::task::TaskBuilder;
use crate::task::{BoxedTask, TaskBuilder};
use crate::worker::TaskRouter;
use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName};

pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");
Expand Down Expand Up @@ -187,6 +188,7 @@ impl GatewayContextProvider {
pub struct GatewayService {
provider: GatewayContextProvider,
db: SqlitePool,
task_router: TaskRouter<BoxedTask>,
}

impl GatewayService {
Expand All @@ -201,7 +203,13 @@ impl GatewayService {

let provider = GatewayContextProvider::new(docker, container_settings);

Self { provider, db }
let task_router = TaskRouter::new();

Self {
provider,
db,
task_router,
}
}

pub async fn route(
Expand Down Expand Up @@ -547,6 +555,10 @@ impl GatewayService {
pub fn new_task(self: &Arc<Self>) -> TaskBuilder {
TaskBuilder::new(self.clone())
}

pub fn task_router(&self) -> TaskRouter<BoxedTask> {
self.task_router.clone()
}
}

#[derive(Clone)]
Expand Down
38 changes: 38 additions & 0 deletions gateway/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use uuid::Uuid;

use crate::project::*;
use crate::service::{GatewayContext, GatewayService};
use crate::worker::TaskRouter;
use crate::{AccountName, EndState, Error, ErrorKind, ProjectName, Refresh, State};

// Default maximum _total_ time a task is allowed to run
Expand Down Expand Up @@ -199,14 +200,51 @@ impl TaskBuilder {
}

pub async fn send(self, sender: &Sender<BoxedTask>) -> Result<TaskHandle, Error> {
let project_name = self.project_name.clone().expect("project_name is required");
let task_router = self.service.task_router();
let (task, handle) = AndThenNotify::after(self.build());
let task = Route::<BoxedTask>::to(project_name, Box::new(task), task_router);
match timeout(TASK_SEND_TIMEOUT, sender.send(Box::new(task))).await {
Ok(Ok(_)) => Ok(handle),
_ => Err(Error::from_kind(ErrorKind::ServiceUnavailable)),
}
}
}

pub struct Route<T> {
project_name: ProjectName,
inner: Option<T>,
router: TaskRouter<T>,
}

impl<T> Route<T> {
pub fn to(project_name: ProjectName, what: T, router: TaskRouter<T>) -> Self {
Self {
project_name,
inner: Some(what),
router,
}
}
}

#[async_trait]
impl Task<()> for Route<BoxedTask> {
type Output = ();

type Error = Error;

async fn poll(&mut self, _ctx: ()) -> TaskResult<Self::Output, Self::Error> {
if let Some(task) = self.inner.take() {
match self.router.route(&self.project_name, task).await {
Ok(_) => TaskResult::Done(()),
Err(_) => TaskResult::Err(Error::from_kind(ErrorKind::Internal)),
}
} else {
TaskResult::Done(())
}
}
}

pub struct RunFn<F, O> {
f: F,
_output: PhantomData<O>,
Expand Down
57 changes: 56 additions & 1 deletion gateway/src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;

use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock;
use tracing::{debug, info};

use crate::task::{BoxedTask, TaskResult};
use crate::Error;
use crate::{Error, ProjectName};

pub const WORKER_QUEUE_SIZE: usize = 2048;

Expand Down Expand Up @@ -71,3 +76,53 @@ impl Worker<BoxedTask> {
Ok(self)
}
}

pub struct TaskRouter<W> {
table: Arc<RwLock<HashMap<ProjectName, Sender<W>>>>,
}

impl<W> Clone for TaskRouter<W> {
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
}
}
}

impl<W> Default for TaskRouter<W> {
fn default() -> Self {
Self::new()
}
}

impl<W> TaskRouter<W> {
pub fn new() -> Self {
Self {
table: Arc::new(RwLock::new(HashMap::new())),
}
}
}

impl TaskRouter<BoxedTask> {
pub async fn route(
&self,
name: &ProjectName,
task: BoxedTask,
) -> Result<(), SendError<BoxedTask>> {
let mut table = self.table.write().await;
if let Some(sender) = table.get(name) {
sender.send(task).await
} else {
let worker = Worker::new();
let sender = worker.sender();

tokio::spawn(worker.start());

let res = sender.send(task).await;

table.insert(name.clone(), sender);

res
}
}
}

0 comments on commit ae8ee01

Please sign in to comment.