diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 6aa9edb1e..158cf224b 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -390,6 +390,21 @@ impl DatabaseConnection { } } + /// Checks if a connection to the database is still valid. + pub async fn ping(&self) -> Result<(), DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.ping().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.ping().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.ping().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.ping(), + DatabaseConnection::Disconnected => Err(conn_err("Disconnected")), + } + } + /// Explicitly close the database connection pub async fn close(self) -> Result<(), DbErr> { match self { diff --git a/src/database/mock.rs b/src/database/mock.rs index e6cac7730..df222d8ee 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -203,6 +203,10 @@ impl MockDatabaseTrait for MockDatabase { fn get_database_backend(&self) -> DbBackend { self.db_backend } + + fn ping(&self) -> Result<(), DbErr> { + Ok(()) + } } impl MockRow { diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 570b5e530..af3145cbb 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -47,6 +47,9 @@ pub trait MockDatabaseTrait: Send + Debug { /// Get the backend being used in the [MockDatabase] fn get_database_backend(&self) -> DbBackend; + + /// Ping the [MockDatabase] + fn ping(&self) -> Result<(), DbErr>; } impl MockDatabaseConnector { @@ -194,4 +197,9 @@ impl MockDatabaseConnection { .expect("Failed to acquire mocker") .rollback() } + + /// Checks if a connection to the database is still valid. + pub fn ping(&self) -> Result<(), DbErr> { + self.mocker.lock().map_err(query_err)?.ping() + } } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 8268ee943..ebea1b837 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc}; use sqlx::{ mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow}, pool::PoolConnection, - Executor, MySql, MySqlPool, + Connection, Executor, MySql, MySqlPool, }; use sea_query_binder::SqlxValues; @@ -222,6 +222,18 @@ impl SqlxMySqlPoolConnection { self.metric_callback = Some(Arc::new(callback)); } + /// Checks if a connection to the database is still valid. + pub async fn ping(&self) -> Result<(), DbErr> { + if let Ok(conn) = &mut self.pool.acquire().await { + match conn.ping().await { + Ok(_) => Ok(()), + Err(err) => Err(sqlx_error_to_conn_err(err)), + } + } else { + Err(DbErr::ConnectionAcquire) + } + } + /// Explicitly close the MySQL connection pub async fn close(self) -> Result<(), DbErr> { self.pool.close().await; diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index d24ebf2b5..e78aa8d84 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc}; use sqlx::{ pool::PoolConnection, postgres::{PgConnectOptions, PgQueryResult, PgRow}, - Executor, PgPool, Postgres, + Connection, Executor, PgPool, Postgres, }; use sea_query_binder::SqlxValues; @@ -237,6 +237,18 @@ impl SqlxPostgresPoolConnection { self.metric_callback = Some(Arc::new(callback)); } + /// Checks if a connection to the database is still valid. + pub async fn ping(&self) -> Result<(), DbErr> { + if let Ok(conn) = &mut self.pool.acquire().await { + match conn.ping().await { + Ok(_) => Ok(()), + Err(err) => Err(sqlx_error_to_conn_err(err)), + } + } else { + Err(DbErr::ConnectionAcquire) + } + } + /// Explicitly close the Postgres connection pub async fn close(self) -> Result<(), DbErr> { self.pool.close().await; diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 6309208d5..5e1875495 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc}; use sqlx::{ pool::PoolConnection, sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow}, - Executor, Sqlite, SqlitePool, + Connection, Executor, Sqlite, SqlitePool, }; use sea_query_binder::SqlxValues; @@ -229,6 +229,18 @@ impl SqlxSqlitePoolConnection { self.metric_callback = Some(Arc::new(callback)); } + /// Checks if a connection to the database is still valid. + pub async fn ping(&self) -> Result<(), DbErr> { + if let Ok(conn) = &mut self.pool.acquire().await { + match conn.ping().await { + Ok(_) => Ok(()), + Err(err) => Err(sqlx_error_to_conn_err(err)), + } + } else { + Err(DbErr::ConnectionAcquire) + } + } + /// Explicitly close the SQLite connection pub async fn close(self) -> Result<(), DbErr> { self.pool.close().await; diff --git a/tests/connection_tests.rs b/tests/connection_tests.rs new file mode 100644 index 000000000..ae0304d57 --- /dev/null +++ b/tests/connection_tests.rs @@ -0,0 +1,53 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +use pretty_assertions::assert_eq; +pub use sea_orm::entity::*; +pub use sea_orm::*; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn connection_ping() { + let ctx = TestContext::new("connection_ping").await; + + ctx.db.ping().await.unwrap(); + + ctx.delete().await; +} + +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-mysql")] +pub async fn connection_ping_closed_mysql() { + let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await)); + let ctx_ping = std::rc::Rc::clone(&ctx); + + ctx.db.get_mysql_connection_pool().close().await; + assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire)); + ctx.delete().await; +} + +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-sqlite")] +pub async fn connection_ping_closed_sqlite() { + let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await)); + let ctx_ping = std::rc::Rc::clone(&ctx); + + ctx.db.get_sqlite_connection_pool().close().await; + assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire)); + ctx.delete().await; +} + +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +pub async fn connection_ping_closed_postgres() { + let ctx = std::rc::Rc::new(Box::new(TestContext::new("connection_ping_closed").await)); + let ctx_ping = std::rc::Rc::clone(&ctx); + + ctx.db.get_postgres_connection_pool().close().await; + assert_eq!(ctx_ping.db.ping().await, Err(DbErr::ConnectionAcquire)); + ctx.delete().await; +}