diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 4edb9d93bc..e001f08fa3 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -62,9 +62,15 @@ impl Connection for SqliteConnection { type Options = SqliteConnectOptions; - fn close(self) -> BoxFuture<'static, Result<(), Error>> { - // nothing explicit to do; connection will close in drop - Box::pin(future::ok(())) + fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + let shutdown = self.worker.shutdown(); + // Drop the statement worker and any outstanding statements, which should + // cover all references to the connection handle outside of the worker thread + drop(self); + // Ensure the worker thread has terminated + shutdown.await + }) } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index b5a7cf88ed..5a06f637b0 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -30,6 +30,9 @@ enum StatementWorkerCommand { statement: Weak, tx: oneshot::Sender<()>, }, + Shutdown { + tx: oneshot::Sender<()>, + }, } impl StatementWorker { @@ -72,6 +75,13 @@ impl StatementWorker { let _ = tx.send(()); } } + StatementWorkerCommand::Shutdown { tx } => { + // drop the connection reference before sending confirmation + // and ending the command loop + drop(conn); + let _ = tx.send(()); + return; + } } } @@ -127,4 +137,25 @@ impl StatementWorker { rx.await.map_err(|_| Error::WorkerCrashed) } } + + /// Send a command to the worker to shut down the processing thread. + /// + /// A `WorkerCrashed` error may be returned if the thread has already stopped. + /// Subsequent calls to `step()`, `reset()`, or this method will fail with + /// `WorkerCrashed`. Ensure that any associated statements are dropped first. + pub(crate) fn shutdown(&mut self) -> impl Future> { + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Shutdown { tx }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index ea983540e9..12f1834e8c 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -206,7 +206,8 @@ async fn it_executes_with_pool() -> anyhow::Result<()> { async fn it_opens_in_memory() -> anyhow::Result<()> { // If the filename is ":memory:", then a private, temporary in-memory database // is created for the connection. - let _ = SqliteConnection::connect(":memory:").await?; + let conn = SqliteConnection::connect(":memory:").await?; + conn.close().await?; Ok(()) } @@ -215,7 +216,8 @@ async fn it_opens_in_memory() -> anyhow::Result<()> { async fn it_opens_temp_on_disk() -> anyhow::Result<()> { // If the filename is an empty string, then a private, temporary on-disk database will // be created. - let _ = SqliteConnection::connect("").await?; + let conn = SqliteConnection::connect("").await?; + conn.close().await?; Ok(()) }