Skip to content

Commit

Permalink
worker/swirl/runner: Simplify AssertUnwindSafe usage
Browse files Browse the repository at this point in the history
rust-lang/rust#40628, rust-lang/rust#65717 and rust-lang/rfcs#3260 all show that unwind safety isn't particularly ergonomic to use and implement, and ultimately leads to people slapping `AssertUnwindSafe` everywhere until the compiler stops complaining.

This situation has led to built-in test framework using `catch_unwind(AssertUnwindSafe(...))` (see https://github.com/rust-lang/rust/blob/1.73.0/library/test/src/lib.rs#L649) and libraries like tower-http doing the same (see https://docs.rs/tower-http/0.4.4/src/tower_http/catch_panic.rs.html#198).

As people have mentioned in the threads above, trying to implement this correctly is akin to fighting windmills at the moment. Since the above cases demonstrated that `catch_unwind(AssertUnwindSafe(...))` is currently the easiest way to deal with this situation, this commit does the same and refactors our background job runner code accordingly.
  • Loading branch information
Turbo87 committed Nov 6, 2023
1 parent ccccb82 commit 71cffc6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
5 changes: 2 additions & 3 deletions src/cloudfront.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use aws_sdk_cloudfront::types::{InvalidationBatch, Paths};
use aws_sdk_cloudfront::{Client, Config};
use retry::delay::{jitter, Exponential};
use retry::OperationResult;
use std::panic::AssertUnwindSafe;
use std::time::Duration;
use tokio::runtime::Runtime;

pub struct CloudFront {
client: AssertUnwindSafe<Client>,
client: Client,
distribution_id: String,
}

Expand All @@ -27,7 +26,7 @@ impl CloudFront {
.credentials_provider(credentials)
.build();

let client = AssertUnwindSafe(Client::from_conf(config));
let client = Client::from_conf(config);

Some(Self {
client,
Expand Down
9 changes: 4 additions & 5 deletions src/worker/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ use crate::storage::Storage;
use crate::worker::swirl::PerformError;
use crates_io_index::Repository;
use reqwest::blocking::Client;
use std::panic::AssertUnwindSafe;
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};

pub struct Environment {
index: Mutex<Repository>,
http_client: AssertUnwindSafe<Client>,
http_client: Client,
cloudfront: Option<CloudFront>,
fastly: Option<Fastly>,
pub storage: AssertUnwindSafe<Arc<Storage>>,
pub storage: Arc<Storage>,
}

impl Environment {
Expand All @@ -25,10 +24,10 @@ impl Environment {
) -> Self {
Self {
index: Mutex::new(index),
http_client: AssertUnwindSafe(http_client),
http_client,
cloudfront,
fastly,
storage: AssertUnwindSafe(storage),
storage,
}
}

Expand Down
34 changes: 15 additions & 19 deletions src/worker/swirl/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::error::Error;
use std::panic::{catch_unwind, AssertUnwindSafe, PanicInfo, UnwindSafe};
use std::panic::{catch_unwind, AssertUnwindSafe, PanicInfo};
use std::sync::mpsc::{sync_channel, SyncSender};
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -29,15 +29,15 @@ fn runnable<J: BackgroundJob>(
}

/// The core runner responsible for locking and running jobs
pub struct Runner<Context: Clone + Send + UnwindSafe + 'static> {
pub struct Runner<Context: Clone + Send + 'static> {
connection_pool: DieselPool,
thread_pool: ThreadPool,
job_registry: Arc<RwLock<HashMap<String, RunTaskFn<Context>>>>,
environment: Context,
job_start_timeout: Duration,
}

impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
impl<Context: Clone + Send + 'static> Runner<Context> {
pub fn new(connection_pool: DieselPool, environment: Context) -> Self {
Self {
connection_pool,
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
fn run_single_job(&self, sender: SyncSender<Event>) {
use diesel::result::Error::RollbackTransaction;

let job_registry = AssertUnwindSafe(self.job_registry.clone());
let job_registry = self.job_registry.clone();
let environment = self.environment.clone();

// The connection may not be `Send` so we need to clone the pool instead
Expand Down Expand Up @@ -155,11 +155,8 @@ impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
|| {
conn.transaction(|conn| {
let pool = pool.to_real_pool();
let state = AssertUnwindSafe(PerformState { conn, pool });
catch_unwind(|| {
// Ensure the whole `AssertUnwindSafe(_)` is moved
let state = state;

let state = PerformState { conn, pool };
catch_unwind(AssertUnwindSafe(|| {
let job_registry = job_registry.read();
let run_task_fn =
job_registry.get(&job.job_type).ok_or_else(|| {
Expand All @@ -169,8 +166,8 @@ impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
))
})?;

run_task_fn(environment, state.0, job.data)
})
run_task_fn(environment, state, job.data)
}))
.map_err(|e| try_to_extract_panic_info(&e))
})
// TODO: Replace with flatten() once that stabilizes
Expand Down Expand Up @@ -294,7 +291,6 @@ mod tests {
use crates_io_test_db::TestDatabase;
use diesel::r2d2;
use diesel::r2d2::ConnectionManager;
use std::panic::AssertUnwindSafe;
use std::sync::{Arc, Barrier};

fn job_exists(id: i64, conn: &mut PgConnection) -> bool {
Expand Down Expand Up @@ -323,8 +319,8 @@ mod tests {
fn jobs_are_locked_when_fetched() {
#[derive(Clone)]
struct TestContext {
job_started_barrier: Arc<AssertUnwindSafe<Barrier>>,
assertions_finished_barrier: Arc<AssertUnwindSafe<Barrier>>,
job_started_barrier: Arc<Barrier>,
assertions_finished_barrier: Arc<Barrier>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -344,8 +340,8 @@ mod tests {
let test_database = TestDatabase::new();

let test_context = TestContext {
job_started_barrier: Arc::new(AssertUnwindSafe(Barrier::new(2))),
assertions_finished_barrier: Arc::new(AssertUnwindSafe(Barrier::new(2))),
job_started_barrier: Arc::new(Barrier::new(2)),
assertions_finished_barrier: Arc::new(Barrier::new(2)),
};

let runner =
Expand Down Expand Up @@ -409,7 +405,7 @@ mod tests {
fn failed_jobs_do_not_release_lock_before_updating_retry_time() {
#[derive(Clone)]
struct TestContext {
job_started_barrier: Arc<AssertUnwindSafe<Barrier>>,
job_started_barrier: Arc<Barrier>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -428,7 +424,7 @@ mod tests {
let test_database = TestDatabase::new();

let test_context = TestContext {
job_started_barrier: Arc::new(AssertUnwindSafe(Barrier::new(2))),
job_started_barrier: Arc::new(Barrier::new(2)),
};

let runner =
Expand Down Expand Up @@ -495,7 +491,7 @@ mod tests {
assert_eq!(tries, 1);
}

fn runner<Context: Clone + Send + UnwindSafe + 'static>(
fn runner<Context: Clone + Send + 'static>(
database_url: &str,
context: Context,
) -> Runner<Context> {
Expand Down

0 comments on commit 71cffc6

Please sign in to comment.