Skip to content
This repository has been archived by the owner on Apr 4, 2024. It is now read-only.

Commit

Permalink
Add authentication
Browse files Browse the repository at this point in the history
Add endpoints `/accounts/register` & `/accounts/login`

When you register or login successfully you receive an auth token.
This needs to be sent in the X-Auth-Token header in authenticated
endpoints.
To make an endpoint authenticated add the Authentication extractor to
the endpoint.
This will make the id of the account available to the endpoint.
If the authentication token is invalid the request terminates with an
error and the endpoint handler is never executed.
  • Loading branch information
mads256h committed Apr 3, 2024
1 parent 5b08cc0 commit 1eb9503
Show file tree
Hide file tree
Showing 13 changed files with 344 additions and 84 deletions.
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ edition = "2021"

[dependencies]
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.8.0", features = ["v4"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }
axum = { version = "0.7", features = ["macros"] }
tokio = { version = "1.36", features = ["full"] }
tower = "0.4"
sqlx = { version = "0.7", features = ["sqlite", "macros", "migrate", "runtime-tokio", "chrono"] }
sqlx = { version = "0.7", features = ["sqlite", "macros", "migrate", "runtime-tokio", "chrono", "uuid"] }
serde = "1.0"
serde_with = "3.7"
dotenv = "0.15"
argon2 = { version = "0.5", features = ["std"] }
8 changes: 0 additions & 8 deletions migrations/0001_create_tasks_table.sql

This file was deleted.

22 changes: 22 additions & 0 deletions migrations/0001_initial_schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
CREATE TABLE Accounts(
id INTEGER PRIMARY KEY NOT NULL,
username VARCHAR(255) NOT NULL,
password_hash VARCHAR(255) NOT NULL,
UNIQUE(username)
);

CREATE TABLE AuthTokens(
id VARCHAR(64) PRIMARY KEY NOT NULL,
account_id INTEGER NOT NULL
REFERENCES Accounts(id) ON DELETE CASCADE
);

CREATE TABLE Tasks(
id INTEGER PRIMARY KEY NOT NULL,
timespan_start DATETIME NOT NULL,
timespan_end DATETIME NOT NULL,
duration INTEGER NOT NULL,
effect REAL NOT NULL,
account_id INTEGER NOT NULL
REFERENCES Accounts(id) ON DELETE CASCADE
);
92 changes: 92 additions & 0 deletions src/extractors/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, HeaderMap, StatusCode},
};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use uuid::Uuid;

#[derive(Deserialize, Serialize, sqlx::Type)]
#[sqlx(transparent)]
pub struct AuthToken(Uuid);

impl AuthToken {
fn new() -> Self {
AuthToken(Uuid::new_v4())
}

fn try_parse(input: &str) -> Result<AuthToken, uuid::Error> {
let uuid = Uuid::try_parse(input)?;
Ok(AuthToken(uuid))
}
}

// Account id
pub struct Authentication(pub i64);

#[async_trait]
impl FromRequestParts<SqlitePool> for Authentication {
type Rejection = (StatusCode, String);

async fn from_request_parts(
parts: &mut Parts,
pool: &SqlitePool,
) -> Result<Self, Self::Rejection> {
match get_auth_token(&parts.headers) {
Some(token) => {
if let Some(account_id) = get_account_id_from_token(token, pool).await {
Ok(Authentication(account_id))
} else {
Err((
StatusCode::UNAUTHORIZED,
"Auth token is not in the database".to_string(),
))
}
}
_ => Err((
StatusCode::UNAUTHORIZED,
"Auth token invalid or missing".to_string(),
)),
}
}
}

fn get_auth_token(headers: &HeaderMap) -> Option<AuthToken> {
let string = headers.get("X-Auth-Token")?.to_str().ok()?;
AuthToken::try_parse(string).ok()
}

async fn get_account_id_from_token(token: AuthToken, pool: &SqlitePool) -> Option<i64> {
sqlx::query_scalar!(
r#"
SELECT account_id
FROM AuthTokens
WHERE id = ?
"#,
token
)
.fetch_optional(pool)
.await
.ok()?
}

pub async fn create_auth_token(
account_id: i64,
pool: &SqlitePool,
) -> Result<AuthToken, sqlx::Error> {
let auth_token = AuthToken::new();

sqlx::query!(
r#"
INSERT INTO AuthTokens (id, account_id)
VALUES (?, ?)
"#,
auth_token,
account_id
)
.execute(pool)
.await?;

Ok(auth_token)
}
1 change: 1 addition & 0 deletions src/extractors/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod auth;
82 changes: 82 additions & 0 deletions src/handlers/accounts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
};
use axum::{debug_handler, extract::State, http::StatusCode, Json};
use sqlx::SqlitePool;

use crate::{
extractors::auth::create_auth_token,
handlers::util::internal_error,
protocol::accounts::{RegisterOrLoginRequest, RegisterOrLoginResponse},
};

#[debug_handler]
pub async fn register_account(
State(pool): State<SqlitePool>,
Json(register_request): Json<RegisterOrLoginRequest>,
) -> Result<Json<RegisterOrLoginResponse>, (StatusCode, String)> {
let password = register_request.password;
let password_bytes = password.as_bytes();
let salt = SaltString::generate(&mut OsRng);

let argon2 = Argon2::default();

let password_hash = argon2
.hash_password(password_bytes, &salt)
.map_err(internal_error)?
.to_string();

let account_id = sqlx::query_scalar!(
r#"
INSERT INTO Accounts (username, password_hash)
VALUES (?, ?)
RETURNING id
"#,
register_request.username,
password_hash
)
.fetch_one(&pool)
.await
.map_err(internal_error)?;

let auth_token = create_auth_token(account_id, &pool)
.await
.map_err(internal_error)?;

Ok(Json(RegisterOrLoginResponse { auth_token }))
}

#[debug_handler]
pub async fn login_to_account(
State(pool): State<SqlitePool>,
Json(login_request): Json<RegisterOrLoginRequest>,
) -> Result<Json<RegisterOrLoginResponse>, (StatusCode, String)> {
let account = sqlx::query!(
r#"
SELECT id, password_hash
FROM Accounts
WHERE username = ?
"#,
login_request.username
)
.fetch_optional(&pool)
.await
.map_err(internal_error)?;

let account = account.ok_or((
StatusCode::NOT_FOUND,
"No account with username exists".to_string(),
))?;

let password_hash = PasswordHash::new(&account.password_hash).map_err(internal_error)?;

Argon2::default()
.verify_password(login_request.password.as_bytes(), &password_hash)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid password".to_string()))?;

let auth_token = create_auth_token(account.id, &pool)
.await
.map_err(internal_error)?;
Ok(Json(RegisterOrLoginResponse { auth_token }))
}
3 changes: 3 additions & 0 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod accounts;
pub mod tasks;
pub mod util;
65 changes: 65 additions & 0 deletions src/handlers/tasks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use axum::{debug_handler, extract::State, http::StatusCode, Json};
use sqlx::SqlitePool;

use crate::{
data_model::{task::Task, time::Timespan},
extractors::auth::Authentication,
handlers::util::internal_error,
};

#[debug_handler]
pub async fn get_tasks(
State(pool): State<SqlitePool>,
Authentication(account_id): Authentication,
) -> Result<Json<Vec<Task>>, (StatusCode, String)> {
let tasks = sqlx::query!(
r#"
SELECT id, timespan_start, timespan_end, duration, effect
FROM Tasks
WHERE account_id = ?
"#,
account_id
)
.fetch_all(&pool)
.await
.map_err(internal_error)?;

let my_tasks = tasks
.iter()
.map(|t| Task {
id: t.id,
timespan: Timespan::new_from_naive(t.timespan_start, t.timespan_end),
duration: t.duration.into(),
effect: t.effect,
})
.collect();

Ok(Json(my_tasks))
}

#[debug_handler]
pub async fn create_task(
State(pool): State<SqlitePool>,
Authentication(account_id): Authentication,
Json(mut task): Json<Task>,
) -> Result<Json<Task>, (StatusCode, String)> {
let id = sqlx::query_scalar!(
r#"
INSERT INTO Tasks (timespan_start, timespan_end, duration, effect, account_id)
VALUES (?, ?, ?, ?, ?)
RETURNING id
"#,
task.timespan.start,
task.timespan.end,
task.duration,
task.effect,
account_id
)
.fetch_one(&pool)
.await
.map_err(internal_error)?;

task.id = id;

Ok(Json(task))
}
Loading

0 comments on commit 1eb9503

Please sign in to comment.