Skip to content

Commit

Permalink
Merge pull request #19 from digital-society-coop/multi-db
Browse files Browse the repository at this point in the history
Add support for multiple database instances
  • Loading branch information
connec committed Nov 3, 2023
2 parents 5957c32 + 243fc03 commit 0325d41
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 214 deletions.
53 changes: 53 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::marker::PhantomData;

use crate::{Layer, Marker, State};

/// Configuration for [`Tx`](crate::Tx) extractors.
///
/// Use `Config` to configure and create a [`State`] and [`Layer`].
///
/// Access the `Config` API from [`Tx::config`](crate::Tx::config).
///
/// ```
/// # async fn foo() {
/// # let pool: sqlx::SqlitePool = todo!();
/// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
///
/// let config = Tx::config(pool);
/// # }
/// ```
pub struct Config<DB: Marker, LayerError> {
pool: sqlx::Pool<DB::Driver>,
_layer_error: PhantomData<LayerError>,
}

impl<DB: Marker, LayerError> Config<DB, LayerError>
where
LayerError: axum_core::response::IntoResponse,
sqlx::Error: Into<LayerError>,
{
pub(crate) fn new(pool: sqlx::Pool<DB::Driver>) -> Self {
Self {
pool,
_layer_error: PhantomData,
}
}

/// Change the layer error type.
pub fn layer_error<E>(self) -> Config<DB, E>
where
sqlx::Error: Into<E>,
{
Config {
pool: self.pool,
_layer_error: PhantomData,
}
}

/// Create a [`State`] and [`Layer`] to enable the [`Tx`](crate::Tx) extractor.
pub fn setup(self) -> (State<DB>, Layer<DB, LayerError>) {
let state = State::new(self.pool);
let layer = Layer::new(state.clone());
(state, layer)
}
}
105 changes: 105 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/// Possible errors when extracting [`Tx`] from a request.
///
/// Errors can occur at two points during the request lifecycle:
///
/// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a
/// transaction. This could be due to:
///
/// - Forgetting to add the middleware: [`Error::MissingExtension`].
/// - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`].
/// - A problem communicating with the database: [`Error::Database`].
///
/// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem
/// communicating with the database, or else a logic error (e.g. unsatisfied deferred
/// constraint): [`Error::Database`].
///
/// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a
/// HTTP 500 response with the error message as the response body. This may be suitable for
/// development or internal services but it's generally not advisable to return internal error
/// details to clients.
///
/// You can override the error types for both the [`Tx`] extractor and [`Layer`]:
///
/// - Override the [`Tx`]`<DB, E>` error type using the `E` generic type parameter. `E` must be
/// convertible from [`Error`] (e.g. [`Error`]`: Into<E>`).
///
/// - Override the [`Layer`] error type using [`Config::layer_error`](crate::Config::layer_error).
/// The layer error type must be convertible from `sqlx::Error` (e.g.
/// `sqlx::Error: Into<LayerError>`).
///
/// In both cases, the error type must implement `axum::response::IntoResponse`.
///
/// ```
/// use axum::{response::IntoResponse, routing::post};
///
/// enum MyError{
/// Extractor(axum_sqlx_tx::Error),
/// Layer(sqlx::Error),
/// }
///
/// impl From<axum_sqlx_tx::Error> for MyError {
/// fn from(error: axum_sqlx_tx::Error) -> Self {
/// Self::Extractor(error)
/// }
/// }
///
/// impl From<sqlx::Error> for MyError {
/// fn from(error: sqlx::Error) -> Self {
/// Self::Layer(error)
/// }
/// }
///
/// impl IntoResponse for MyError {
/// fn into_response(self) -> axum::response::Response {
/// // note that you would probably want to log the error as well
/// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
/// }
/// }
///
/// // Override the `Tx` error type using the second generic type parameter
/// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite, MyError>;
///
/// # async fn foo() {
/// let pool = sqlx::SqlitePool::connect("...").await.unwrap();
///
/// let (state, layer) = Tx::config(pool)
/// // Override the `Layer` error type using the `Config` API
/// .layer_error::<MyError>()
/// .setup();
/// # let app = axum::Router::new()
/// # .route("/", post(create_user))
/// # .layer(layer)
/// # .with_state(state);
/// # axum::Server::bind(todo!()).serve(app.into_make_service());
/// # }
/// # async fn create_user(mut tx: Tx, /* ... */) {
/// # /* ... */
/// # }
/// ```
///
/// [`Tx`]: crate::Tx
/// [`Layer`]: crate::Layer
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// Indicates that the [`Layer`](crate::Layer) middleware was not installed.
#[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")]
MissingExtension,

/// Indicates that [`Tx`](crate::Tx) was extracted multiple times in a single
/// handler/middleware.
#[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")]
OverlappingExtractors,

/// A database error occurred when starting or committing the transaction.
#[error(transparent)]
Database {
#[from]
error: sqlx::Error,
},
}

impl axum_core::response::IntoResponse for Error {
fn into_response(self) -> axum_core::response::Response {
(http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
16 changes: 8 additions & 8 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bytes::Bytes;
use futures_core::future::BoxFuture;
use http_body::{combinators::UnsyncBoxBody, Body};

use crate::{tx::TxSlot, State};
use crate::{tx::TxSlot, Marker, State};

/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
///
Expand All @@ -20,12 +20,12 @@ use crate::{tx::TxSlot, State};
///
/// [`Tx`]: crate::Tx
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
pub struct Layer<DB: sqlx::Database, E> {
pub struct Layer<DB: Marker, E> {
state: State<DB>,
_error: PhantomData<E>,
}

impl<DB: sqlx::Database, E> Layer<DB, E>
impl<DB: Marker, E> Layer<DB, E>
where
E: IntoResponse,
sqlx::Error: Into<E>,
Expand All @@ -38,7 +38,7 @@ where
}
}

impl<DB: sqlx::Database, E> Clone for Layer<DB, E> {
impl<DB: Marker, E> Clone for Layer<DB, E> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
Expand All @@ -47,7 +47,7 @@ impl<DB: sqlx::Database, E> Clone for Layer<DB, E> {
}
}

impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E>
impl<DB: Marker, S, E> tower_layer::Layer<S> for Layer<DB, E>
where
E: IntoResponse,
sqlx::Error: Into<E>,
Expand All @@ -66,14 +66,14 @@ where
/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
///
/// See [`Layer`] for more information.
pub struct Service<DB: sqlx::Database, S, E> {
pub struct Service<DB: Marker, S, E> {
state: State<DB>,
inner: S,
_error: PhantomData<E>,
}

// can't simply derive because `DB` isn't `Clone`
impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
impl<DB: Marker, S: Clone, E> Clone for Service<DB, S, E> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
Expand All @@ -83,7 +83,7 @@ impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
}
}

impl<DB: sqlx::Database, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
impl<DB: Marker, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
for Service<DB, S, E>
where
S: tower_service::Service<
Expand Down
Loading

0 comments on commit 0325d41

Please sign in to comment.