Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ r2d2 = "0.8.8"
rand = "0.8.4"
env_logger = "0.9.0"
config = "0.11"
derive_more = "0.99"
4 changes: 4 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct UdpTrackerConfig {

#[derive(Serialize, Deserialize)]
pub struct HttpTrackerConfig {
pub enabled: bool,
pub bind_address: String,
pub announce_interval: u32,
pub ssl_enabled: bool,
Expand All @@ -34,6 +35,7 @@ impl HttpTrackerConfig {

#[derive(Serialize, Deserialize)]
pub struct HttpApiConfig {
pub enabled: bool,
pub bind_address: String,
pub access_tokens: HashMap<String, String>,
}
Expand Down Expand Up @@ -124,13 +126,15 @@ impl Configuration {
announce_interval: 120,
},
http_tracker: Option::from(HttpTrackerConfig {
enabled: false,
bind_address: String::from("0.0.0.0:7878"),
announce_interval: 120,
ssl_enabled: false,
ssl_cert_path: None,
ssl_key_path: None
}),
http_api: Option::from(HttpApiConfig {
enabled: true,
bind_address: String::from("127.0.0.1:1212"),
access_tokens: [(String::from("admin"), String::from("MyAccessToken"))].iter().cloned().collect(),
}),
Expand Down
40 changes: 17 additions & 23 deletions src/database.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,31 @@
use crate::{InfoHash, AUTH_KEY_LENGTH};
use log::debug;
use std::sync::Arc;
use r2d2_sqlite::{SqliteConnectionManager, rusqlite};
use r2d2::{Pool};
use r2d2_sqlite::rusqlite::NO_PARAMS;
use crate::key_manager::AuthKey;
use std::str::FromStr;

pub struct SqliteDatabase {
pool: Arc<Pool<SqliteConnectionManager>>
pool: Pool<SqliteConnectionManager>
}

impl SqliteDatabase {
pub async fn new(db_path: &str) -> Option<SqliteDatabase> {
pub fn new(db_path: &str) -> Result<SqliteDatabase, rusqlite::Error> {
let sqlite_connection_manager = SqliteConnectionManager::file(db_path);
let sqlite_pool = r2d2::Pool::new(sqlite_connection_manager)
.expect("Failed to create r2d2 SQLite connection pool.");
let pool_arc = Arc::new(sqlite_pool);

match SqliteDatabase::create_database_tables(pool_arc.clone()) {
Ok(_) => {
Some(SqliteDatabase {
pool: pool_arc.clone()
})
}
Err(_) => {
eprintln!("Could not create database tables.");
None
}
}
let sqlite_pool = r2d2::Pool::new(sqlite_connection_manager).expect("Failed to create r2d2 SQLite connection pool.");
let sqlite_database = SqliteDatabase {
pool: sqlite_pool
};

if let Err(error) = SqliteDatabase::create_database_tables(&sqlite_database.pool) {
return Err(error)
};

Ok(sqlite_database)
}

pub fn create_database_tables(pool: Arc<Pool<SqliteConnectionManager>>) -> Result<usize, rusqlite::Error> {
pub fn create_database_tables(pool: &Pool<SqliteConnectionManager>) -> Result<usize, rusqlite::Error> {
let create_whitelist_table = "
CREATE TABLE IF NOT EXISTS whitelist (
id integer PRIMARY KEY AUTOINCREMENT,
Expand Down Expand Up @@ -106,10 +100,10 @@ impl SqliteDatabase {
}
}

pub async fn get_key_from_keys(&self, key: String) -> Result<AuthKey, rusqlite::Error> {
pub async fn get_key_from_keys(&self, key: &str) -> Result<AuthKey, rusqlite::Error> {
let conn = self.pool.get().unwrap();
let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?;
let mut rows = stmt.query(&[key])?;
let mut rows = stmt.query(&[key.to_string()])?;

if let Some(row) = rows.next()? {
let key: String = row.get(0).unwrap();
Expand All @@ -124,10 +118,10 @@ impl SqliteDatabase {
}
}

pub async fn add_key_to_keys(&self, auth_key: AuthKey) -> Result<usize, rusqlite::Error> {
pub async fn add_key_to_keys(&self, auth_key: &AuthKey) -> Result<usize, rusqlite::Error> {
let conn = self.pool.get().unwrap();
match conn.execute("INSERT INTO keys (key, valid_until) VALUES (?1, ?2)",
&[auth_key.key, auth_key.valid_until.unwrap().to_string()]
&[auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string()]
) {
Ok(updated) => {
if updated > 0 { return Ok(updated) }
Expand Down
12 changes: 5 additions & 7 deletions src/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::str::FromStr;
use log::{debug};
use warp::{filters, reply::Reply, Filter};
use warp::http::Response;
use crate::{Configuration, TorrentError, TorrentPeer, TorrentStats};
use crate::{TorrentError, TorrentPeer, TorrentStats};
use crate::key_manager::AuthKey;
use crate::utils::url_encode_bytes;
use super::common::*;
Expand Down Expand Up @@ -133,14 +133,12 @@ impl warp::Reply for HttpErrorResponse {

#[derive(Clone)]
pub struct HttpServer {
pub config: Arc<Configuration>,
pub tracker: Arc<TorrentTracker>,
tracker: Arc<TorrentTracker>,
}

impl HttpServer {
pub fn new(config: Arc<Configuration>, tracker: Arc<TorrentTracker>) -> HttpServer {
pub fn new(tracker: Arc<TorrentTracker>) -> HttpServer {
HttpServer {
config,
tracker
}
}
Expand Down Expand Up @@ -312,7 +310,7 @@ impl HttpServer {
}
};

let peer = TorrentPeer::from_http_announce_request(&query, remote_addr, self.config.get_ext_ip());
let peer = TorrentPeer::from_http_announce_request(&query, remote_addr, self.tracker.config.get_ext_ip());

match self.tracker.update_torrent_with_peer_and_get_stats(&info_hash, &peer).await {
Err(e) => {
Expand All @@ -329,7 +327,7 @@ impl HttpServer {

// todo: add http announce interval config option
// success response
let announce_interval = self.config.http_tracker.as_ref().unwrap().announce_interval;
let announce_interval = self.tracker.config.http_tracker.as_ref().unwrap().announce_interval;
HttpServer::send_announce_response(&query, torrent_stats, peers.unwrap(), announce_interval)
}
}
Expand Down
62 changes: 23 additions & 39 deletions src/key_manager.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use super::common::AUTH_KEY_LENGTH;
use crate::utils::current_time;
use crate::database::SqliteDatabase;
use std::sync::Arc;
use rand::{thread_rng, Rng};
use rand::distributions::Alphanumeric;
use serde::Serialize;
use log::debug;
use derive_more::{Display, Error};

#[derive(Serialize, Debug, Eq, PartialEq, Clone)]
pub struct AuthKey {
Expand All @@ -31,18 +30,26 @@ impl AuthKey {
}
}

pub struct KeyManager {
database: Arc<SqliteDatabase>,
#[derive(Debug, Display, PartialEq, Error)]
#[allow(dead_code)]
pub enum Error {
#[display(fmt = "Key is invalid.")]
KeyVerificationError,
#[display(fmt = "Key has expired.")]
KeyExpired
}

impl KeyManager {
pub fn new(database: Arc<SqliteDatabase>) -> KeyManager {
KeyManager {
database
}
impl From<r2d2_sqlite::rusqlite::Error> for Error {
fn from(e: r2d2_sqlite::rusqlite::Error) -> Self {
eprintln!("{}", e);
Error::KeyVerificationError
}
}

pub async fn generate_auth_key(&self, seconds_valid: u64) -> Result<AuthKey, ()> {
pub struct KeyManager;

impl KeyManager {
pub fn generate_auth_key(&self, seconds_valid: u64) -> AuthKey {
let key: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(AUTH_KEY_LENGTH)
Expand All @@ -51,40 +58,17 @@ impl KeyManager {

debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid);

let auth_key = AuthKey {
key: key.clone(),
AuthKey {
key,
valid_until: Some(current_time() + seconds_valid),
};

// add key to database
match self.database.add_key_to_keys(auth_key.clone()).await {
Ok(_) => Ok(auth_key),
Err(_) => Err(())
}
}

pub async fn remove_auth_key(&self, key: String) -> Result<(), ()> {
match self.database.remove_key_from_keys(key).await {
Ok(_) => Ok(()),
Err(_) => Err(())
}
}

pub async fn verify_auth_key(&self, auth_key: &AuthKey) -> bool {
pub async fn verify_auth_key(&self, auth_key: &AuthKey) -> Result<(), Error> {
let current_time = current_time();
if auth_key.valid_until.is_none() { return Err(Error::KeyVerificationError) }
if &auth_key.valid_until.unwrap() < &current_time { return Err(Error::KeyExpired) }

match self.database.get_key_from_keys(auth_key.key.to_string()).await {
Ok(auth_key) => {
match auth_key.valid_until {
// should not be possible, valid_until is required
None => false,
Some(valid_until) => valid_until > current_time
}
}
Err(e) => {
debug!{"{:?}", e}
false
}
}
Ok(())
}
}
Loading