Skip to content

Commit

Permalink
feat: auto detect db_driver from connect_url
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvandijke committed Sep 13, 2022
1 parent 3435ca6 commit 4bb814c
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 67 deletions.
3 changes: 0 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use config::{ConfigError, Config, File};
use std::path::Path;
use serde::{Serialize, Deserialize};
use tokio::sync::RwLock;
use crate::databases::database::DatabaseDriver;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Website {
Expand Down Expand Up @@ -50,7 +49,6 @@ pub struct Auth {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Database {
pub db_driver: DatabaseDriver,
pub connect_url: String,
pub torrent_info_update_interval: u64,
}
Expand Down Expand Up @@ -105,7 +103,6 @@ impl Configuration {
secret_key: "MaxVerstappenWC2021".to_string()
},
database: Database {
db_driver: DatabaseDriver::Sqlite3,
connect_url: "sqlite://data.db?mode=rwc".to_string(),
torrent_info_update_interval: 3600
},
Expand Down
106 changes: 51 additions & 55 deletions src/databases/database.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use async_trait::async_trait;
use chrono::{NaiveDateTime};
use serde::{Serialize, Deserialize};

use crate::databases::mysql::MysqlDatabase;
use crate::databases::sqlite::SqliteDatabase;
use crate::models::response::{TorrentsResponse};
Expand All @@ -9,25 +10,29 @@ use crate::models::torrent_file::{DbTorrentInfo, Torrent, TorrentFile};
use crate::models::tracker_key::TrackerKey;
use crate::models::user::{User, UserAuthentication, UserCompact, UserProfile};

/// Database drivers.
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum DatabaseDriver {
Sqlite3,
Mysql
}

/// Compact representation of torrent.
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct TorrentCompact {
pub torrent_id: i64,
pub info_hash: String,
}

/// Torrent category.
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct Category {
pub category_id: i64,
pub name: String,
pub num_torrents: i64
}

/// Sorting options for torrents.
#[derive(Clone, Copy, Debug, Deserialize)]
pub enum Sorting {
UploadedAsc,
Expand All @@ -42,9 +47,11 @@ pub enum Sorting {
SizeDesc,
}

/// Database errors.
#[derive(Debug)]
pub enum DatabaseError {
Error,
UnrecognizedDatabaseDriver, // when the db path does not start with sqlite or mysql
UsernameTaken,
EmailTaken,
UserNotFound,
Expand All @@ -55,127 +62,116 @@ pub enum DatabaseError {
TorrentTitleAlreadyExists,
}

pub async fn connect_database(db_driver: &DatabaseDriver, db_path: &str) -> Box<dyn Database> {
// match &db_path.chars().collect::<Vec<char>>() as &[char] {
// ['s', 'q', 'l', 'i', 't', 'e', ..] => {
// let db = SqliteDatabase::new(db_path).await;
// Ok(Box::new(db))
// }
// ['m', 'y', 's', 'q', 'l', ..] => {
// let db = MysqlDatabase::new(db_path).await;
// Ok(Box::new(db))
// }
// _ => {
// Err(())
// }
// }

match db_driver {
DatabaseDriver::Sqlite3 => {
/// Connect to a database.
pub async fn connect_database(db_path: &str) -> Result<Box<dyn Database>, DatabaseError> {
match &db_path.chars().collect::<Vec<char>>() as &[char] {
['s', 'q', 'l', 'i', 't', 'e', ..] => {
let db = SqliteDatabase::new(db_path).await;
Box::new(db)
Ok(Box::new(db))
}
DatabaseDriver::Mysql => {
['m', 'y', 's', 'q', 'l', ..] => {
let db = MysqlDatabase::new(db_path).await;
Box::new(db)
Ok(Box::new(db))
}
_ => {
Err(DatabaseError::UnrecognizedDatabaseDriver)
}
}
}

/// Trait for database implementations.
#[async_trait]
pub trait Database: Sync + Send {
// return current database driver
/// Return current database driver.
fn get_database_driver(&self) -> DatabaseDriver;

// add new user and get the newly inserted user_id
/// Add new user and return the newly inserted `user_id`.
async fn insert_user_and_get_id(&self, username: &str, email: &str, password: &str) -> Result<i64, DatabaseError>;

// get user profile by user_id
/// Get `User` from `user_id`.
async fn get_user_from_id(&self, user_id: i64) -> Result<User, DatabaseError>;

// get user authentication by user_id
/// Get `UserAuthentication` from `user_id`.
async fn get_user_authentication_from_id(&self, user_id: i64) -> Result<UserAuthentication, DatabaseError>;

// get user profile by username
/// Get `UserProfile` from `username`.
async fn get_user_profile_from_username(&self, username: &str) -> Result<UserProfile, DatabaseError>;

// get user compact by user_id
/// Get `UserCompact` from `user_id`.
async fn get_user_compact_from_id(&self, user_id: i64) -> Result<UserCompact, DatabaseError>;

// todo: change to get all tracker keys of user, no matter if they are still valid
// get a user's tracker key
/// Get a user's `TrackerKey`.
async fn get_user_tracker_key(&self, user_id: i64) -> Option<TrackerKey>;

// count users
/// Get total user count.
async fn count_users(&self) -> Result<i64, DatabaseError>;

// todo: make DateTime struct for the date_expiry
// ban user
/// Ban user with `user_id`, `reason` and `date_expiry`.
async fn ban_user(&self, user_id: i64, reason: &str, date_expiry: NaiveDateTime) -> Result<(), DatabaseError>;

// give a user administrator rights
/// Grant a user the administrator role.
async fn grant_admin_role(&self, user_id: i64) -> Result<(), DatabaseError>;

// verify email
/// Verify a user's email with `user_id`.
async fn verify_email(&self, user_id: i64) -> Result<(), DatabaseError>;

// create a new tracker key for a certain user
/// Link a `TrackerKey` to a certain user with `user_id`.
async fn add_tracker_key(&self, user_id: i64, tracker_key: &TrackerKey) -> Result<(), DatabaseError>;

// delete user
/// Delete user and all related user data with `user_id`.
async fn delete_user(&self, user_id: i64) -> Result<(), DatabaseError>;

// add new category
/// Add a new category and return `category_id`.
async fn insert_category_and_get_id(&self, category_name: &str) -> Result<i64, DatabaseError>;

// get category by id
async fn get_category_from_id(&self, id: i64) -> Result<Category, DatabaseError>;
/// Get `Category` from `category_id`.
async fn get_category_from_id(&self, category_id: i64) -> Result<Category, DatabaseError>;

// get category by name
async fn get_category_from_name(&self, category: &str) -> Result<Category, DatabaseError>;
/// Get `Category` from `category_name`.
async fn get_category_from_name(&self, category_name: &str) -> Result<Category, DatabaseError>;

// get all categories
/// Get all categories as `Vec<Category>`.
async fn get_categories(&self) -> Result<Vec<Category>, DatabaseError>;

// delete category
/// Delete category with `category_name`.
async fn delete_category(&self, category_name: &str) -> Result<(), DatabaseError>;

// get results of a torrent search in a paginated and sorted form
/// Get results of a torrent search in a paginated and sorted form as `TorrentsResponse` from `search`, `categories`, `sort`, `offset` and `page_size`.
async fn get_torrents_search_sorted_paginated(&self, search: &Option<String>, categories: &Option<Vec<String>>, sort: &Sorting, offset: u64, page_size: u8) -> Result<TorrentsResponse, DatabaseError>;

// add new torrent and get the newly inserted torrent_id
/// Add new torrent and return the newly inserted `torrent_id` with `torrent`, `uploader_id`, `category_id`, `title` and `description`.
async fn insert_torrent_and_get_id(&self, torrent: &Torrent, uploader_id: i64, category_id: i64, title: &str, description: &str) -> Result<i64, DatabaseError>;

// get torrent by id
/// Get `Torrent` from `torrent_id`.
async fn get_torrent_from_id(&self, torrent_id: i64) -> Result<Torrent, DatabaseError>;

// get torrent info by id
/// Get torrent's info as `DbTorrentInfo` from `torrent_id`.
async fn get_torrent_info_from_id(&self, torrent_id: i64) -> Result<DbTorrentInfo, DatabaseError>;

// get torrent files by id
/// Get all torrent's files as `Vec<TorrentFile>` from `torrent_id`.
async fn get_torrent_files_from_id(&self, torrent_id: i64) -> Result<Vec<TorrentFile>, DatabaseError>;

// get torrent announce urls by id
/// Get all torrent's announce urls as `Vec<Vec<String>>` from `torrent_id`.
async fn get_torrent_announce_urls_from_id(&self, torrent_id: i64) -> Result<Vec<Vec<String>>, DatabaseError>;

// get torrent listing by id
/// Get `TorrentListing` from `torrent_id`.
async fn get_torrent_listing_from_id(&self, torrent_id: i64) -> Result<TorrentListing, DatabaseError>;

// get all torrents (torrent_id + info_hash)
/// Get all torrents as `Vec<TorrentCompact>`.
async fn get_all_torrents_compact(&self) -> Result<Vec<TorrentCompact>, DatabaseError>;

// update a torrent's title
/// Update a torrent's title with `torrent_id` and `title`.
async fn update_torrent_title(&self, torrent_id: i64, title: &str) -> Result<(), DatabaseError>;

// update a torrent's description
/// Update a torrent's description with `torrent_id` and `description`.
async fn update_torrent_description(&self, torrent_id: i64, description: &str) -> Result<(), DatabaseError>;

// update the seeders and leechers info for a particular torrent
/// Update the seeders and leechers info for a torrent with `torrent_id`, `tracker_url`, `seeders` and `leechers`.
async fn update_tracker_info(&self, torrent_id: i64, tracker_url: &str, seeders: i64, leechers: i64) -> Result<(), DatabaseError>;

// delete a torrent
/// Delete a torrent with `torrent_id`.
async fn delete_torrent(&self, torrent_id: i64) -> Result<(), DatabaseError>;

// DELETES ALL DATABASE ROWS, ONLY CALL THIS IF YOU KNOW WHAT YOU'RE DOING!
/// DELETES ALL DATABASE ROWS, ONLY CALL THIS IF YOU KNOW WHAT YOU'RE DOING!
async fn delete_all_database_rows(&self) -> Result<(), DatabaseError>;
}
3 changes: 2 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ impl From<DatabaseError> for ServiceError {
DatabaseError::CategoryNotFound => ServiceError::InvalidCategory,
DatabaseError::TorrentNotFound => ServiceError::TorrentNotFound,
DatabaseError::TorrentAlreadyExists => ServiceError::InfoHashAlreadyExists,
DatabaseError::TorrentTitleAlreadyExists => ServiceError::TorrentTitleAlreadyExists
DatabaseError::TorrentTitleAlreadyExists => ServiceError::TorrentTitleAlreadyExists,
DatabaseError::UnrecognizedDatabaseDriver => ServiceError::InternalServerError,
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ async fn main() -> std::io::Result<()> {

let settings = cfg.settings.read().await;

let database = Arc::new(connect_database(&settings.database.db_driver, &settings.database.connect_url).await);
let database = Arc::new(connect_database(&settings.database.connect_url)
.await
.expect("Database error.")
);

let auth = Arc::new(AuthorizationService::new(cfg.clone(), database.clone()));
let tracker_service = Arc::new(TrackerService::new(cfg.clone(), database.clone()));
let mailer_service = Arc::new(MailerService::new(cfg.clone()).await);
Expand Down
10 changes: 7 additions & 3 deletions tests/databases/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::future::Future;
use torrust_index_backend::databases::database::{connect_database, Database, DatabaseDriver};
use torrust_index_backend::databases::database::{connect_database, Database};

mod mysql;
mod tests;
Expand All @@ -19,8 +19,12 @@ async fn run_test<'a, T, F>(db_fn: T, db: &'a Box<dyn Database>)
}

// runs all tests
pub async fn run_tests(db_driver: DatabaseDriver, db_path: &str) {
let db = connect_database(&db_driver, db_path).await;
pub async fn run_tests(db_path: &str) {
let db_res = connect_database(db_path).await;

assert!(db_res.is_ok());

let db = db_res.unwrap();

run_test(tests::it_can_add_a_user, &db).await;
run_test(tests::it_can_add_a_torrent_category, &db).await;
Expand Down
3 changes: 1 addition & 2 deletions tests/databases/mysql.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use torrust_index_backend::databases::database::{DatabaseDriver};
use crate::databases::{run_tests};

const DATABASE_URL: &str = "mysql://root:password@localhost:3306/torrust-index_test";

#[tokio::test]
async fn run_mysql_tests() {
run_tests(DatabaseDriver::Mysql, DATABASE_URL).await;
run_tests(DATABASE_URL).await;
}


3 changes: 1 addition & 2 deletions tests/databases/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use torrust_index_backend::databases::database::{DatabaseDriver};
use crate::databases::{run_tests};

const DATABASE_URL: &str = "sqlite::memory:";

#[tokio::test]
async fn run_sqlite_tests() {
run_tests(DatabaseDriver::Sqlite3, DATABASE_URL).await;
run_tests(DATABASE_URL).await;
}


0 comments on commit 4bb814c

Please sign in to comment.