From 8248f95b2776e9b336625556a3782d482d4524d2 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Thu, 22 Dec 2022 16:32:30 +0800 Subject: [PATCH] Execute unprepared statement --- src/database/connection.rs | 3 +++ src/database/db_connection.rs | 26 ++++++++++++++++++++ src/database/transaction.rs | 32 ++++++++++++++++++++++++ src/driver/sqlx_mysql.rs | 15 +++++++++++ src/driver/sqlx_postgres.rs | 15 +++++++++++ src/driver/sqlx_sqlite.rs | 17 ++++++++++++- tests/execute_unprepared_tests.rs | 41 +++++++++++++++++++++++++++++++ 7 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 tests/execute_unprepared_tests.rs diff --git a/src/database/connection.rs b/src/database/connection.rs index b6a3c1652..389c84c8c 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -15,6 +15,9 @@ pub trait ConnectionTrait: Sync { /// Execute a [Statement] async fn execute(&self, stmt: Statement) -> Result; + /// Execute a unprepared [Statement] + async fn execute_unprepared(&self, sql: &str) -> Result; + /// Execute a [Statement] and return a query async fn query_one(&self, stmt: Statement) -> Result, DbErr>; diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index e0839d3c8..9309e9686 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -122,6 +122,32 @@ impl ConnectionTrait for DatabaseConnection { } } + #[instrument(level = "trace")] + #[allow(unused_variables)] + async fn execute_unprepared(&self, sql: &str) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute_unprepared(sql).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + conn.execute_unprepared(sql).await + } + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => { + conn.execute_unprepared(sql).await + } + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + let db_backend = conn.get_database_backend(); + let stmt = Statement::from_string(db_backend, sql.into()); + conn.execute(stmt) + } + DatabaseConnection::Disconnected => { + Err(DbErr::Conn(RuntimeErr::Internal("Disconnected".to_owned()))) + } + } + } + #[instrument(level = "trace")] #[allow(unused_variables)] async fn query_one(&self, stmt: Statement) -> Result, DbErr> { diff --git a/src/database/transaction.rs b/src/database/transaction.rs index fa8d57755..558ce70a1 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -330,6 +330,38 @@ impl ConnectionTrait for DatabaseTransaction { } } + #[instrument(level = "trace")] + #[allow(unused_variables)] + async fn execute_unprepared(&self, sql: &str) -> Result { + debug_print!("{}", sql); + + match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => sqlx::Executor::execute(conn, sql) + .await + .map(Into::into) + .map_err(sqlx_error_to_exec_err), + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => sqlx::Executor::execute(conn, sql) + .await + .map(Into::into) + .map_err(sqlx_error_to_exec_err), + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => sqlx::Executor::execute(conn, sql) + .await + .map(Into::into) + .map_err(sqlx_error_to_exec_err), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => { + let db_backend = conn.get_database_backend(); + let stmt = Statement::from_string(db_backend, sql.into()); + conn.execute(stmt) + } + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } + #[instrument(level = "trace")] #[allow(unused_variables)] async fn query_one(&self, stmt: Statement) -> Result, DbErr> { diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 4ee9c7378..f73efbd6a 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -94,6 +94,21 @@ impl SqlxMySqlPoolConnection { } } + /// Execute an unprepared SQL statement on a MySQL backend + #[instrument(level = "trace")] + pub async fn execute_unprepared(&self, sql: &str) -> Result { + debug_print!("{}", sql); + + if let Ok(conn) = &mut self.pool.acquire().await { + match conn.execute(sql).await { + Ok(res) => Ok(res.into()), + Err(err) => Err(sqlx_error_to_exec_err(err)), + } + } else { + Err(DbErr::ConnectionAcquire) + } + } + /// Get one result from a SQL query. Returns [Option::None] if no match was found #[instrument(level = "trace")] pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index ff4d55ada..716f111d7 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -109,6 +109,21 @@ impl SqlxPostgresPoolConnection { } } + /// Execute an unprepared SQL statement on a PostgreSQL backend + #[instrument(level = "trace")] + pub async fn execute_unprepared(&self, sql: &str) -> Result { + debug_print!("{}", sql); + + if let Ok(conn) = &mut self.pool.acquire().await { + match conn.execute(sql).await { + Ok(res) => Ok(res.into()), + Err(err) => Err(sqlx_error_to_exec_err(err)), + } + } else { + Err(DbErr::ConnectionAcquire) + } + } + /// Get one result from a SQL query. Returns [Option::None] if no match was found #[instrument(level = "trace")] pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index c0d7e3f0f..b67fadb9b 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc}; use sqlx::{ pool::PoolConnection, sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow}, - Sqlite, SqlitePool, + Executor, Sqlite, SqlitePool, }; use sea_query_binder::SqlxValues; @@ -101,6 +101,21 @@ impl SqlxSqlitePoolConnection { } } + /// Execute an unprepared SQL statement on a SQLite backend + #[instrument(level = "trace")] + pub async fn execute_unprepared(&self, sql: &str) -> Result { + debug_print!("{}", sql); + + if let Ok(conn) = &mut self.pool.acquire().await { + match conn.execute(sql).await { + Ok(res) => Ok(res.into()), + Err(err) => Err(sqlx_error_to_exec_err(err)), + } + } else { + Err(DbErr::ConnectionAcquire) + } + } + /// Get one result from a SQL query. Returns [Option::None] if no match was found #[instrument(level = "trace")] pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { diff --git a/tests/execute_unprepared_tests.rs b/tests/execute_unprepared_tests.rs new file mode 100644 index 000000000..7b1082fb5 --- /dev/null +++ b/tests/execute_unprepared_tests.rs @@ -0,0 +1,41 @@ +pub mod common; + +pub use common::{features::*, setup::*, TestContext}; +use pretty_assertions::assert_eq; +use sea_orm::{entity::prelude::*, ConnectionTrait, DatabaseConnection}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("execute_unprepared_tests").await; + create_tables(&ctx.db).await?; + execute_unprepared(&ctx.db).await?; + ctx.delete().await; + + Ok(()) +} + +pub async fn execute_unprepared(db: &DatabaseConnection) -> Result<(), DbErr> { + use insert_default::*; + + db.execute_unprepared( + [ + "INSERT INTO insert_default VALUES (1), (2), (3), (4), (5)", + "DELETE FROM insert_default WHERE id % 2 = 0", + ] + .join(";") + .as_str(), + ) + .await?; + + assert_eq!( + Entity::find().all(db).await?, + vec![Model { id: 1 }, Model { id: 3 }, Model { id: 5 },] + ); + + Ok(()) +}