Skip to content

Commit

Permalink
Execute unprepared statement (#1327)
Browse files Browse the repository at this point in the history
  • Loading branch information
billy1624 authored Jan 5, 2023
1 parent d332afa commit e927a0e
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pub trait ConnectionTrait: Sync {
/// Execute a [Statement]
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr>;

/// Execute a unprepared [Statement]
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr>;

/// Execute a [Statement] and return a query
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr>;

Expand Down
26 changes: 26 additions & 0 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,32 @@ impl ConnectionTrait for DatabaseConnection {
}
}

#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute_unprepared(sql).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
conn.execute_unprepared(sql).await
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
conn.execute_unprepared(sql).await
}
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
let db_backend = conn.get_database_backend();
let stmt = Statement::from_string(db_backend, sql.into());
conn.execute(stmt)
}
DatabaseConnection::Disconnected => {
Err(DbErr::Conn(RuntimeErr::Internal("Disconnected".to_owned())))
}
}
}

#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Expand Down
32 changes: 32 additions & 0 deletions src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,38 @@ impl ConnectionTrait for DatabaseTransaction {
}
}

#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
debug_print!("{}", sql);

match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => sqlx::Executor::execute(conn, sql)
.await
.map(Into::into)
.map_err(sqlx_error_to_exec_err),
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => sqlx::Executor::execute(conn, sql)
.await
.map(Into::into)
.map_err(sqlx_error_to_exec_err),
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => sqlx::Executor::execute(conn, sql)
.await
.map(Into::into)
.map_err(sqlx_error_to_exec_err),
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => {
let db_backend = conn.get_database_backend();
let stmt = Statement::from_string(db_backend, sql.into());
conn.execute(stmt)
}
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}

#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Expand Down
15 changes: 15 additions & 0 deletions src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ impl SqlxMySqlPoolConnection {
}
}

/// Execute an unprepared SQL statement on a MySQL backend
#[instrument(level = "trace")]
pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
debug_print!("{}", sql);

if let Ok(conn) = &mut self.pool.acquire().await {
match conn.execute(sql).await {
Ok(res) => Ok(res.into()),
Err(err) => Err(sqlx_error_to_exec_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Get one result from a SQL query. Returns [Option::None] if no match was found
#[instrument(level = "trace")]
pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Expand Down
15 changes: 15 additions & 0 deletions src/driver/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ impl SqlxPostgresPoolConnection {
}
}

/// Execute an unprepared SQL statement on a PostgreSQL backend
#[instrument(level = "trace")]
pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
debug_print!("{}", sql);

if let Ok(conn) = &mut self.pool.acquire().await {
match conn.execute(sql).await {
Ok(res) => Ok(res.into()),
Err(err) => Err(sqlx_error_to_exec_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Get one result from a SQL query. Returns [Option::None] if no match was found
#[instrument(level = "trace")]
pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Expand Down
17 changes: 16 additions & 1 deletion src/driver/sqlx_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use sqlx::{
pool::PoolConnection,
sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
Sqlite, SqlitePool,
Executor, Sqlite, SqlitePool,
};

use sea_query_binder::SqlxValues;
Expand Down Expand Up @@ -101,6 +101,21 @@ impl SqlxSqlitePoolConnection {
}
}

/// Execute an unprepared SQL statement on a SQLite backend
#[instrument(level = "trace")]
pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
debug_print!("{}", sql);

if let Ok(conn) = &mut self.pool.acquire().await {
match conn.execute(sql).await {
Ok(res) => Ok(res.into()),
Err(err) => Err(sqlx_error_to_exec_err(err)),
}
} else {
Err(DbErr::ConnectionAcquire)
}
}

/// Get one result from a SQL query. Returns [Option::None] if no match was found
#[instrument(level = "trace")]
pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Expand Down
41 changes: 41 additions & 0 deletions tests/execute_unprepared_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
pub mod common;

pub use common::{features::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
use sea_orm::{entity::prelude::*, ConnectionTrait, DatabaseConnection};

#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
async fn main() -> Result<(), DbErr> {
let ctx = TestContext::new("execute_unprepared_tests").await;
create_tables(&ctx.db).await?;
execute_unprepared(&ctx.db).await?;
ctx.delete().await;

Ok(())
}

pub async fn execute_unprepared(db: &DatabaseConnection) -> Result<(), DbErr> {
use insert_default::*;

db.execute_unprepared(
[
"INSERT INTO insert_default VALUES (1), (2), (3), (4), (5)",
"DELETE FROM insert_default WHERE id % 2 = 0",
]
.join(";")
.as_str(),
)
.await?;

assert_eq!(
Entity::find().all(db).await?,
vec![Model { id: 1 }, Model { id: 3 }, Model { id: 5 },]
);

Ok(())
}

0 comments on commit e927a0e

Please sign in to comment.