diff --git a/postgres-macros/tests/fail-nightly/enum_extra_variant.stderr b/postgres-macros/tests/fail-nightly/enum_extra_variant.stderr index a5524a0..bad7d20 100644 --- a/postgres-macros/tests/fail-nightly/enum_extra_variant.stderr +++ b/postgres-macros/tests/fail-nightly/enum_extra_variant.stderr @@ -4,7 +4,7 @@ error[E0277]: the trait bound `Vec: Query, 26 | let _: Vec = sql!("SELECT id, role FROM users").await.unwrap(); | -^^^^^ | || - | |the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is not implemented for `Vec` + | |the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is not implemented for `Vec`, which is required by `Sql<'_, Struct<(StructColumn, StructColumn, EnumVariant<"user">)>, "role">)>, _>: IntoFuture` | help: remove the `.await` | = help: the trait `Query, StructColumn, EnumVariant<"moderator">, EnumVariant<"user">)>, "role">)>>` is implemented for `Vec` diff --git a/postgres-macros/tests/fail-nightly/enum_missing_variant.stderr b/postgres-macros/tests/fail-nightly/enum_missing_variant.stderr index 4472d46..ee23a83 100644 --- a/postgres-macros/tests/fail-nightly/enum_missing_variant.stderr +++ b/postgres-macros/tests/fail-nightly/enum_missing_variant.stderr @@ -4,7 +4,7 @@ error[E0277]: the trait bound `Vec: Query, 20 | let _: Vec = sql!("SELECT id, role FROM users").await.unwrap(); | -^^^^^ | || - | |the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is not implemented for `Vec` + | |the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is not implemented for `Vec`, which is required by `Sql<'_, Struct<(StructColumn, StructColumn, EnumVariant<"user">)>, "role">)>, _>: IntoFuture` | help: remove the `.await` | = help: the trait `Query, StructColumn,)>, "role">)>>` is implemented for `Vec` diff --git a/postgres-macros/tests/fail-nightly/enum_variant_mismatch.stderr b/postgres-macros/tests/fail-nightly/enum_variant_mismatch.stderr index 5370ab9..908809a 100644 --- a/postgres-macros/tests/fail-nightly/enum_variant_mismatch.stderr +++ b/postgres-macros/tests/fail-nightly/enum_variant_mismatch.stderr @@ -4,7 +4,7 @@ error[E0277]: the trait bound `Vec: Query, 23 | let _: Vec = sql!("SELECT id, role FROM users").await.unwrap(); | -^^^^^ | || - | |the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is not implemented for `Vec` + | |the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is not implemented for `Vec`, which is required by `Sql<'_, Struct<(StructColumn, StructColumn, EnumVariant<"user">)>, "role">)>, _>: IntoFuture` | help: remove the `.await` | = help: the trait `Query, StructColumn, EnumVariant<"user">)>, "role">)>>` is implemented for `Vec` diff --git a/postgres-macros/tests/fail-stable/enum_extra_variant.stderr b/postgres-macros/tests/fail-stable/enum_extra_variant.stderr index e32ec85..2cfe016 100644 --- a/postgres-macros/tests/fail-stable/enum_extra_variant.stderr +++ b/postgres-macros/tests/fail-stable/enum_extra_variant.stderr @@ -2,9 +2,10 @@ error[E0277]: the trait bound `Vec: Query tests/fail-stable/enum_extra_variant.rs:26:59 | 26 | let _: Vec = sql!("SELECT id, role FROM users").await.unwrap(); - | -^^^^^ the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is not implemented for `Vec` - | | + | -^^^^^ + | || + | |the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is not implemented for `Vec` | help: remove the `.await` | - = help: the trait `Query>` is implemented for `Vec` + = help: the trait `Query, StructColumn, EnumVariant<16036746858103170191>, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is implemented for `Vec` = note: required for `Sql<'_, Struct<(StructColumn, StructColumn, 18137070463969723500>)>, ...>` to implement `IntoFuture` diff --git a/postgres-macros/tests/fail-stable/enum_missing_variant.stderr b/postgres-macros/tests/fail-stable/enum_missing_variant.stderr index 4828cd8..744d8d9 100644 --- a/postgres-macros/tests/fail-stable/enum_missing_variant.stderr +++ b/postgres-macros/tests/fail-stable/enum_missing_variant.stderr @@ -2,9 +2,10 @@ error[E0277]: the trait bound `Vec: Query tests/fail-stable/enum_missing_variant.rs:20:59 | 20 | let _: Vec = sql!("SELECT id, role FROM users").await.unwrap(); - | -^^^^^ the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is not implemented for `Vec` - | | + | -^^^^^ + | || + | |the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is not implemented for `Vec` | help: remove the `.await` | - = help: the trait `Query>` is implemented for `Vec` + = help: the trait `Query, StructColumn,)>, 18137070463969723500>)>>` is implemented for `Vec` = note: required for `Sql<'_, Struct<(StructColumn, StructColumn, 18137070463969723500>)>, ...>` to implement `IntoFuture` diff --git a/postgres-macros/tests/fail-stable/enum_variant_mismatch.stderr b/postgres-macros/tests/fail-stable/enum_variant_mismatch.stderr index 8ce342b..f6958e8 100644 --- a/postgres-macros/tests/fail-stable/enum_variant_mismatch.stderr +++ b/postgres-macros/tests/fail-stable/enum_variant_mismatch.stderr @@ -2,9 +2,10 @@ error[E0277]: the trait bound `Vec: Query tests/fail-stable/enum_variant_mismatch.rs:23:59 | 23 | let _: Vec = sql!("SELECT id, role FROM users").await.unwrap(); - | -^^^^^ the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is not implemented for `Vec` - | | + | -^^^^^ + | || + | |the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is not implemented for `Vec` | help: remove the `.await` | - = help: the trait `Query>` is implemented for `Vec` + = help: the trait `Query, StructColumn, EnumVariant<10465144470622129318>)>, 18137070463969723500>)>>` is implemented for `Vec` = note: required for `Sql<'_, Struct<(StructColumn, StructColumn, 18137070463969723500>)>, ...>` to implement `IntoFuture` diff --git a/postgres/src/connection.rs b/postgres/src/connection.rs new file mode 100644 index 0000000..4237401 --- /dev/null +++ b/postgres/src/connection.rs @@ -0,0 +1,167 @@ +// TODO: remove once Rust's async lifetime in trait story got improved +#![allow(clippy::manual_async_fn)] + +use std::future::Future; + +use deadpool_postgres::GenericClient; +use tokio_postgres::types::ToSql; +use tokio_postgres::Row; + +use crate::Error; + +pub trait Connection: Send + Sync { + fn query_one<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a; + + fn query_opt<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a; + + fn query<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a; + + fn execute<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a; +} + +impl Connection for deadpool_postgres::Client { + fn query_one<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + Ok(tokio_postgres::Client::query_one(self, &stmt, parameters).await?) + } + } + + fn query_opt<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + Ok(tokio_postgres::Client::query_opt(self, &stmt, parameters).await?) + } + } + + fn query<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + Ok(tokio_postgres::Client::query(self, &stmt, parameters).await?) + } + } + + fn execute<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + tokio_postgres::Client::execute(self, &stmt, parameters).await?; + Ok(()) + } + } +} + +impl<'t> Connection for deadpool_postgres::Transaction<'t> { + fn query_one<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + Ok(tokio_postgres::Transaction::query_one(self, &stmt, parameters).await?) + } + } + + fn query_opt<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + Ok(tokio_postgres::Transaction::query_opt(self, &stmt, parameters).await?) + } + } + + fn query<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + Ok(tokio_postgres::Transaction::query(self, &stmt, parameters).await?) + } + } + + fn execute<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a { + async move { + let stmt = self.prepare_cached(query).await?; + tokio_postgres::Transaction::execute(self, &stmt, parameters).await?; + Ok(()) + } + } +} + +impl<'b, C> Connection for &'b C +where + C: Connection, +{ + fn query_one<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a { + (*self).query_one(query, parameters) + } + + fn query_opt<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a { + (*self).query_opt(query, parameters) + } + + fn query<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future, Error>> + Send + 'a { + (*self).query(query, parameters) + } + + fn execute<'a>( + &'a self, + query: &'a str, + parameters: &'a [&'a (dyn ToSql + Sync)], + ) -> impl Future> + Send + 'a { + (*self).execute(query, parameters) + } +} diff --git a/postgres/src/future.rs b/postgres/src/future.rs index 6ccce09..ed34495 100644 --- a/postgres/src/future.rs +++ b/postgres/src/future.rs @@ -19,16 +19,33 @@ where type IntoFuture = SqlFuture<'a, T>; fn into_future(self) -> Self::IntoFuture { + SqlFuture::new(self) + } +} + +pub struct SqlFuture<'a, T> { + future: Pin> + Send + 'a>>, + marker: PhantomData<&'a ()>, +} + +impl<'a, T> SqlFuture<'a, T> { + pub fn new(sql: Sql<'a, Cols, T>) -> Self + where + T: Query + Send + Sync + 'a, + Cols: Send + Sync + 'a, + { let span = - tracing::debug_span!("sql query", query = self.query, parameters = ?self.parameters); + tracing::debug_span!("sql query", query = sql.query, parameters = ?sql.parameters); let start = Instant::now(); SqlFuture { future: Box::pin( + // Note: changes here must be applied to `with_connection` below too! async move { let mut i = 1; loop { - match T::query(&self).await { + let conn = super::connect().await?; + match T::query(&sql, &conn).await { Ok(r) => { let elapsed = start.elapsed(); tracing::trace!(?elapsed, "sql query finished"); @@ -55,11 +72,49 @@ where marker: PhantomData, } } -} -pub struct SqlFuture<'a, T> { - future: Pin> + Send + 'a>>, - marker: PhantomData<&'a ()>, + pub fn with_connection(sql: Sql<'a, Cols, T>, conn: impl super::Connection + 'a) -> Self + where + T: Query + Send + Sync + 'a, + Cols: Send + Sync + 'a, + { + let span = + tracing::debug_span!("sql query", query = sql.query, parameters = ?sql.parameters); + let start = Instant::now(); + + SqlFuture { + future: Box::pin( + // Note: changes here must be applied to `bew` above too! + async move { + let mut i = 1; + loop { + match T::query(&sql, &conn).await { + Ok(r) => { + let elapsed = start.elapsed(); + tracing::trace!(?elapsed, "sql query finished"); + return Ok(r); + } + Err(Error { + kind: ErrorKind::Postgres(err), + .. + }) if err.is_closed() && i <= 5 => { + // retry pool size + 1 times if connection is closed (might have + // received a closed one from the connection pool) + i += 1; + tracing::trace!("retry due to connection closed error"); + continue; + } + Err(err) => { + return Err(err); + } + } + } + } + .instrument(span), + ), + marker: PhantomData, + } + } } impl<'a, T> Future for SqlFuture<'a, T> { diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index f07198b..ce2b192 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -5,6 +5,7 @@ #[cfg(test)] extern crate self as sqlm_postgres; +mod connection; mod error; mod future; pub mod internal; @@ -17,8 +18,9 @@ use std::marker::PhantomData; use std::str::FromStr; use std::sync::Arc; -pub use deadpool_postgres::Transaction; -use deadpool_postgres::{ClientWrapper, Manager, ManagerConfig, Object, Pool, RecyclingMethod}; +pub use connection::Connection; +pub use deadpool_postgres::{Client, Transaction}; +use deadpool_postgres::{ClientWrapper, Manager, ManagerConfig, Pool, RecyclingMethod}; pub use error::Error; use error::ErrorKind; pub use future::SqlFuture; @@ -35,10 +37,8 @@ pub use types::SqlType; static POOL: OnceCell = OnceCell::new(); -pub type Connection = Object; - #[tracing::instrument] -pub async fn connect() -> Result { +pub async fn connect() -> Result { // Don't trace connect, as this would create an endless loop of connecting again and // again when persisting the connect trace! let pool = POOL.get_or_try_init(|| { @@ -90,71 +90,12 @@ pub struct Sql<'a, Cols, T> { } impl<'a, Cols, T> Sql<'a, Cols, T> { - pub fn with(mut self, tx: &'a ClientWrapper) -> Self { - self.connection = Some(tx); - self - } - - pub fn with_transaction(mut self, tx: &'a Transaction<'a>) -> Self { - self.transaction = Some(tx); - self - } - - async fn query_one(&self) -> Result { - if let Some(tx) = self.transaction { - let stmt = tx.prepare_cached(self.query).await?; - Ok(tx.query_one(&stmt, self.parameters).await?) - } else if let Some(conn) = self.connection { - let stmt = conn.prepare_cached(self.query).await?; - Ok(conn.query_one(&stmt, self.parameters).await?) - } else { - let conn = connect().await?; - let stmt = conn.prepare_cached(self.query).await?; - Ok(conn.query_one(&stmt, self.parameters).await?) - } - } - - async fn query_opt(&self) -> Result, Error> { - if let Some(tx) = self.transaction { - let stmt = tx.prepare_cached(self.query).await?; - Ok(tx.query_opt(&stmt, self.parameters).await?) - } else if let Some(conn) = self.connection { - let stmt = conn.prepare_cached(self.query).await?; - Ok(conn.query_opt(&stmt, self.parameters).await?) - } else { - let conn = connect().await?; - let stmt = conn.prepare_cached(self.query).await?; - Ok(conn.query_opt(&stmt, self.parameters).await?) - } - } - - async fn query(&self) -> Result, Error> { - if let Some(tx) = self.transaction { - let stmt = tx.prepare_cached(self.query).await?; - Ok(tx.query(&stmt, self.parameters).await?) - } else if let Some(conn) = self.connection { - let stmt = conn.prepare_cached(self.query).await?; - Ok(conn.query(&stmt, self.parameters).await?) - } else { - let conn = connect().await?; - let stmt = conn.prepare_cached(self.query).await?; - Ok(conn.query(&stmt, self.parameters).await?) - } - } - - async fn execute(&self) -> Result<(), Error> { - if let Some(tx) = self.transaction { - let stmt = tx.prepare_cached(self.query).await?; - tx.execute(&stmt, self.parameters).await?; - } else if let Some(conn) = self.connection { - let stmt = conn.prepare_cached(self.query).await?; - conn.execute(&stmt, self.parameters).await?; - } else { - let conn = connect().await?; - let stmt = conn.prepare_cached(self.query).await?; - conn.execute(&stmt, self.parameters).await?; - } - Ok(()) + pub fn run_with(self, conn: impl Connection + 'a) -> SqlFuture<'a, T> + where + T: Query + Send + Sync + 'a, + Cols: Send + Sync + 'a, + { + SqlFuture::with_connection(self, conn) } } diff --git a/postgres/src/query.rs b/postgres/src/query.rs index 812112e..75093aa 100644 --- a/postgres/src/query.rs +++ b/postgres/src/query.rs @@ -9,6 +9,7 @@ use crate::{Error, FromRow, Sql}; pub trait Query: Sized { fn query<'a>( sql: &'a Sql<'a, Cols, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>>; } @@ -19,8 +20,9 @@ where { fn query<'a>( sql: &'a Sql<'a, Primitive, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> { - T::query_literal(sql) + T::query_literal(sql, conn) } } @@ -31,9 +33,10 @@ where { fn query<'a>( sql: &'a Sql<'a, Struct, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> { Box::pin(async move { - let row = sql.query_one().await?; + let row = conn.query_one(sql.query, sql.parameters).await?; Ok(FromRow::>::from_row(row.into())?) }) } @@ -46,9 +49,10 @@ where { fn query<'a>( sql: &'a Sql<'a, Struct, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> { Box::pin(async move { - let row = sql.query_opt().await?; + let row = conn.query_opt(sql.query, sql.parameters).await?; match row { Some(row) => Ok(Some(FromRow::>::from_row(row.into())?)), None => Ok(None), @@ -64,9 +68,10 @@ where { fn query<'a>( sql: &'a Sql<'a, Struct, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> { Box::pin(async move { - let rows = sql.query().await?; + let rows = conn.query(sql.query, sql.parameters).await?; rows.into_iter() .map(|row| FromRow::>::from_row(row.into()).map_err(Error::from)) .collect() @@ -77,9 +82,10 @@ where impl Query<()> for () { fn query<'a>( sql: &'a Sql<'a, (), Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> { Box::pin(async move { - sql.execute().await?; + conn.execute(sql.query, sql.parameters).await?; Ok(()) }) } diff --git a/postgres/src/types.rs b/postgres/src/types.rs index f00da24..1cfbcd6 100644 --- a/postgres/src/types.rs +++ b/postgres/src/types.rs @@ -11,13 +11,14 @@ pub trait SqlType { fn query_literal<'a>( sql: &'a Sql<'a, Primitive, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> where Self: FromSqlOwned + ToSql + Send + Sync, Self::Type: Send + Sync, { Box::pin(async move { - let row = sql.query_one().await?; + let row = conn.query_one(sql.query, sql.parameters).await?; Ok(row.try_get(0)?) }) } @@ -47,13 +48,14 @@ where fn query_literal<'a>( sql: &'a Sql<'a, Primitive, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> where Self: FromSqlOwned + ToSql + Send + Sync, Self::Type: Send + Sync, { Box::pin(async move { - let row = sql.query_opt().await?; + let row = conn.query_opt(sql.query, sql.parameters).await?; match row { Some(row) => Ok(row.try_get::<'_, _, Option>(0)?), None => Ok(None), @@ -70,13 +72,14 @@ where fn query_literal<'a>( sql: &'a Sql<'a, Primitive, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> where Self: FromSqlOwned + ToSql + Send + Sync, Self::Type: Send + Sync, { Box::pin(async move { - let row = sql.query_one().await?; + let row = conn.query_one(sql.query, sql.parameters).await?; Ok(row.try_get(0)?) }) } @@ -88,13 +91,14 @@ impl SqlType for Vec { fn query_literal<'a>( sql: &'a Sql<'a, Primitive, Self>, + conn: impl super::Connection + 'a, ) -> Pin> + Send + 'a>> where Self: FromSqlOwned + ToSql + Send + Sync, Self::Type: Send + Sync, { Box::pin(async move { - let row = sql.query_one().await?; + let row = conn.query_one(sql.query, sql.parameters).await?; Ok(row.try_get(0)?) }) }