Skip to content

Commit

Permalink
encapsulate transaction depth inside transaction instannce
Browse files Browse the repository at this point in the history
The depth tracking can be encapsulated entirely inside a transaction
instance, simplifying the code significantly.
  • Loading branch information
LucianBuzzo committed Sep 26, 2024
1 parent 71d3bd8 commit a220cf5
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 128 deletions.
13 changes: 2 additions & 11 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ use futures::lock::Mutex;
use std::{
convert::TryFrom,
future::Future,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tiberius::*;
Expand All @@ -48,11 +45,7 @@ impl TransactionCapable for Mssql {
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));

let opts = TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);
let opts = TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
Expand All @@ -65,7 +58,6 @@ pub struct Mssql {
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

impl Mssql {
Expand Down Expand Up @@ -97,7 +89,6 @@ impl Mssql {
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
};

if let Some(isolation) = this.url.transaction_isolation_level() {
Expand Down
7 changes: 1 addition & 6 deletions quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ use mysql_async::{
};
use std::{
future::Future,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -79,7 +76,6 @@ pub struct Mysql {
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
statement_cache: Mutex<LruCache<String, my::Statement>>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl Mysql {
Expand All @@ -93,7 +89,6 @@ impl Mysql {
statement_cache: Mutex::new(url.cache()),
url,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down
7 changes: 1 addition & 6 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ use std::{
fmt::{Debug, Display},
fs,
future::Future,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tokio_postgres::{config::ChannelBinding, Client, Config, Statement};
Expand Down Expand Up @@ -64,7 +61,6 @@ pub struct PostgreSql {
is_healthy: AtomicBool,
is_cockroachdb: bool,
is_materialize: bool,
transaction_depth: Arc<Mutex<i32>>,
}

/// Key uniquely representing an SQL statement in the prepared statements cache.
Expand Down Expand Up @@ -293,7 +289,6 @@ impl PostgreSql {
is_healthy: AtomicBool::new(true),
is_cockroachdb,
is_materialize,
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down
6 changes: 1 addition & 5 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ macro_rules! impl_default_TransactionCapable {
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> {
let opts = crate::connector::TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, opts).await?,
Expand Down
9 changes: 2 additions & 7 deletions quaint/src/connector/sqlite/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
visitor::{self, Visitor},
};
use async_trait::async_trait;
use std::{convert::TryFrom, sync::Arc};
use std::convert::TryFrom;
use tokio::sync::Mutex;

/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature.
Expand All @@ -27,7 +27,6 @@ pub use rusqlite;
/// A connector interface for the SQLite database
pub struct Sqlite {
pub(crate) client: Mutex<rusqlite::Connection>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl TryFrom<&str> for Sqlite {
Expand Down Expand Up @@ -65,10 +64,7 @@ impl TryFrom<&str> for Sqlite {

let client = Mutex::new(conn);

Ok(Sqlite {
client,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
Ok(Sqlite { client })
}
}

Expand All @@ -83,7 +79,6 @@ impl Sqlite {

Ok(Sqlite {
client: Mutex::new(client),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down
71 changes: 41 additions & 30 deletions quaint/src/connector/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use crate::{
error::{Error, ErrorKind},
};
use async_trait::async_trait;
use futures::lock::Mutex;
use metrics::{decrement_gauge, increment_gauge};
use std::{fmt, str::FromStr, sync::Arc};
use std::{
fmt,
str::FromStr,
sync::{Arc, Mutex},
};

extern crate metrics as metrics;

Expand All @@ -31,9 +34,6 @@ pub(crate) struct TransactionOptions {

/// Whether or not to put the isolation level `SET` before or after the `BEGIN`.
pub(crate) isolation_first: bool,

/// The depth of the transaction, used to determine the nested transaction statements.
pub depth: Arc<Mutex<i32>>,
}

/// A default representation of an SQL database transaction. If not commited, a
Expand All @@ -53,7 +53,7 @@ impl<'a> DefaultTransaction<'a> {
) -> crate::Result<DefaultTransaction<'a>> {
let mut this = Self {
inner,
depth: tx_opts.depth,
depth: Arc::new(Mutex::new(0)),
};

if tx_opts.isolation_first {
Expand Down Expand Up @@ -81,14 +81,13 @@ impl<'a> Transaction for DefaultTransaction<'a> {
async fn begin(&mut self) -> crate::Result<()> {
increment_gauge!("prisma_client_queries_active", 1.0);

let mut depth_guard = self.depth.lock().await;

// Modify the depth value through the MutexGuard
*depth_guard += 1;

let st_depth = *depth_guard;
let current_depth = {
let mut depth = self.depth.lock().unwrap();
*depth += 1;
*depth
};

let begin_statement = self.inner.begin_statement(st_depth).await;
let begin_statement = self.inner.begin_statement(current_depth).await;

self.inner.raw_cmd(&begin_statement).await?;

Expand All @@ -99,36 +98,49 @@ impl<'a> Transaction for DefaultTransaction<'a> {
async fn commit(&mut self) -> crate::Result<i32> {
decrement_gauge!("prisma_client_queries_active", 1.0);

let mut depth_guard = self.depth.lock().await;

let st_depth = *depth_guard;

let commit_statement = self.inner.commit_statement(st_depth).await;
// Lock the mutex and get the depth value
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

// Perform the asynchronous operation without holding the lock
let commit_statement = self.inner.commit_statement(depth_val).await;
self.inner.raw_cmd(&commit_statement).await?;

// Modify the depth value through the MutexGuard
*depth_guard -= 1;
// Lock the mutex again to modify the depth
let new_depth = {
let mut depth = self.depth.lock().unwrap();
*depth -= 1;
*depth
};

Ok(*depth_guard)
Ok(new_depth)
}

/// Rolls back the changes to the database.
async fn rollback(&mut self) -> crate::Result<i32> {
decrement_gauge!("prisma_client_queries_active", 1.0);

let mut depth_guard = self.depth.lock().await;

let st_depth = *depth_guard;
// Lock the mutex and get the depth value
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

let rollback_statement = self.inner.rollback_statement(st_depth).await;
// Perform the asynchronous operation without holding the lock
let rollback_statement = self.inner.rollback_statement(depth_val).await;

self.inner.raw_cmd(&rollback_statement).await?;

// Modify the depth value through the MutexGuard
*depth_guard -= 1;
// Lock the mutex again to modify the depth
let new_depth = {
let mut depth = self.depth.lock().unwrap();
*depth -= 1;
*depth
};

Ok(*depth_guard)
Ok(new_depth)
}

fn as_queryable(&self) -> &dyn Queryable {
Expand Down Expand Up @@ -240,11 +252,10 @@ impl FromStr for IsolationLevel {
}
}
impl TransactionOptions {
pub fn new(isolation_level: Option<IsolationLevel>, isolation_first: bool, depth: Arc<Mutex<i32>>) -> Self {
pub fn new(isolation_level: Option<IsolationLevel>, isolation_first: bool) -> Self {
Self {
isolation_level,
isolation_first,
depth,
}
}
}
5 changes: 1 addition & 4 deletions quaint/src/pooled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,7 @@ impl Quaint {
}
};

Ok(PooledConnection {
inner,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
Ok(PooledConnection { inner })
}

/// Info about the connection and underlying database.
Expand Down
3 changes: 0 additions & 3 deletions quaint/src/pooled/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@ use crate::{
error::Error,
};
use async_trait::async_trait;
use futures::lock::Mutex;
use mobc::{Connection as MobcPooled, Manager};
use std::sync::Arc;

/// A connection from the pool. Implements
/// [Queryable](connector/trait.Queryable.html).
pub struct PooledConnection {
pub(crate) inner: MobcPooled<QuaintManager>,
pub transaction_depth: Arc<Mutex<i32>>,
}

impl_default_TransactionCapable!(PooledConnection);
Expand Down
9 changes: 1 addition & 8 deletions quaint/src/single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable},
};
use async_trait::async_trait;
use futures::lock::Mutex;
use std::{fmt, sync::Arc};

#[cfg(feature = "sqlite-native")]
Expand All @@ -19,7 +18,6 @@ use crate::connector::NativeConnectionInfo;
pub struct Quaint {
inner: Arc<dyn Queryable>,
connection_info: Arc<ConnectionInfo>,
transaction_depth: Arc<Mutex<i32>>,
}

impl fmt::Debug for Quaint {
Expand Down Expand Up @@ -167,11 +165,7 @@ impl Quaint {
let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?);
Self::log_start(&connection_info);

Ok(Self {
inner,
connection_info,
transaction_depth: Arc::new(Mutex::new(0)),
})
Ok(Self { inner, connection_info })
}

#[cfg(feature = "sqlite-native")]
Expand All @@ -184,7 +178,6 @@ impl Quaint {
connection_info: Arc::new(ConnectionInfo::Native(NativeConnectionInfo::InMemorySqlite {
db_name: DEFAULT_SQLITE_DATABASE.to_owned(),
})),
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down
9 changes: 5 additions & 4 deletions quaint/src/tests/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> {
assert_eq!(Value::int32(10), res[0]);

// Check that nested transactions are also rolled back, even at multiple levels deep
let mut tx_inner = api.conn().start_transaction(None).await?;
tx.begin().await?;
let inner_insert1 = Insert::single_into(&table).value("value", 20);
let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?;
assert_eq!(1, inner_rows_affected1);

let mut tx_inner2 = api.conn().start_transaction(None).await?;
// Open another nested transaction
tx.begin().await?;
let inner_insert2 = Insert::single_into(&table).value("value", 20);
let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?;
assert_eq!(1, inner_rows_affected2);
tx_inner2.commit().await?;
tx.commit().await?;

tx_inner.commit().await?;
tx.commit().await?;

tx.rollback().await?;

Expand Down
Loading

0 comments on commit a220cf5

Please sign in to comment.