Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: per-project parallelism #533

Merged
merged 4 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl std::fmt::Display for Error {

impl StdError for Error {}

#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq)]
#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq, Hash)]
#[sqlx(transparent)]
pub struct ProjectName(String);

Expand Down
16 changes: 14 additions & 2 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use crate::acme::CustomDomain;
use crate::args::ContextArgs;
use crate::auth::{Key, Permissions, ScopedUser, User};
use crate::project::Project;
use crate::task::TaskBuilder;
use crate::task::{BoxedTask, TaskBuilder};
use crate::worker::TaskRouter;
use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName};

pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");
Expand Down Expand Up @@ -187,6 +188,7 @@ impl GatewayContextProvider {
pub struct GatewayService {
provider: GatewayContextProvider,
db: SqlitePool,
task_router: TaskRouter<BoxedTask>,
}

impl GatewayService {
Expand All @@ -201,7 +203,13 @@ impl GatewayService {

let provider = GatewayContextProvider::new(docker, container_settings);

Self { provider, db }
let task_router = TaskRouter::new();

Self {
provider,
db,
task_router,
}
}

pub async fn route(
Expand Down Expand Up @@ -547,6 +555,10 @@ impl GatewayService {
pub fn new_task(self: &Arc<Self>) -> TaskBuilder {
TaskBuilder::new(self.clone())
}

pub fn task_router(&self) -> TaskRouter<BoxedTask> {
self.task_router.clone()
}
}

#[derive(Clone)]
Expand Down
38 changes: 38 additions & 0 deletions gateway/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use uuid::Uuid;

use crate::project::*;
use crate::service::{GatewayContext, GatewayService};
use crate::worker::TaskRouter;
use crate::{AccountName, EndState, Error, ErrorKind, ProjectName, Refresh, State};

// Default maximum _total_ time a task is allowed to run
Expand Down Expand Up @@ -199,14 +200,51 @@ impl TaskBuilder {
}

pub async fn send(self, sender: &Sender<BoxedTask>) -> Result<TaskHandle, Error> {
let project_name = self.project_name.clone().expect("project_name is required");
let task_router = self.service.task_router();
let (task, handle) = AndThenNotify::after(self.build());
let task = Route::<BoxedTask>::to(project_name, Box::new(task), task_router);
match timeout(TASK_SEND_TIMEOUT, sender.send(Box::new(task))).await {
Ok(Ok(_)) => Ok(handle),
_ => Err(Error::from_kind(ErrorKind::ServiceUnavailable)),
}
}
}

pub struct Route<T> {
project_name: ProjectName,
inner: Option<T>,
router: TaskRouter<T>,
}

impl<T> Route<T> {
pub fn to(project_name: ProjectName, what: T, router: TaskRouter<T>) -> Self {
Self {
project_name,
inner: Some(what),
router,
}
}
}

#[async_trait]
impl Task<()> for Route<BoxedTask> {
type Output = ();

type Error = Error;

async fn poll(&mut self, _ctx: ()) -> TaskResult<Self::Output, Self::Error> {
if let Some(task) = self.inner.take() {
match self.router.route(&self.project_name, task).await {
Ok(_) => TaskResult::Done(()),
Err(_) => TaskResult::Err(Error::from_kind(ErrorKind::Internal)),
}
} else {
TaskResult::Done(())
}
}
}

pub struct RunFn<F, O> {
f: F,
_output: PhantomData<O>,
Expand Down
57 changes: 56 additions & 1 deletion gateway/src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;

use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock;
use tracing::{debug, info};

use crate::task::{BoxedTask, TaskResult};
use crate::Error;
use crate::{Error, ProjectName};

pub const WORKER_QUEUE_SIZE: usize = 2048;

Expand Down Expand Up @@ -71,3 +76,53 @@ impl Worker<BoxedTask> {
Ok(self)
}
}

pub struct TaskRouter<W> {
table: Arc<RwLock<HashMap<ProjectName, Sender<W>>>>,
}

impl<W> Clone for TaskRouter<W> {
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
}
}
}

impl<W> Default for TaskRouter<W> {
fn default() -> Self {
Self::new()
}
}

impl<W> TaskRouter<W> {
pub fn new() -> Self {
Self {
table: Arc::new(RwLock::new(HashMap::new())),
}
}
}

impl TaskRouter<BoxedTask> {
pub async fn route(
&self,
name: &ProjectName,
task: BoxedTask,
) -> Result<(), SendError<BoxedTask>> {
let mut table = self.table.write().await;
if let Some(sender) = table.get(name) {
sender.send(task).await
} else {
let worker = Worker::new();
let sender = worker.sender();

tokio::spawn(worker.start());

let res = sender.send(task).await;

table.insert(name.clone(), sender);

res
}
}
}