Skip to content

Commit

Permalink
fix: simplify synchronisation to handle cloned request extensions
Browse files Browse the repository at this point in the history
This follows on from the upgrade to axum 0.7 / hyper 1.0 which
introduced the requirement for request extensions to implement `Clone`.

The previous implementation exibited undesirable behaviour if some
middleware (that ran after `axum_sqlx_tx::Layer`) held on to a clone of
the request extensions - specifically an `OverlappingExtractors` error
would then be thrown by any attempt to extract `Tx`. Note that in this
circumstance it's not actually possible for the "inspecting" middleware
to obtain the transaction since it cannot name the type of the extension
(`crate::extension::Extension`, previously `crate::tx::Lazy`) in order
to interact with it.

The fix involved simplifying the `Slot` synchronisation primitive and
the usage of it in the extension. We previously had a "chained" `Slot`
setup (e.g. a `Slot` containing another `Slot`) in order to share the
lazy transaction between the middleware future and the request extension
- the `Slot` remained in the middleware future stack while the `Lease`
was passed to the request extension. This was problematic because
`Lease` doesn't implement `Clone`, and it mustn't since the whole idea
is to synchronise access to a `Transaction` which itself cannot be
cloned. Now `Slot`s can be cloned, so the same `Slot` can be held in the
middleware future's stack and in the request extension. Moreover, an
arbitrary number of copies of the `Slot` can be held without preventing
another `Slot` from leasing (so long as there's no other lease) - just
like `Mutex`.

This still achieves the outcome we want because the only public
interface to the `Slot` is via `Tx::from_request_parts`, thus a conflict
can only occur due to overlapping use of the extractor.
  • Loading branch information
connec committed Dec 23, 2023
1 parent 9e9827f commit 3c48710
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 182 deletions.
99 changes: 99 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use sqlx::Transaction;

use crate::{
slot::{Lease, Slot},
Error, Marker, State,
};

#[derive(Debug)]
pub(crate) struct Extension<DB: Marker> {
slot: Slot<LazyTransaction<DB>>,
}

impl<DB: Marker> Extension<DB> {
pub(crate) fn new(state: State<DB>) -> Self {
let slot = Slot::new(LazyTransaction::new(state));
Self { slot }
}

pub(crate) async fn acquire(&self) -> Result<Lease<LazyTransaction<DB>>, Error> {
let mut tx = self.slot.lease().ok_or(Error::OverlappingExtractors)?;
tx.acquire().await?;
Ok(tx)
}

pub(crate) async fn commit(&self) -> Result<(), sqlx::Error> {
if let Some(tx) = self.slot.lease().map(|lease| lease.steal()) {
tx.commit().await
} else {
Ok(())
}
}
}

impl<DB: Marker> Clone for Extension<DB> {
fn clone(&self) -> Self {
Self {
slot: self.slot.clone(),
}
}
}

#[derive(Debug)]
pub(crate) struct LazyTransaction<DB: Marker>(LazyTransactionState<DB>);

#[derive(Debug)]
enum LazyTransactionState<DB: Marker> {
Unacquired {
state: State<DB>,
},
Acquired {
tx: Transaction<'static, DB::Driver>,
},
}

impl<DB: Marker> LazyTransaction<DB> {
fn new(state: State<DB>) -> Self {
Self(LazyTransactionState::Unacquired { state })
}

pub(crate) async fn acquire(&mut self) -> Result<(), sqlx::Error> {
match &self.0 {
LazyTransactionState::Unacquired { state } => {
let tx = state.transaction().await?;
self.0 = LazyTransactionState::Acquired { tx };
Ok(())
}
LazyTransactionState::Acquired { .. } => Ok(()),
}
}

pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
match self.0 {
LazyTransactionState::Unacquired { .. } => Ok(()),
LazyTransactionState::Acquired { tx } => tx.commit().await,
}
}
}

impl<DB: Marker> AsRef<Transaction<'static, DB::Driver>> for LazyTransaction<DB> {
fn as_ref(&self) -> &Transaction<'static, DB::Driver> {
match &self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
}
}
}

impl<DB: Marker> AsMut<Transaction<'static, DB::Driver>> for LazyTransaction<DB> {
fn as_mut(&mut self) -> &mut Transaction<'static, DB::Driver> {
match &mut self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
}
}
}
7 changes: 4 additions & 3 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bytes::Bytes;
use futures_core::future::BoxFuture;
use http_body::Body;

use crate::{tx::TxSlot, Marker, State};
use crate::{extension::Extension, Marker, State};

/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
///
Expand Down Expand Up @@ -109,15 +109,16 @@ where
}

fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone());
let ext = Extension::new(self.state.clone());
req.extensions_mut().insert(ext.clone());

let res = self.inner.call(req);

Box::pin(async move {
let res = res.await.unwrap(); // inner service is infallible

if !res.status().is_server_error() && !res.status().is_client_error() {
if let Err(error) = transaction.commit().await {
if let Err(error) = ext.commit().await {
return Ok(error.into().into_response());
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@

mod config;
mod error;
mod extension;
mod layer;
mod marker;
mod slot;
Expand Down
Loading

0 comments on commit 3c48710

Please sign in to comment.