Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Expose ping method from SQLx #1627

Merged
merged 15 commits into from
Jun 1, 2023
7 changes: 7 additions & 0 deletions sea-orm-migration/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ impl<'c> ConnectionTrait for SchemaManagerConnection<'c> {
SchemaManagerConnection::Transaction(trans) => trans.is_mock_connection(),
}
}

async fn ping(&self) -> Result<(), DbErr> {
match self {
SchemaManagerConnection::Connection(conn) => conn.ping().await,
SchemaManagerConnection::Transaction(trans) => trans.ping().await,
}
}
}

#[async_trait::async_trait]
Expand Down
3 changes: 3 additions & 0 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub trait ConnectionTrait: Sync {
fn is_mock_connection(&self) -> bool {
false
}

/// Checks if a connection to the database is still valid.
async fn ping(&self) -> Result<(), DbErr>;
}

/// Stream query results
Expand Down
14 changes: 14 additions & 0 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ impl ConnectionTrait for DatabaseConnection {
fn is_mock_connection(&self) -> bool {
matches!(self, DatabaseConnection::MockDatabaseConnection(_))
}

async fn ping(&self) -> Result<(), DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.ping().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.ping().await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.ping().await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.ping(),
DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
}
}
}

#[async_trait::async_trait]
Expand Down
4 changes: 4 additions & 0 deletions src/database/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ impl MockDatabaseTrait for MockDatabase {
fn get_database_backend(&self) -> DbBackend {
self.db_backend
}

fn ping(&self) -> Result<(), DbErr> {
Ok(())
}
}

impl MockRow {
Expand Down
21 changes: 19 additions & 2 deletions src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use crate::{
TransactionTrait,
};
#[cfg(feature = "sqlx-dep")]
use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
use crate::{sqlx_error_to_conn_err, sqlx_error_to_exec_err, sqlx_error_to_query_err};
use futures::lock::Mutex;

#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, TransactionManager};
use sqlx::{pool::PoolConnection, Connection, TransactionManager};
use std::{future::Future, pin::Pin, sync::Arc};
use tracing::instrument;

Expand Down Expand Up @@ -457,6 +458,22 @@ impl ConnectionTrait for DatabaseTransaction {
_ => Err(conn_err("Disconnected")),
}
}

#[allow(unused_variables)]
async fn ping(&self) -> Result<(), DbErr> {
match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => conn.ping().await.map_err(sqlx_error_to_conn_err),
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => conn.ping().await.map_err(sqlx_error_to_conn_err),
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => conn.ping().await.map_err(sqlx_error_to_conn_err),
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.ping(),
#[allow(unreachable_patterns)]
_ => Err(conn_err("Disconnected")),
}
}
}

impl StreamTrait for DatabaseTransaction {
Expand Down
8 changes: 8 additions & 0 deletions src/driver/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ pub trait MockDatabaseTrait: Send + Debug {

/// Get the backend being used in the [MockDatabase]
fn get_database_backend(&self) -> DbBackend;

/// Ping the [MockDatabase]
fn ping(&self) -> Result<(), DbErr>;
}

impl MockDatabaseConnector {
Expand Down Expand Up @@ -194,4 +197,9 @@ impl MockDatabaseConnection {
.expect("Failed to acquire mocker")
.rollback()
}

/// Checks if a connection to the database is still valid.
pub fn ping(&self) -> Result<(), DbErr> {
self.mocker.lock().map_err(query_err)?.ping()
}
}
14 changes: 13 additions & 1 deletion src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
pool::PoolConnection,
Executor, MySql, MySqlPool,
Connection, Executor, MySql, MySqlPool,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -222,6 +222,18 @@ impl SqlxMySqlPoolConnection {
self.metric_callback = Some(Arc::new(callback));
}

/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
if let Ok(conn) = &mut self.pool.acquire().await {
match conn.ping().await {
Ok(_) => Ok(()),
Err(err) => Err(sqlx_error_to_conn_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Explicitly close the MySQL connection
pub async fn close(self) -> Result<(), DbErr> {
self.pool.close().await;
Expand Down
14 changes: 13 additions & 1 deletion src/driver/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
pool::PoolConnection,
postgres::{PgConnectOptions, PgQueryResult, PgRow},
Executor, PgPool, Postgres,
Connection, Executor, PgPool, Postgres,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -237,6 +237,18 @@ impl SqlxPostgresPoolConnection {
self.metric_callback = Some(Arc::new(callback));
}

/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
if let Ok(conn) = &mut self.pool.acquire().await {
match conn.ping().await {
Ok(_) => Ok(()),
Err(err) => Err(sqlx_error_to_conn_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Explicitly close the Postgres connection
pub async fn close(self) -> Result<(), DbErr> {
self.pool.close().await;
Expand Down
15 changes: 14 additions & 1 deletion src/driver/sqlx_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
pool::PoolConnection,
sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
Executor, Sqlite, SqlitePool,
Connection, Executor, Sqlite, SqlitePool,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -229,6 +229,19 @@ impl SqlxSqlitePoolConnection {
self.metric_callback = Some(Arc::new(callback));
}

/// Checks if a connection to the database is still valid.
tyt2y3 marked this conversation as resolved.
Show resolved Hide resolved
/// Checks if a connection to the database is still valid.
pub async fn ping(&self) -> Result<(), DbErr> {
if let Ok(conn) = &mut self.pool.acquire().await {
match conn.ping().await {
Ok(_) => Ok(()),
Err(err) => Err(sqlx_error_to_conn_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Explicitly close the SQLite connection
pub async fn close(self) -> Result<(), DbErr> {
self.pool.close().await;
Expand Down
53 changes: 53 additions & 0 deletions tests/connection_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
pub mod common;

pub use common::{bakery_chain::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
pub use sea_orm::entity::*;
pub use sea_orm::*;

#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn connection_ping() {
let ctx = TestContext::new("connection_ping").await;

ctx.db.ping().await.unwrap();

ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(feature = "sqlx-mysql")]
pub async fn connection_ping_closed_mysql() {
let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await));
let ctx_ping = std::rc::Rc::clone(&ctx);

ctx.db.get_mysql_connection_pool().close().await;
assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire));
ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(feature = "sqlx-sqlite")]
pub async fn connection_ping_closed_sqlite() {
let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await));
let ctx_ping = std::rc::Rc::clone(&ctx);

ctx.db.get_sqlite_connection_pool().close().await;
assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire));
ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(feature = "sqlx-postgres")]
pub async fn connection_ping_closed_postgres() {
let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await));
let ctx_ping = std::rc::Rc::clone(&ctx);

ctx.db.get_postgres_connection_pool().close().await;
assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire));
ctx.delete().await;
}
14 changes: 14 additions & 0 deletions tests/transaction_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ pub async fn transaction() {
ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn transaction_ping() {
let ctx = TestContext::new("transaction_ping").await;

ctx.db.transaction(|txn| txn.ping()).await.unwrap();

ctx.delete().await;
}

#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
Expand Down