diff --git a/src/errors.rs b/src/errors.rs index 301d841c..b750a852 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -148,6 +148,9 @@ pub enum ServiceError { #[display(fmt = "Database error.")] DatabaseError, + #[display(fmt = "Authentication error, please sign in")] + LoggedInUserNotFound, + // Begin tracker errors #[display(fmt = "Sorry, we have an error with our tracker connection.")] TrackerOffline, @@ -311,6 +314,7 @@ pub fn http_status_code_for_service_error(error: &ServiceError) -> StatusCode { ServiceError::TrackerUnknownResponse => StatusCode::INTERNAL_SERVER_ERROR, ServiceError::TorrentNotFoundInTracker => StatusCode::NOT_FOUND, ServiceError::InvalidTrackerToken => StatusCode::INTERNAL_SERVER_ERROR, + ServiceError::LoggedInUserNotFound => StatusCode::UNAUTHORIZED, } } diff --git a/src/web/api/server/v1/extractors/mod.rs b/src/web/api/server/v1/extractors/mod.rs index 36d737ca..2c55e042 100644 --- a/src/web/api/server/v1/extractors/mod.rs +++ b/src/web/api/server/v1/extractors/mod.rs @@ -1 +1,2 @@ pub mod bearer_token; +pub mod user_id; diff --git a/src/web/api/server/v1/extractors/user_id.rs b/src/web/api/server/v1/extractors/user_id.rs new file mode 100644 index 00000000..4ea81900 --- /dev/null +++ b/src/web/api/server/v1/extractors/user_id.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use axum::extract::{FromRef, FromRequestParts}; +use axum::http::request::Parts; +use axum::response::{IntoResponse, Response}; + +use super::bearer_token; +use crate::common::AppData; +use crate::errors::ServiceError; +use crate::models::user::UserId; + +pub struct ExtractLoggedInUser(pub UserId); + +#[async_trait] +impl FromRequestParts for ExtractLoggedInUser +where + Arc: FromRef, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let maybe_bearer_token = match bearer_token::Extract::from_request_parts(parts, state).await { + Ok(maybe_bearer_token) => maybe_bearer_token.0, + Err(_) => return Err(ServiceError::TokenNotFound.into_response()), + }; + + //Extracts the app state + let app_data = Arc::from_ref(state); + + match app_data.auth.get_user_id_from_bearer_token(&maybe_bearer_token).await { + Ok(user_id) => Ok(ExtractLoggedInUser(user_id)), + Err(_) => Err(ServiceError::LoggedInUserNotFound.into_response()), + } + } +}