diff --git a/Cargo.lock b/Cargo.lock index f40cb3384..197d8b855 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -161,60 +161,60 @@ checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3" [[package]] name = "apalis" -version = "0.4.9" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bbaeebf00817d5aa561515b313ef0d280bf4b92592e4709b21925c1233f613" +checksum = "3661d27ed090fb120a887a8416f648343a8e6e864791b36f6175a72b2ab3df39" dependencies = [ "apalis-core", "apalis-cron", "apalis-redis", "apalis-sql", + "futures", + "pin-project-lite", + "serde", + "thiserror", + "tokio", + "tower", + "tracing", + "tracing-futures", ] [[package]] name = "apalis-core" -version = "0.4.9" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1deb48475efcdece1f23a0553209ee842f264c2a5e9bcc4928bfa6a15a044cde" +checksum = "d82227972a1bb6f5f5c4444b8228aaed79e28d6ad411e5f88ad46dc04cf066bb" dependencies = [ - "async-stream", - "async-trait", - "chrono", + "async-oneshot", + "async-timer", "futures", - "graceful-shutdown", - "http 1.1.0", - "log", "pin-project-lite", "serde", - "strum", + "serde_json", "thiserror", - "tokio", "tower", - "tracing", - "tracing-futures", "ulid", ] [[package]] name = "apalis-cron" -version = "0.4.9" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43310b7e0132f9520b09224fb6faafb32eec82a672aa79c09e46b5b488ed505b" +checksum = "d11c4150f1088c1237cfde2d5cd3b045c17b3ed605c52bb3346641e18f2e1f77" dependencies = [ "apalis-core", "async-stream", "chrono", "cron", "futures", - "tokio", "tower", ] [[package]] name = "apalis-redis" -version = "0.4.9" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2abee8225fd875e57b530abbcf2d9c3122c1a2cce905367b67c6410b6f9654d7" +checksum = "dd6f0968397ad66d4628a3c8022e201d3edc58eb44a522b5c76b5efd334b9fdd" dependencies = [ "apalis-core", "async-stream", @@ -224,23 +224,20 @@ dependencies = [ "log", "redis", "serde", - "serde_json", "tokio", ] [[package]] name = "apalis-sql" -version = "0.4.9" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5899bfd124e460f1449ffab643f6bac0dc417e7ca234f34c732c923c6a3addbf" +checksum = "99eaea6cf256a5d0fce59c68608ba16e3ea9f01cb4a45e5c7fa5709ea44dacd1" dependencies = [ "apalis-core", "async-stream", - "async-trait", - "chrono", - "debounced", "futures", "futures-lite 2.3.0", + "log", "serde", "serde_json", "sqlx", @@ -488,6 +485,15 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-oneshot" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae47de2a02d543205f3f5457a90b6ecbc9494db70557bd29590ec8f1ddff5463" +dependencies = [ + "futures-micro", +] + [[package]] name = "async-process" version = "1.8.1" @@ -578,6 +584,17 @@ version = "4.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" +[[package]] +name = "async-timer" +version = "1.0.0-beta.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a18932baa05100f01c9980d03e330f95a8f2dee1a7576969fa507bdce3b568" +dependencies = [ + "error-code", + "libc", + "wasm-bindgen", +] + [[package]] name = "async-trait" version = "0.1.80" @@ -1507,16 +1524,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63dfa964fe2a66f3fde91fc70b267fe193d822c7e603e2a675a49a7f46ad3f49" -[[package]] -name = "debounced" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d8b0346b9fa0aa01a3fa4bcce48d62f8738e9c2956e92f275bbf6cf9d6fab5" -dependencies = [ - "futures-timer", - "futures-util", -] - [[package]] name = "debugid" version = "0.8.0" @@ -1699,6 +1706,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "error-code" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0474425d51df81997e2f90a21591180b38eccf27292d755f3e30750225c175b" + [[package]] name = "etcetera" version = "0.8.0" @@ -1966,6 +1979,15 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "futures-micro" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b460264b3593d68b16a7bc35f7bc226ddfebdf9a1c8db1ed95d5cc6b7168c826" +dependencies = [ + "pin-project-lite", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -1978,12 +2000,6 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.30" @@ -2074,17 +2090,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "graceful-shutdown" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3effbaf774a1da3462925bb182ccf975c284cf46edca5569ea93420a657af484" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - [[package]] name = "group" version = "0.13.0" @@ -3618,6 +3623,7 @@ dependencies = [ "anyhow", "apalis", "apalis-core", + "apalis-sql", "async-stream", "async-trait", "chrono", @@ -4727,9 +4733,9 @@ dependencies = [ [[package]] name = "redis" -version = "0.24.0" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" +checksum = "6472825949c09872e8f2c50bde59fcefc17748b6be5c90fd67cd8b4daca73bfd" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index dab553abf..1cab7fa3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ oauth2-types = { path = "./crates/oauth2-types/", version = "=0.9.0" } # Async job queue [workspace.dependencies.apalis] -version = "0.4.9" +version = "0.5.1" features = ["cron"] # GraphQL server diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 40cbf9732..5fd4329d9 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -171,8 +171,7 @@ impl Options { info!(worker_name, "Starting task worker"); let monitor = - mas_tasks::init(&worker_name, &pool, &mailer, homeserver_connection.clone()) - .await?; + mas_tasks::init(&worker_name, &pool, &mailer, homeserver_connection.clone()); // TODO: grab the handle tokio::spawn(monitor.run()); } diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index afaeecbf1..9147b2b03 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -77,7 +77,7 @@ impl Options { let worker_name = Alphanumeric.sample_string(&mut rng, 10); info!(worker_name, "Starting task scheduler"); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?; + let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn); span.exit(); diff --git a/crates/email/src/lib.rs b/crates/email/src/lib.rs index ce206c640..bc44b9d4d 100644 --- a/crates/email/src/lib.rs +++ b/crates/email/src/lib.rs @@ -20,11 +20,13 @@ mod mailer; mod transport; pub use lettre::{ - message::Mailbox, transport::smtp::authentication::Credentials as SmtpCredentials, Address, + address::{Address, AddressError}, + message::Mailbox, + transport::smtp::authentication::Credentials as SmtpCredentials, }; pub use mas_templates::EmailVerificationContext; pub use self::{ - mailer::Mailer, + mailer::{Error as MailerError, Mailer}, transport::{SmtpMode, Transport as MailTransport}, }; diff --git a/crates/email/src/mailer.rs b/crates/email/src/mailer.rs index 8e5616786..a4aa88a04 100644 --- a/crates/email/src/mailer.rs +++ b/crates/email/src/mailer.rs @@ -32,11 +32,17 @@ pub struct Mailer { reply_to: Mailbox, } +/// Errors that can occur when sending emails #[derive(Debug, Error)] #[error(transparent)] pub enum Error { + /// Mail failed to send through the transport Transport(#[from] crate::transport::Error), + + /// Failed to render email templates Templates(#[from] mas_templates::TemplateError), + + /// Email built was invalid Content(#[from] lettre::error::Error), } diff --git a/crates/storage-pg/src/job.rs b/crates/storage-pg/src/job.rs index 770f2bd28..719133d11 100644 --- a/crates/storage-pg/src/job.rs +++ b/crates/storage-pg/src/job.rs @@ -60,9 +60,7 @@ impl<'c> JobRepository for PgJobRepository<'c> { ) -> Result<(), Self::Error> { let now = clock.now(); let id = Ulid::from_datetime_with_source(now.into(), rng); - // XXX: this is what apalis_core::job::JobId does - let id = format!("JID-{id}"); - tracing::Span::current().record("job.id", &id); + tracing::Span::current().record("job.id", tracing::field::display(id)); let res = sqlx::query!( r#" diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index d48c2f450..1a6cb22bb 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -32,7 +32,7 @@ pub struct JobSubmission { payload: Value, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct SerializableSpanContext { trace_id: String, span_id: String, @@ -65,7 +65,7 @@ impl TryFrom<&SerializableSpanContext> for SpanContext { } /// A wrapper for [`Job`] which adds the span context in the payload. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct JobWithSpanContext { #[serde(skip_serializing_if = "Option::is_none")] span_context: Option, diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index de50f1455..7553a433f 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -14,7 +14,8 @@ workspace = true [dependencies] anyhow.workspace = true apalis.workspace = true -apalis-core = "0.4.9" +apalis-core = "0.5.1" +apalis-sql = { version = "0.5.1", features = ["postgres"] } async-stream = "0.3.5" async-trait.workspace = true chrono.workspace = true diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 86c7f64f3..4a22cee03 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -18,9 +18,7 @@ use std::str::FromStr; use apalis::{ cron::CronStream, - prelude::{ - timer::TokioTimer, Job, JobContext, Monitor, TokioExecutor, WorkerBuilder, WorkerFactoryFn, - }, + prelude::{Data, Job, Monitor, TokioExecutor, WorkerBuilder, WorkerFactoryFn}, }; use chrono::{DateTime, Utc}; use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess}; @@ -28,7 +26,7 @@ use tracing::{debug, info}; use crate::{ utils::{metrics_layer, trace_layer, TracedJob}, - JobContextExt, State, + State, }; #[derive(Default, Clone)] @@ -50,13 +48,15 @@ impl TracedJob for CleanupExpiredTokensJob {} pub async fn cleanup_expired_tokens( job: CleanupExpiredTokensJob, - ctx: JobContext, -) -> Result<(), Box> { + state: Data, +) -> Result<(), mas_storage::RepositoryError> { debug!("cleanup expired tokens job scheduled at {}", job.scheduled); - let state = ctx.state(); let clock = state.clock(); - let mut repo = state.repository().await?; + let mut repo = state + .repository() + .await + .map_err(mas_storage::RepositoryError::from_error)?; let count = repo.oauth2_access_token().cleanup_expired(&clock).await?; repo.save().await?; @@ -78,10 +78,10 @@ pub(crate) fn register( let schedule = apalis::cron::Schedule::from_str("*/15 * * * * *").unwrap(); let worker_name = format!("{job}-{suffix}", job = CleanupExpiredTokensJob::NAME); let worker = WorkerBuilder::new(worker_name) - .stream(CronStream::new(schedule).timer(TokioTimer).to_stream()) - .layer(state.inject()) + .data(state.clone()) .layer(metrics_layer()) .layer(trace_layer()) + .stream(CronStream::new(schedule).into_stream()) .build_fn(cleanup_expired_tokens); monitor.register(worker) diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index 97be3cb43..f851e87c3 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -12,29 +12,49 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context; -use apalis::prelude::{JobContext, Monitor, TokioExecutor}; +use apalis::prelude::{Monitor, TokioExecutor}; +use apalis_core::layers::extensions::Data; use chrono::Duration; use mas_email::{Address, Mailbox}; use mas_i18n::locale; use mas_storage::job::{JobWithSpanContext, VerifyEmailJob}; use mas_templates::{EmailVerificationContext, TemplateContext}; use rand::{distributions::Uniform, Rng}; +use sqlx::PgPool; +use thiserror::Error; use tracing::info; +use ulid::Ulid; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::State; + +#[derive(Debug, Error)] +pub enum Error { + #[error("User email not found: {0}")] + UserEmailNotFound(Ulid), + + #[error("User not found: {0}")] + UserNotFound(Ulid), + + #[error("Repository error")] + Repositoru(#[from] mas_storage::RepositoryError), + + #[error("Invalid email address")] + InvalidEmailAddress(#[from] mas_email::AddressError), + + #[error("Failed to send email")] + Mailer(#[from] mas_email::MailerError), +} #[tracing::instrument( name = "job.verify_email", fields(user_email.id = %job.user_email_id()), skip_all, - err(Debug), + err, )] async fn verify_email( job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); + state: Data, +) -> Result<(), Error> { let mut repo = state.repository().await?; let mut rng = state.rng(); let mailer = state.mailer(); @@ -50,14 +70,14 @@ async fn verify_email( .user_email() .lookup(job.user_email_id()) .await? - .context("User email not found")?; + .ok_or(Error::UserEmailNotFound(job.user_email_id()))?; // Lookup the user associated with the email let user = repo .user() .lookup(user_email.user_id) .await? - .context("User not found")?; + .ok_or(Error::UserNotFound(user_email.user_id))?; // Generate a verification code let range = Uniform::::from(0..1_000_000); @@ -100,10 +120,9 @@ pub(crate) fn register( suffix: &str, monitor: Monitor, state: &State, - storage_factory: &PostgresStorageFactory, + pool: &PgPool, ) -> Monitor { - let verify_email_worker = - crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory); + let verify_email_worker = crate::build!(VerifyEmailJob => verify_email, suffix, state, pool); monitor.register(verify_email_worker) } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 37c355b3f..ede4f1d3f 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -15,21 +15,17 @@ use std::sync::Arc; use apalis::prelude::{Monitor, TokioExecutor}; -use apalis_core::layers::extensions::Extension; +use apalis_core::layers::extensions::Data; use mas_email::Mailer; use mas_matrix::HomeserverConnection; -use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock}; -use mas_storage_pg::{DatabaseError, PgRepository}; +use mas_storage::{BoxClock, BoxRepository, Repository, RepositoryError, SystemClock}; +use mas_storage_pg::PgRepository; use rand::SeedableRng; use sqlx::{Pool, Postgres}; -use tracing::debug; - -use crate::storage::PostgresStorageFactory; mod database; mod email; mod matrix; -mod storage; mod user; mod utils; @@ -56,8 +52,8 @@ impl State { } } - pub fn inject(&self) -> Extension { - Extension(self.clone()) + pub fn inject(&self) -> Data { + Data::new(self.clone()) } pub fn pool(&self) -> &Pool { @@ -78,9 +74,10 @@ impl State { rand_chacha::ChaChaRng::from_rng(rand::thread_rng()).expect("failed to seed rng") } - pub async fn repository(&self) -> Result { + pub async fn repository(&self) -> Result { let repo = PgRepository::from_pool(self.pool()) - .await? + .await + .map_err(mas_storage::RepositoryError::from_error)? .map_err(mas_storage::RepositoryError::from_error) .boxed(); @@ -92,22 +89,10 @@ impl State { } } -trait JobContextExt { - fn state(&self) -> State; -} - -impl JobContextExt for apalis::prelude::JobContext { - fn state(&self) -> State { - self.data_opt::() - .expect("state not injected in job context") - .clone() - } -} - /// Helper macro to build a storage-backed worker. macro_rules! build { - ($job:ty => $fn:ident, $suffix:expr, $state:expr, $factory:expr) => {{ - let storage = $factory.build(); + ($job:ty => $fn:ident, $suffix:expr, $state:expr, $pool:expr) => {{ + let storage = ::apalis_sql::postgres::PostgresStorage::new($pool.clone()); let worker_name = format!( "{job}-{suffix}", job = <$job as ::apalis::prelude::Job>::NAME, @@ -117,12 +102,9 @@ macro_rules! build { let builder = ::apalis::prelude::WorkerBuilder::new(worker_name) .layer($state.inject()) .layer(crate::utils::trace_layer()) - .layer(crate::utils::metrics_layer()); - - let builder = ::apalis::prelude::WithStorage::with_storage_config(builder, storage, |c| { - c.fetch_interval(std::time::Duration::from_secs(1)) - }); - ::apalis::prelude::WorkerFactory::build(builder, ::apalis::prelude::job_fn($fn)) + .layer(crate::utils::metrics_layer()) + .with_storage(storage); + ::apalis::prelude::WorkerFactoryFn::build_fn(builder, $fn) }}; } @@ -133,26 +115,30 @@ pub(crate) use build; /// # Errors /// /// This function can fail if the database connection fails. -pub async fn init( +pub fn init( name: &str, pool: &Pool, mailer: &Mailer, homeserver: impl HomeserverConnection + 'static, -) -> Result, sqlx::Error> { +) -> Monitor { let state = State::new( pool.clone(), SystemClock::default(), mailer.clone(), homeserver, ); - let factory = PostgresStorageFactory::new(pool.clone()); - let monitor = Monitor::new().executor(TokioExecutor::new()); + let monitor = Monitor::::new(); let monitor = self::database::register(name, monitor, &state); - let monitor = self::email::register(name, monitor, &state, &factory); - let monitor = self::matrix::register(name, monitor, &state, &factory); - let monitor = self::user::register(name, monitor, &state, &factory); - // TODO: we might want to grab the join handle here - factory.listen().await?; - debug!(?monitor, "workers registered"); - Ok(monitor) + let monitor = self::email::register(name, monitor, &state, pool); + let monitor = self::matrix::register(name, monitor, &state, pool); + let monitor = self::user::register(name, monitor, &state, pool); + + monitor.on_event(|e| { + let event = e.inner(); + if let apalis::prelude::Event::Error(error) = e.inner() { + tracing::error!(?error, "worker error"); + } else { + tracing::debug!(?event, "worker event"); + } + }) } diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index b9c9f2788..43252464e 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -12,17 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context; -use apalis::prelude::{JobContext, Monitor, TokioExecutor}; +use apalis::prelude::{Monitor, TokioExecutor}; +use apalis_core::layers::extensions::Data; use mas_matrix::ProvisionRequest; use mas_storage::{ job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob}, user::{UserEmailRepository, UserRepository}, RepositoryAccess, }; +use sqlx::PgPool; +use thiserror::Error; use tracing::info; +use ulid::Ulid; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::State; + +#[derive(Debug, Error)] +pub enum Error { + #[error("User not found: {0}")] + UserNotFound(Ulid), + + #[error("Failed to do homesever operation")] + HomeserverConnection(#[source] anyhow::Error), + + #[error("Repository error")] + Repository(#[from] mas_storage::RepositoryError), +} /// Job to provision a user on the Matrix homeserver. /// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id} @@ -35,9 +50,8 @@ use crate::{storage::PostgresStorageFactory, JobContextExt, State}; )] async fn provision_user( job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); + state: Data, +) -> Result<(), Error> { let matrix = state.matrix_connection(); let mut repo = state.repository().await?; @@ -45,7 +59,7 @@ async fn provision_user( .user() .lookup(job.user_id()) .await? - .context("User not found")?; + .ok_or(Error::UserNotFound(job.user_id()))?; let mxid = matrix.mxid(&user.username); let emails = repo @@ -65,7 +79,10 @@ async fn provision_user( request = request.set_displayname(display_name.to_owned()); } - let created = matrix.provision_user(&request).await?; + let created = matrix + .provision_user(&request) + .await + .map_err(Error::HomeserverConnection)?; if created { info!(%user.id, %mxid, "User created"); @@ -90,9 +107,8 @@ async fn provision_user( )] async fn provision_device( job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); + state: Data, +) -> Result<(), Error> { let matrix = state.matrix_connection(); let mut repo = state.repository().await?; @@ -100,11 +116,14 @@ async fn provision_device( .user() .lookup(job.user_id()) .await? - .context("User not found")?; + .ok_or(Error::UserNotFound(job.user_id()))?; let mxid = matrix.mxid(&user.username); - matrix.create_device(&mxid, job.device_id()).await?; + matrix + .create_device(&mxid, job.device_id()) + .await + .map_err(Error::HomeserverConnection)?; info!(%user.id, %mxid, device.id = job.device_id(), "Device created"); Ok(()) @@ -124,9 +143,8 @@ async fn provision_device( )] async fn delete_device( job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); + state: Data, +) -> Result<(), Error> { let matrix = state.matrix_connection(); let mut repo = state.repository().await?; @@ -134,11 +152,14 @@ async fn delete_device( .user() .lookup(job.user_id()) .await? - .context("User not found")?; + .ok_or(Error::UserNotFound(job.user_id()))?; let mxid = matrix.mxid(&user.username); - matrix.delete_device(&mxid, job.device_id()).await?; + matrix + .delete_device(&mxid, job.device_id()) + .await + .map_err(Error::HomeserverConnection)?; info!(%user.id, %mxid, device.id = job.device_id(), "Device deleted"); Ok(()) @@ -148,14 +169,13 @@ pub(crate) fn register( suffix: &str, monitor: Monitor, state: &State, - storage_factory: &PostgresStorageFactory, + pool: &PgPool, ) -> Monitor { let provision_user_worker = - crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory); + crate::build!(ProvisionUserJob => provision_user, suffix, state, pool); let provision_device_worker = - crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); - let delete_device_worker = - crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); + crate::build!(ProvisionDeviceJob => provision_device, suffix, state, pool); + let delete_device_worker = crate::build!(DeleteDeviceJob => delete_device, suffix, state, pool); monitor .register(provision_user_worker) diff --git a/crates/tasks/src/storage/from_row.rs b/crates/tasks/src/storage/from_row.rs deleted file mode 100644 index d26203657..000000000 --- a/crates/tasks/src/storage/from_row.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; - -use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId}; -use chrono::{DateTime, Utc}; -use serde_json::Value; -use sqlx::Row; - -/// Wrapper for [`JobRequest`] -pub(crate) struct SqlJobRequest(JobRequest); - -impl From> for JobRequest { - fn from(val: SqlJobRequest) -> Self { - val.0 - } -} - -impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow> - for SqlJobRequest -{ - fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { - let job: Value = row.try_get("job")?; - let id: JobId = - JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { - index: "id".to_owned(), - source: Box::new(e), - })?; - let mut context = JobContext::new(id); - - let run_at = row.try_get("run_at")?; - context.set_run_at(run_at); - - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); - - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); - - let done_at: Option> = row.try_get("done_at").unwrap_or_default(); - context.set_done_at(done_at); - - let lock_at: Option> = row.try_get("lock_at").unwrap_or_default(); - context.set_lock_at(lock_at); - - let last_error = row.try_get("last_error").unwrap_or_default(); - context.set_last_error(last_error); - - let status: String = row.try_get("status")?; - context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?); - - let lock_by: Option = row.try_get("lock_by").unwrap_or_default(); - context.set_lock_by(lock_by.map(WorkerId::new)); - - Ok(SqlJobRequest(JobRequest::new_with_context( - serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?, - context, - ))) - } -} diff --git a/crates/tasks/src/storage/mod.rs b/crates/tasks/src/storage/mod.rs deleted file mode 100644 index 7884c083b..000000000 --- a/crates/tasks/src/storage/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a -//! shared connection for the [`PgListener`] - -mod from_row; -mod postgres; - -use self::from_row::SqlJobRequest; -pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory; diff --git a/crates/tasks/src/storage/postgres.rs b/crates/tasks/src/storage/postgres.rs deleted file mode 100644 index 4836f09cb..000000000 --- a/crates/tasks/src/storage/postgres.rs +++ /dev/null @@ -1,399 +0,0 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration}; - -use apalis_core::{ - error::JobStreamError, - job::{Job, JobId, JobStreamResult}, - request::JobRequest, - storage::{StorageError, StorageResult, StorageWorkerPulse}, - utils::Timer, - worker::WorkerId, -}; -use async_stream::try_stream; -use chrono::{DateTime, Utc}; -use event_listener::Event; -use futures_lite::{Stream, StreamExt}; -use serde::{de::DeserializeOwned, Serialize}; -use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; -use tokio::task::JoinHandle; - -use super::SqlJobRequest; - -pub struct StorageFactory { - pool: PgPool, - event: Arc, -} - -impl StorageFactory { - pub fn new(pool: Pool) -> Self { - StorageFactory { - pool, - event: Arc::new(Event::new()), - } - } - - pub async fn listen(self) -> Result, sqlx::Error> { - let mut listener = PgListener::connect_with(&self.pool).await?; - listener.listen("apalis::job").await?; - - let handle = tokio::spawn(async move { - loop { - let notification = listener.recv().await.expect("Failed to poll notification"); - self.event.notify(usize::MAX); - tracing::debug!(?notification, "Broadcast notification"); - } - }); - - Ok(handle) - } - - pub fn build(&self) -> Storage { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres -#[derive(Debug)] -pub struct Storage { - pool: PgPool, - event: Arc, - job_type: PhantomData, -} - -impl Clone for Storage { - fn clone(&self) -> Self { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -impl Storage { - fn stream_jobs( - &self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> impl Stream, JobStreamError>> { - let pool = self.pool.clone(); - let sleeper = apalis_core::utils::timer::TokioTimer; - let worker_id = worker_id.clone(); - let event = self.event.clone(); - try_stream! { - loop { - // Wait for a notification or a timeout - let listener = event.listen(); - let interval = sleeper.sleep(interval); - futures_lite::future::race(interval, listener).await; - - let tx = pool.clone(); - let job_type = T::NAME; - let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; - let jobs: Vec> = sqlx::query_as(fetch_query) - .bind(worker_id.name()) - .bind(job_type) - // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html - .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) - .fetch_all(&tx) - .await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?; - for job in jobs { - yield job.into() - } - } - } - } - - async fn keep_alive_at( - &mut self, - worker_id: &WorkerId, - last_seen: DateTime, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - - let worker_type = T::NAME; - let storage_name = std::any::type_name::(); - let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (id) DO - UPDATE SET last_seen = EXCLUDED.last_seen"; - sqlx::query(query) - .bind(worker_id.name()) - .bind(worker_type) - .bind(storage_name) - .bind(std::any::type_name::()) - .bind(last_seen) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } -} - -#[async_trait::async_trait] -impl apalis_core::storage::Storage for Storage -where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, -{ - type Output = T; - - /// Push a job to Postgres [Storage] - /// - /// # SQL Example - /// - /// ```sql - /// SELECT apalis.push_job(job_type::text, job::json); - /// ``` - async fn push(&mut self, job: Self::Output) -> StorageResult { - let id = JobId::new(); - let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn schedule( - &mut self, - job: Self::Output, - on: chrono::DateTime, - ) -> StorageResult { - let query = - "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)"; - - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let id = JobId::new(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .bind(on) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult>> { - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1"; - let res: Option> = sqlx::query_as(fetch_query) - .bind(job_id.to_string()) - .fetch_optional(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(res.map(Into::into)) - } - - async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult { - match pulse { - StorageWorkerPulse::EnqueueScheduled { count: _ } => { - // Ideally jobs are queue via run_at. So this is not necessary - Ok(true) - } - - // Worker not seen in 5 minutes yet has running jobs - StorageWorkerPulse::ReenqueueOrphaned { count, .. } => { - let job_type = T::NAME; - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - let query = "UPDATE apalis.jobs - SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned' - WHERE id in - (SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id - WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes' - AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);"; - sqlx::query(query) - .bind(job_type) - .bind(count) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(true) - } - - _ => unimplemented!(), - } - } - - async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - /// Puts the job instantly back into the queue - /// Another [Worker] may consume - async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - fn consume( - &mut self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> JobStreamResult { - Box::pin( - self.stream_jobs(worker_id, interval, buffer_size) - .map(|r| r.map(Some)), - ) - } - async fn len(&self) -> StorageResult { - let pool = self.pool.clone(); - let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'"; - let record = sqlx::query(query) - .fetch_one(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(record - .try_get("count") - .map_err(|e| StorageError::Database(Box::from(e)))?) - } - async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - let query = - "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn reschedule(&mut self, job: &JobRequest, wait: Duration) -> StorageResult<()> { - let pool = self.pool.clone(); - let job_id = job.id(); - - let wait: i64 = wait - .as_secs() - .try_into() - .map_err(|e| StorageError::Database(Box::new(e)))?; - let wait = chrono::Duration::microseconds(wait * 1000 * 1000); - // TODO: should we use a clock here? - #[allow(clippy::disallowed_methods)] - let run_at = Utc::now().add(wait); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(run_at) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn update_by_id( - &self, - job_id: &JobId, - job: &JobRequest, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - let status = job.status().as_ref(); - let attempts = job.attempts(); - let done_at = *job.done_at(); - let lock_by = job.lock_by().clone(); - let lock_at = *job.lock_at(); - let last_error = job.last_error().clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7"; - sqlx::query(query) - .bind(status.to_owned()) - .bind(attempts) - .bind(done_at) - .bind(lock_by.as_ref().map(WorkerId::name)) - .bind(lock_at) - .bind(last_error) - .bind(job_id.to_string()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn keep_alive(&mut self, worker_id: &WorkerId) -> StorageResult<()> { - #[allow(clippy::disallowed_methods)] - let now = Utc::now(); - - self.keep_alive_at::(worker_id, now).await - } -} diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index 3e64e9698..a90acfd34 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -12,16 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context; -use apalis::prelude::{JobContext, Monitor, TokioExecutor}; +use apalis::prelude::{Monitor, TokioExecutor}; +use apalis_core::layers::extensions::Data; use mas_storage::{ job::{DeactivateUserJob, JobWithSpanContext}, user::UserRepository, RepositoryAccess, }; +use sqlx::PgPool; +use thiserror::Error; use tracing::info; +use ulid::Ulid; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::State; + +#[derive(Debug, Error)] +pub enum Error { + #[error("User not found: {0}")] + UserNotFound(Ulid), + + #[error("Failed to do homesever operation")] + HomeserverConnection(#[source] anyhow::Error), + + #[error("Repository error")] + Repository(#[from] mas_storage::RepositoryError), +} /// Job to deactivate a user, both locally and on the Matrix homeserver. #[tracing::instrument( @@ -32,9 +47,8 @@ use crate::{storage::PostgresStorageFactory, JobContextExt, State}; )] async fn deactivate_user( job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); + state: Data, +) -> Result<(), Error> { let clock = state.clock(); let matrix = state.matrix_connection(); let mut repo = state.repository().await?; @@ -43,14 +57,10 @@ async fn deactivate_user( .user() .lookup(job.user_id()) .await? - .context("User not found")?; + .ok_or(Error::UserNotFound(job.user_id()))?; // Let's first lock the user - let user = repo - .user() - .lock(&clock, user) - .await - .context("Failed to lock user")?; + let user = repo.user().lock(&clock, user).await?; // TODO: delete the sessions & access tokens @@ -59,7 +69,10 @@ async fn deactivate_user( let mxid = matrix.mxid(&user.username); info!("Deactivating user {} on homeserver", mxid); - matrix.delete_user(&mxid, job.hs_erase()).await?; + matrix + .delete_user(&mxid, job.hs_erase()) + .await + .map_err(Error::HomeserverConnection)?; Ok(()) } @@ -68,10 +81,10 @@ pub(crate) fn register( suffix: &str, monitor: Monitor, state: &State, - storage_factory: &PostgresStorageFactory, + pool: &PgPool, ) -> Monitor { let deactivate_user_worker = - crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory); + crate::build!(DeactivateUserJob => deactivate_user, suffix, state, pool); monitor.register(deactivate_user_worker) } diff --git a/crates/tasks/src/utils.rs b/crates/tasks/src/utils.rs index 8b2d04811..ab7d69332 100644 --- a/crates/tasks/src/utils.rs +++ b/crates/tasks/src/utils.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use apalis::{prelude::Job, prelude::JobRequest}; +use apalis::prelude::{Job, Request}; use mas_storage::job::JobWithSpanContext; use mas_tower::{ make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer, @@ -43,13 +43,11 @@ impl TracedJob for JobWithSpanContext { } } -fn make_span_for_job_request(req: &JobRequest) -> tracing::Span { +fn make_span_for_job_request(req: &Request) -> tracing::Span { let span = info_span!( "job.run", "otel.kind" = "consumer", "otel.status_code" = tracing::field::Empty, - "job.id" = %req.id(), - "job.attempts" = req.attempts(), "job.name" = J::NAME, ); @@ -61,21 +59,21 @@ fn make_span_for_job_request(req: &JobRequest) -> tracing::Span } type TraceLayerForJob = - TraceLayer) -> tracing::Span>, KV<&'static str>, KV<&'static str>>; + TraceLayer) -> tracing::Span>, KV<&'static str>, KV<&'static str>>; pub(crate) fn trace_layer() -> TraceLayerForJob where J: TracedJob, { TraceLayer::new(make_span_fn( - make_span_for_job_request:: as fn(&JobRequest) -> tracing::Span, + make_span_for_job_request:: as fn(&Request) -> tracing::Span, )) .on_response(KV("otel.status_code", "OK")) .on_error(KV("otel.status_code", "ERROR")) } type MetricsLayerForJob = ( - IdentityLayer>, + IdentityLayer>, DurationRecorderLayer, InFlightCounterLayer, );