Skip to content

Commit

Permalink
feat: add shutdown method to AsyncDB DB and Runner (#250)
Browse files Browse the repository at this point in the history
* add shutdown interface & use

Signed-off-by: Bugen Zhao <[email protected]>

* fix postgres impl

Signed-off-by: Bugen Zhao <[email protected]>

* manual shutdown

Signed-off-by: Bugen Zhao <[email protected]>

* remove f word

Signed-off-by: Bugen Zhao <[email protected]>

* also add shutdown method to sync db

Signed-off-by: Bugen Zhao <[email protected]>

* bump version and add change log

Signed-off-by: Bugen Zhao <[email protected]>

---------

Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Feb 11, 2025
1 parent c3b8c52 commit 89d0d3c
Show file tree
Hide file tree
Showing 15 changed files with 95 additions and 33 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## [0.27.0] - 2025-02-11

* runner: add `shutdown` method to `DB` and `AsyncDB` trait to allow for graceful shutdown of the database connection. Users are encouraged to call `Runner::shutdown` or `Runner::shutdown_async` after running tests to ensure that the database connections are properly closed.

## [0.26.4] - 2025-01-27

* runner: add random string in path generation to avoid conflict when using `include`.
Expand Down
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]

[workspace.package]
version = "0.26.4"
version = "0.27.0"
edition = "2021"
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
keywords = ["sql", "database", "parser", "cli"]
Expand Down
4 changes: 2 additions & 2 deletions sqllogictest-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ glob = "0.3"
itertools = "0.13"
quick-junit = { version = "0.5" }
rand = "0.8"
sqllogictest = { path = "../sqllogictest", version = "0.26" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.26" }
sqllogictest = { path = "../sqllogictest", version = "0.27" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.27" }
tokio = { version = "1", features = [
"rt",
"rt-multi-thread",
Expand Down
4 changes: 4 additions & 0 deletions sqllogictest-bin/src/engines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,8 @@ impl AsyncDB for Engines {
async fn run_command(command: std::process::Command) -> std::io::Result<std::process::Output> {
Command::from(command).output().await
}

async fn shutdown(&mut self) {
dispatch_engines!(self, e, { e.shutdown().await })
}
}
23 changes: 15 additions & 8 deletions sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ async fn run_parallel(
}
}

// Shutdown the connection for managing temporary databases.
db.shutdown().await;

if !failed_case.is_empty() {
Err(anyhow!("some test case failed:\n{:#?}", failed_case))
} else {
Expand Down Expand Up @@ -467,7 +470,7 @@ async fn run_serial(
let filename = file.to_string_lossy().to_string();
let test_case_name = filename.replace(['/', ' ', '.', '-'], "_");
let mut failed = false;
let case = match run_test_file(&mut std::io::stdout(), runner, &file).await {
let case = match run_test_file(&mut std::io::stdout(), &mut runner, &file).await {
Ok(duration) => {
let mut case = TestCase::new(test_case_name, TestCaseStatus::success());
case.set_time(duration);
Expand Down Expand Up @@ -495,6 +498,7 @@ async fn run_serial(
case
}
};
runner.shutdown_async().await;
test_suite.add_test_case(case);
if connection_refused {
eprintln!("Connection refused. The server may be down. Exiting...");
Expand Down Expand Up @@ -534,14 +538,16 @@ async fn update_test_files(
format: bool,
) -> Result<()> {
for file in files {
let runner = Runner::new(|| engines::connect(engine, &config));
let mut runner = Runner::new(|| engines::connect(engine, &config));

if let Err(e) = update_test_file(&mut std::io::stdout(), runner, &file, format).await {
if let Err(e) = update_test_file(&mut std::io::stdout(), &mut runner, &file, format).await {
{
println!("{}\n\n{:?}", style("[FAILED]").red().bold(), e);
println!();
}
};

runner.shutdown_async().await;
}

Ok(())
Expand All @@ -562,16 +568,17 @@ async fn connect_and_run_test_file(
for label in labels {
runner.add_label(label);
}
let result = run_test_file(out, runner, filename).await?;
let result = run_test_file(out, &mut runner, filename).await;
runner.shutdown_async().await;

Ok(result)
result
}

/// Different from [`Runner::run_file_async`], we re-implement it here to print some progress
/// information.
async fn run_test_file<T: std::io::Write, M: MakeConnection>(
out: &mut T,
mut runner: Runner<M::Conn, M>,
runner: &mut Runner<M::Conn, M>,
filename: impl AsRef<Path>,
) -> Result<Duration> {
let filename = filename.as_ref();
Expand Down Expand Up @@ -676,7 +683,7 @@ fn finish_test_file<T: std::io::Write>(
/// progress information.
async fn update_test_file<T: std::io::Write, M: MakeConnection>(
out: &mut T,
mut runner: Runner<M::Conn, M>,
runner: &mut Runner<M::Conn, M>,
filename: impl AsRef<Path>,
format: bool,
) -> Result<()> {
Expand Down Expand Up @@ -804,7 +811,7 @@ async fn update_test_file<T: std::io::Write, M: MakeConnection>(
writeln!(outfile, "{record}")?;
continue;
}
update_record(outfile, &mut runner, record, format)
update_record(outfile, runner, record, format)
.await
.context(format!("failed to run `{}`", style(filename).bold()))?;
}
Expand Down
2 changes: 1 addition & 1 deletion sqllogictest-engines/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sqllogictest = { path = "../sqllogictest", version = "0.26" }
sqllogictest = { path = "../sqllogictest", version = "0.27" }
thiserror = "2"
tokio = { version = "1", features = [
"rt",
Expand Down
5 changes: 5 additions & 0 deletions sqllogictest-engines/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ impl AsyncDB for ExternalDriver {
}
}

async fn shutdown(&mut self) {
self.stdin.shutdown().await.ok();
self.child.wait().await.ok();
}

fn engine_name(&self) -> &str {
"external"
}
Expand Down
4 changes: 4 additions & 0 deletions sqllogictest-engines/src/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ impl sqllogictest::AsyncDB for MySql {
}
}

async fn shutdown(&mut self) {
self.pool.clone().disconnect().await.ok();
}

fn engine_name(&self) -> &str {
"mysql"
}
Expand Down
24 changes: 12 additions & 12 deletions sqllogictest-engines/src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod extended;
mod simple;

use std::marker::PhantomData;
use std::sync::Arc;

use tokio::task::JoinHandle;

Expand All @@ -16,8 +15,8 @@ pub struct Extended;
/// Generic Postgres engine based on the client from [`tokio_postgres`]. The protocol `P` can be
/// either [`Simple`] or [`Extended`].
pub struct Postgres<P> {
client: Arc<tokio_postgres::Client>,
join_handle: JoinHandle<()>,
/// `None` means the connection is closed.
conn: Option<(tokio_postgres::Client, JoinHandle<()>)>,
_protocol: PhantomData<P>,
}

Expand All @@ -34,27 +33,28 @@ impl<P> Postgres<P> {
pub async fn connect(config: PostgresConfig) -> Result<Self> {
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;

let join_handle = tokio::spawn(async move {
let connection = tokio::spawn(async move {
if let Err(e) = connection.await {
log::error!("Postgres connection error: {:?}", e);
}
});

Ok(Self {
client: Arc::new(client),
join_handle,
conn: Some((client, connection)),
_protocol: PhantomData,
})
}

/// Returns a reference of the inner Postgres client.
pub fn pg_client(&self) -> &tokio_postgres::Client {
&self.client
pub fn client(&self) -> &tokio_postgres::Client {
&self.conn.as_ref().expect("connection is shutdown").0
}
}

impl<P> Drop for Postgres<P> {
fn drop(&mut self) {
self.join_handle.abort()
/// Shutdown the Postgres connection.
async fn shutdown(&mut self) {
if let Some((client, connection)) = self.conn.take() {
drop(client);
connection.await.ok();
}
}
}
12 changes: 8 additions & 4 deletions sqllogictest-engines/src/postgres/extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ macro_rules! array_process {
match v {
Some(v) => {
let sql = format!("select ($1::{})::varchar", stringify!($ty_name));
let tmp_rows = $self.client.query(&sql, &[&v]).await.unwrap();
let tmp_rows = $self.client().query(&sql, &[&v]).await.unwrap();
let value: &str = tmp_rows.get(0).unwrap().get(0);
assert!(value.len() > 0);
write!(output, "{}", value).unwrap();
Expand Down Expand Up @@ -128,7 +128,7 @@ macro_rules! single_process {
match value {
Some(value) => {
let sql = format!("select ($1::{})::varchar", stringify!($ty_name));
let tmp_rows = $self.client.query(&sql, &[&value]).await.unwrap();
let tmp_rows = $self.client().query(&sql, &[&value]).await.unwrap();
let value: &str = tmp_rows.get(0).unwrap().get(0);
assert!(value.len() > 0);
$row_vec.push(value.to_string());
Expand Down Expand Up @@ -188,9 +188,9 @@ impl sqllogictest::AsyncDB for Postgres<Extended> {
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>> {
let mut output = vec![];

let stmt = self.client.prepare(sql).await?;
let stmt = self.client().prepare(sql).await?;
let rows = self
.client
.client()
.query_raw(&stmt, std::iter::empty::<&(dyn ToSql + Sync)>())
.await?;

Expand Down Expand Up @@ -311,6 +311,10 @@ impl sqllogictest::AsyncDB for Postgres<Extended> {
}
}

async fn shutdown(&mut self) {
self.shutdown().await;
}

fn engine_name(&self) -> &str {
"postgres-extended"
}
Expand Down
6 changes: 5 additions & 1 deletion sqllogictest-engines/src/postgres/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl sqllogictest::AsyncDB for Postgres<Simple> {
// and we have to follow the format given by the specific database (pg).
// For example, postgres will output `t` as true and `f` as false,
// thus we have to write `t`/`f` in the expected results.
let rows = self.client.simple_query(sql).await?;
let rows = self.client().simple_query(sql).await?;
let mut cnt = 0;
for row in rows {
let mut row_vec = vec![];
Expand Down Expand Up @@ -62,6 +62,10 @@ impl sqllogictest::AsyncDB for Postgres<Simple> {
}
}

async fn shutdown(&mut self) {
self.shutdown().await;
}

fn engine_name(&self) -> &str {
"postgres"
}
Expand Down
6 changes: 6 additions & 0 deletions sqllogictest/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::future::IntoFuture;

use futures::future::join_all;
use futures::Future;

use crate::{AsyncDB, Connection as ConnectionName, DBOutput};
Expand Down Expand Up @@ -68,4 +69,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Connections<D, M> {
pub async fn run_default(&mut self, sql: &str) -> Result<DBOutput<D::ColumnType>, D::Error> {
self.get(ConnectionName::Default).await?.run(sql).await
}

/// Shutdown all connections.
pub async fn shutdown_all(&mut self) {
join_all(self.conns.values_mut().map(|conn| conn.shutdown())).await;
}
}
1 change: 1 addition & 0 deletions sqllogictest/src/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ macro_rules! harness {
pub fn test(filename: impl AsRef<Path>, make_conn: impl MakeConnection) -> Result<(), Failed> {
let mut tester = Runner::new(make_conn);
tester.run_file(filename)?;
tester.shutdown();
Ok(())
}
25 changes: 24 additions & 1 deletion sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pub trait AsyncDB {
/// Async run a SQL query and return the output.
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error>;

/// Shutdown the connection gracefully.
async fn shutdown(&mut self);

/// Engine name of current database.
fn engine_name(&self) -> &str {
""
Expand Down Expand Up @@ -106,6 +109,9 @@ pub trait DB {
/// Run a SQL query and return the output.
fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error>;

/// Shutdown the connection gracefully.
fn shutdown(&mut self) {}

/// Engine name of current database.
fn engine_name(&self) -> &str {
""
Expand All @@ -125,6 +131,10 @@ where
D::run(self, sql)
}

async fn shutdown(&mut self) {
D::shutdown(self);
}

fn engine_name(&self) -> &str {
D::engine_name(self)
}
Expand Down Expand Up @@ -512,7 +522,7 @@ pub fn strict_column_validator<T: ColumnType>(actual: &Vec<T>, expected: &Vec<T>
}

/// Sqllogictest runner.
pub struct Runner<D: AsyncDB, M: MakeConnection> {
pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
conn: Connections<D, M>,
// validator is used for validate if the result of query equals to expected.
validator: Validator,
Expand Down Expand Up @@ -1472,6 +1482,19 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
}
}

impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
/// Shutdown all connections in the runner.
pub async fn shutdown_async(&mut self) {
tracing::debug!("shutting down runner...");
self.conn.shutdown_all().await;
}

/// Shutdown all connections in the runner.
pub fn shutdown(&mut self) {
block_on(self.shutdown_async());
}
}

/// Updates the specified [`Record`] with the [`QueryOutput`] produced
/// by a Database, returning `Some(new_record)`.
///
Expand Down

0 comments on commit 89d0d3c

Please sign in to comment.