Skip to content

Feat: Expose ping method from SQLx #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 1, 2023
15 changes: 15 additions & 0 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,21 @@ impl DatabaseConnection {
}
}

/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.ping().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.ping().await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.ping().await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.ping(),
DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
}
}

/// Explicitly close the database connection
pub async fn close(self) -> Result<(), DbErr> {
match self {
Expand Down
4 changes: 4 additions & 0 deletions src/database/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ impl MockDatabaseTrait for MockDatabase {
fn get_database_backend(&self) -> DbBackend {
self.db_backend
}

fn ping(&self) -> Result<(), DbErr> {
Ok(())
}
}

impl MockRow {
Expand Down
8 changes: 8 additions & 0 deletions src/driver/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ pub trait MockDatabaseTrait: Send + Debug {

/// Get the backend being used in the [MockDatabase]
fn get_database_backend(&self) -> DbBackend;

/// Ping the [MockDatabase]
fn ping(&self) -> Result<(), DbErr>;
}

impl MockDatabaseConnector {
Expand Down Expand Up @@ -194,4 +197,9 @@ impl MockDatabaseConnection {
.expect("Failed to acquire mocker")
.rollback()
}

/// Checks if a connection to the database is still valid.
pub fn ping(&self) -> Result<(), DbErr> {
self.mocker.lock().map_err(query_err)?.ping()
}
}
14 changes: 13 additions & 1 deletion src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
pool::PoolConnection,
Executor, MySql, MySqlPool,
Connection, Executor, MySql, MySqlPool,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -222,6 +222,18 @@ impl SqlxMySqlPoolConnection {
self.metric_callback = Some(Arc::new(callback));
}

/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
if let Ok(conn) = &mut self.pool.acquire().await {
match conn.ping().await {
Ok(_) => Ok(()),
Err(err) => Err(sqlx_error_to_conn_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Explicitly close the MySQL connection
pub async fn close(self) -> Result<(), DbErr> {
self.pool.close().await;
Expand Down
14 changes: 13 additions & 1 deletion src/driver/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
pool::PoolConnection,
postgres::{PgConnectOptions, PgQueryResult, PgRow},
Executor, PgPool, Postgres,
Connection, Executor, PgPool, Postgres,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -237,6 +237,18 @@ impl SqlxPostgresPoolConnection {
self.metric_callback = Some(Arc::new(callback));
}

/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
if let Ok(conn) = &mut self.pool.acquire().await {
match conn.ping().await {
Ok(_) => Ok(()),
Err(err) => Err(sqlx_error_to_conn_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Explicitly close the Postgres connection
pub async fn close(self) -> Result<(), DbErr> {
self.pool.close().await;
Expand Down
14 changes: 13 additions & 1 deletion src/driver/sqlx_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
pool::PoolConnection,
sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
Executor, Sqlite, SqlitePool,
Connection, Executor, Sqlite, SqlitePool,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -229,6 +229,18 @@ impl SqlxSqlitePoolConnection {
self.metric_callback = Some(Arc::new(callback));
}

/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
if let Ok(conn) = &mut self.pool.acquire().await {
match conn.ping().await {
Ok(_) => Ok(()),
Err(err) => Err(sqlx_error_to_conn_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Explicitly close the SQLite connection
pub async fn close(self) -> Result<(), DbErr> {
self.pool.close().await;
Expand Down
53 changes: 53 additions & 0 deletions tests/connection_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
pub mod common;

pub use common::{bakery_chain::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
pub use sea_orm::entity::*;
pub use sea_orm::*;

#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn connection_ping() {
let ctx = TestContext::new("connection_ping").await;

ctx.db.ping().await.unwrap();

ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(feature = "sqlx-mysql")]
pub async fn connection_ping_closed_mysql() {
let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await));
let ctx_ping = std::rc::Rc::clone(&ctx);

ctx.db.get_mysql_connection_pool().close().await;
assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire));
ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(feature = "sqlx-sqlite")]
pub async fn connection_ping_closed_sqlite() {
let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await));
let ctx_ping = std::rc::Rc::clone(&ctx);

ctx.db.get_sqlite_connection_pool().close().await;
assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire));
ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(feature = "sqlx-postgres")]
pub async fn connection_ping_closed_postgres() {
let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await));
let ctx_ping = std::rc::Rc::clone(&ctx);

ctx.db.get_postgres_connection_pool().close().await;
assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire));
ctx.delete().await;
}