Skip to content

Commit

Permalink
feat: String parameters accept any Into<String> (#1439)
Browse files Browse the repository at this point in the history
  • Loading branch information
billy1624 authored Apr 11, 2023
1 parent 8f785c1 commit 6833529
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl ConnectionTrait for DatabaseConnection {
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
let db_backend = conn.get_database_backend();
let stmt = Statement::from_string(db_backend, sql.into());
let stmt = Statement::from_string(db_backend, sql);
conn.execute(stmt)
}
DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
Expand Down
66 changes: 25 additions & 41 deletions src/database/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,27 @@ where
}
}

impl IntoMockRow for BTreeMap<String, Value> {
fn into_mock_row(self) -> MockRow {
MockRow {
values: self.into_iter().map(|(k, v)| (k, v)).collect(),
}
}
}

impl IntoMockRow for BTreeMap<&str, Value> {
impl<T> IntoMockRow for BTreeMap<T, Value>
where
T: Into<String>,
{
fn into_mock_row(self) -> MockRow {
MockRow {
values: self.into_iter().map(|(k, v)| (k.to_owned(), v)).collect(),
values: self.into_iter().map(|(k, v)| (k.into(), v)).collect(),
}
}
}

impl Transaction {
/// Get the [Value]s from s raw SQL statement depending on the [DatabaseBackend](crate::DatabaseBackend)
pub fn from_sql_and_values<I>(db_backend: DbBackend, sql: &str, values: I) -> Self
pub fn from_sql_and_values<I, T>(db_backend: DbBackend, sql: T, values: I) -> Self
where
I: IntoIterator<Item = Value>,
T: Into<String>,
{
Self::one(Statement::from_string_values_tuple(
db_backend,
(sql.to_string(), Values(values.into_iter().collect())),
(sql, Values(values.into_iter().collect())),
))
}

Expand Down Expand Up @@ -336,10 +332,7 @@ impl Transaction {
impl OpenTransaction {
fn init() -> Self {
Self {
stmts: vec![Statement::from_string(
DbBackend::Postgres,
"BEGIN".to_owned(),
)],
stmts: vec![Statement::from_string(DbBackend::Postgres, "BEGIN")],
transaction_depth: 0,
}
}
Expand All @@ -354,7 +347,7 @@ impl OpenTransaction {

fn commit(&mut self, db_backend: DbBackend) -> bool {
if self.transaction_depth == 0 {
self.push(Statement::from_string(db_backend, "COMMIT".to_owned()));
self.push(Statement::from_string(db_backend, "COMMIT"));
true
} else {
self.push(Statement::from_string(
Expand All @@ -368,7 +361,7 @@ impl OpenTransaction {

fn rollback(&mut self, db_backend: DbBackend) -> bool {
if self.transaction_depth == 0 {
self.push(Statement::from_string(db_backend, "ROLLBACK".to_owned()));
self.push(Statement::from_string(db_backend, "ROLLBACK"));
true
} else {
self.push(Statement::from_string(
Expand Down Expand Up @@ -433,7 +426,7 @@ mod tests {
db.into_transaction_log(),
[
Transaction::many([
Statement::from_string(DbBackend::Postgres, "BEGIN".to_owned()),
Statement::from_string(DbBackend::Postgres, "BEGIN"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
Expand All @@ -444,7 +437,7 @@ mod tests {
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
[]
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
Statement::from_string(DbBackend::Postgres, "COMMIT"),
]),
Transaction::from_sql_and_values(
DbBackend::Postgres,
Expand Down Expand Up @@ -478,13 +471,13 @@ mod tests {
assert_eq!(
db.into_transaction_log(),
[Transaction::many([
Statement::from_string(DbBackend::Postgres, "BEGIN".to_owned()),
Statement::from_string(DbBackend::Postgres, "BEGIN"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
[1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "ROLLBACK".to_owned()),
Statement::from_string(DbBackend::Postgres, "ROLLBACK"),
])]
);
}
Expand Down Expand Up @@ -516,23 +509,20 @@ mod tests {
assert_eq!(
db.into_transaction_log(),
[Transaction::many([
Statement::from_string(DbBackend::Postgres, "BEGIN".to_owned()),
Statement::from_string(DbBackend::Postgres, "BEGIN"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
[1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
[]
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_1".to_owned()
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
Statement::from_string(DbBackend::Postgres, "RELEASE SAVEPOINT savepoint_1"),
Statement::from_string(DbBackend::Postgres, "COMMIT"),
]),]
);
}
Expand Down Expand Up @@ -574,33 +564,27 @@ mod tests {
assert_eq!(
db.into_transaction_log(),
[Transaction::many([
Statement::from_string(DbBackend::Postgres, "BEGIN".to_owned()),
Statement::from_string(DbBackend::Postgres, "BEGIN"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
[1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
[]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2".to_owned()),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2"),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake""#,
[]
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_2".to_owned()
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_1".to_owned()
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
Statement::from_string(DbBackend::Postgres, "RELEASE SAVEPOINT savepoint_2"),
Statement::from_string(DbBackend::Postgres, "RELEASE SAVEPOINT savepoint_1"),
Statement::from_string(DbBackend::Postgres, "COMMIT"),
]),]
);
}
Expand Down
39 changes: 16 additions & 23 deletions src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,23 @@ impl Database {
}
}

impl From<&str> for ConnectOptions {
fn from(string: &str) -> ConnectOptions {
ConnectOptions::from_str(string)
}
}

impl From<&String> for ConnectOptions {
fn from(string: &String) -> ConnectOptions {
ConnectOptions::from_str(string.as_str())
}
}

impl From<String> for ConnectOptions {
fn from(string: String) -> ConnectOptions {
ConnectOptions::new(string)
impl<T> From<T> for ConnectOptions
where
T: Into<String>,
{
fn from(s: T) -> ConnectOptions {
ConnectOptions::new(s.into())
}
}

impl ConnectOptions {
/// Create new [ConnectOptions] for a [Database] by passing in a URI string
pub fn new(url: String) -> Self {
pub fn new<T>(url: T) -> Self
where
T: Into<String>,
{
Self {
url,
url: url.into(),
max_connections: None,
min_connections: None,
connect_timeout: None,
Expand All @@ -122,10 +116,6 @@ impl ConnectOptions {
}
}

fn from_str(url: &str) -> Self {
Self::new(url.to_owned())
}

#[cfg(feature = "sqlx-dep")]
/// Convert [ConnectOptions] into [sqlx::pool::PoolOptions]
pub fn pool_options<DB>(self) -> sqlx::pool::PoolOptions<DB>
Expand Down Expand Up @@ -258,8 +248,11 @@ impl ConnectOptions {
}

/// Set schema search path (PostgreSQL only)
pub fn set_schema_search_path(&mut self, schema_search_path: String) -> &mut Self {
self.schema_search_path = Some(schema_search_path);
pub fn set_schema_search_path<T>(&mut self, schema_search_path: T) -> &mut Self
where
T: Into<String>,
{
self.schema_search_path = Some(schema_search_path.into());
self
}
}
25 changes: 13 additions & 12 deletions src/database/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,33 @@ pub trait StatementBuilder {

impl Statement {
/// Create a [Statement] from a [crate::DatabaseBackend] and a raw SQL statement
pub fn from_string(db_backend: DbBackend, stmt: String) -> Statement {
pub fn from_string<T>(db_backend: DbBackend, stmt: T) -> Statement
where
T: Into<String>,
{
Statement {
sql: stmt,
sql: stmt.into(),
values: None,
db_backend,
}
}

/// Create a SQL statement from a [crate::DatabaseBackend], a
/// raw SQL statement and param values
pub fn from_sql_and_values<I>(db_backend: DbBackend, sql: &str, values: I) -> Self
pub fn from_sql_and_values<I, T>(db_backend: DbBackend, sql: T, values: I) -> Self
where
I: IntoIterator<Item = Value>,
T: Into<String>,
{
Self::from_string_values_tuple(
db_backend,
(sql.to_owned(), Values(values.into_iter().collect())),
)
Self::from_string_values_tuple(db_backend, (sql, Values(values.into_iter().collect())))
}

pub(crate) fn from_string_values_tuple(
db_backend: DbBackend,
stmt: (String, Values),
) -> Statement {
pub(crate) fn from_string_values_tuple<T>(db_backend: DbBackend, stmt: (T, Values)) -> Statement
where
T: Into<String>,
{
Statement {
sql: stmt.0,
sql: stmt.0.into(),
values: Some(stmt.1),
db_backend,
}
Expand Down
2 changes: 1 addition & 1 deletion src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ impl ConnectionTrait for DatabaseTransaction {
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => {
let db_backend = conn.get_database_backend();
let stmt = Statement::from_string(db_backend, sql.into());
let stmt = Statement::from_string(db_backend, sql);
conn.execute(stmt)
}
#[allow(unreachable_patterns)]
Expand Down
31 changes: 23 additions & 8 deletions src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr {
/// "SELECT `cake`.`id`, `cake`.`name` FROM `cake` WHERE `cake`.`name` LIKE 'cheese'"
/// );
/// ```
fn like(&self, s: &str) -> SimpleExpr {
fn like<T>(&self, s: T) -> SimpleExpr
where
T: Into<String>,
{
Expr::col((self.entity_name(), *self)).like(s)
}

Expand All @@ -164,7 +167,10 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr {
/// "SELECT `cake`.`id`, `cake`.`name` FROM `cake` WHERE `cake`.`name` NOT LIKE 'cheese'"
/// );
/// ```
fn not_like(&self, s: &str) -> SimpleExpr {
fn not_like<T>(&self, s: T) -> SimpleExpr
where
T: Into<String>,
{
Expr::col((self.entity_name(), *self)).not_like(s)
}

Expand All @@ -179,8 +185,11 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr {
/// "SELECT `cake`.`id`, `cake`.`name` FROM `cake` WHERE `cake`.`name` LIKE 'cheese%'"
/// );
/// ```
fn starts_with(&self, s: &str) -> SimpleExpr {
let pattern = format!("{s}%");
fn starts_with<T>(&self, s: T) -> SimpleExpr
where
T: Into<String>,
{
let pattern = format!("{}%", s.into());
Expr::col((self.entity_name(), *self)).like(pattern)
}

Expand All @@ -195,8 +204,11 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr {
/// "SELECT `cake`.`id`, `cake`.`name` FROM `cake` WHERE `cake`.`name` LIKE '%cheese'"
/// );
/// ```
fn ends_with(&self, s: &str) -> SimpleExpr {
let pattern = format!("%{s}");
fn ends_with<T>(&self, s: T) -> SimpleExpr
where
T: Into<String>,
{
let pattern = format!("%{}", s.into());
Expr::col((self.entity_name(), *self)).like(pattern)
}

Expand All @@ -211,8 +223,11 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr {
/// "SELECT `cake`.`id`, `cake`.`name` FROM `cake` WHERE `cake`.`name` LIKE '%cheese%'"
/// );
/// ```
fn contains(&self, s: &str) -> SimpleExpr {
let pattern = format!("%{s}%");
fn contains<T>(&self, s: T) -> SimpleExpr
where
T: Into<String>,
{
let pattern = format!("%{}%", s.into());
Expr::col((self.entity_name(), *self)).like(pattern)
}

Expand Down
2 changes: 1 addition & 1 deletion tests/common/features/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ pub async fn create_json_struct_table(db: &DbConn) -> Result<ExecResult, DbErr>
pub async fn create_collection_table(db: &DbConn) -> Result<ExecResult, DbErr> {
db.execute(sea_orm::Statement::from_string(
db.get_database_backend(),
"CREATE EXTENSION IF NOT EXISTS citext".into(),
"CREATE EXTENSION IF NOT EXISTS citext",
))
.await?;

Expand Down

0 comments on commit 6833529

Please sign in to comment.