diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..fae8ef4 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,23 @@ +language: rust +sudo: false +cache: cargo + +rust: + - nightly +env: + global: + # XXX: begin_test_transaction doesn't play nice over threaded tests + - RUST_TEST_THREADS=1 + - ROCKET_DATABASE_URL="mysql://travis@127.0.0.1/megaphone" + +services: + - mysql + +# XXX: kill the diesel_cli requirement: +# https://docs.rs/diesel/0.16.0/diesel/macro.embed_migrations.html +before_script: + - mysql -e 'CREATE DATABASE IF NOT EXISTS megaphone;' + - | + cargo install diesel_cli --no-default-features --features mysql || \ + echo "diesel_cli already installed" + - diesel setup --database-url $ROCKET_DATABASE_URL diff --git a/Cargo.toml b/Cargo.toml index b0852e1..ddebe1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,9 +4,11 @@ version = "0.1.0" authors = ["jrconlin "] [dependencies] -diesel = { version = "1.0", features = ["mysql"] } -dotenv = "0.9" +diesel = { version = "1.0", features = ["mysql", "r2d2"] } +failure = "0.1" rocket = "0.3" rocket_codegen = "0.3" +rocket_contrib = "0.3" +serde = "1.0" +serde_derive = "1.0" serde_json = "1.0" -websocket = "0.20" diff --git a/README.md b/README.md index 32ba993..54777b9 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,15 @@ See [API doc](https://docs.google.com/document/d/1Wxqf1a4HDkKgHDIswPmhmdvk8KPoME ***NOTE***: This will require: - * rust nightly. See [rocket.rs Getting Started](https://rocket.rs/guide/getting-started/) for -additional steps. - * libmysql-dev installed + + * rust nightly. See [rocket.rs Getting + Started](https://rocket.rs/guide/getting-started/) for additional steps. + * mysql + * libmysqlclient installed (brew install mysql on macOS, apt-get install + libmysqlclient-dev on Ubuntu) + * diesel cli: (cargo install diesel_cli --no-default-features + --features mysql) + +Run: + * export ROCKET_DATABASE_URL=mysql://scott:tiger@mydatabase/megaphone + * $ diesel setup --database-url $ROCKET_DATABASE_URL diff --git a/Rocket.toml b/Rocket.toml new file mode 100644 index 0000000..5172c1b --- /dev/null +++ b/Rocket.toml @@ -0,0 +1,11 @@ +[development] +#database_url = "mysql://" +json_logging = false + +[staging] +#database_url = "mysql://" +json_logging = true + +[production] +#database_url = "mysql://" +json_logging = true diff --git a/migrations/.gitkeep b/migrations/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/migrations/2018-02-20-220249_create_broadcastsv1_table/down.sql b/migrations/2018-02-20-220249_create_broadcastsv1_table/down.sql new file mode 100644 index 0000000..d7dcd7e --- /dev/null +++ b/migrations/2018-02-20-220249_create_broadcastsv1_table/down.sql @@ -0,0 +1 @@ +DROP TABLE broadcastsv1; diff --git a/migrations/2018-02-20-220249_create_broadcastsv1_table/up.sql b/migrations/2018-02-20-220249_create_broadcastsv1_table/up.sql new file mode 100644 index 0000000..56a7b45 --- /dev/null +++ b/migrations/2018-02-20-220249_create_broadcastsv1_table/up.sql @@ -0,0 +1,7 @@ +CREATE TABLE broadcastsv1 ( + broadcaster_id VARCHAR(64) NOT NULL, + bchannel_id VARCHAR(128) NOT NULL, + last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP NOT NULL, + version VARCHAR(200) NOT NULL, + PRIMARY KEY(broadcaster_id, bchannel_id) +); diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..ef4fb4a --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,49 @@ +pub mod schema; +pub mod models; + +use std::ops::Deref; + +use diesel::mysql::MysqlConnection; +use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; +use failure::err_msg; + +use rocket::http::Status; +use rocket::request::{self, FromRequest}; +use rocket::{Config, Outcome, Request, State}; + +use error::Result; + +pub type MysqlPool = Pool>; + +pub fn pool_from_config(config: &Config) -> Result { + let database_url = config + .get_str("database_url") + .map_err(|_| err_msg("ROCKET_DATABASE_URL undefined"))? + .to_string(); + let max_size = config.get_int("database_pool_max_size").unwrap_or(10) as u32; + let manager = ConnectionManager::::new(database_url); + Ok(Pool::builder().max_size(max_size).build(manager)?) +} + +pub struct Conn(pub PooledConnection>); + +impl Deref for Conn { + type Target = MysqlConnection; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for Conn { + type Error = (); + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + let pool = request.guard::>()?; + match pool.get() { + Ok(conn) => Outcome::Success(Conn(conn)), + Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())), + } + } +} diff --git a/src/db/models.rs b/src/db/models.rs new file mode 100644 index 0000000..31543ff --- /dev/null +++ b/src/db/models.rs @@ -0,0 +1,47 @@ +use failure::ResultExt; +use diesel::{replace_into, RunQueryDsl}; +use diesel::mysql::MysqlConnection; + +use super::schema::broadcastsv1; +use error::{HandlerErrorKind, HandlerResult}; + +#[derive(Debug, Queryable, Insertable)] +#[table_name = "broadcastsv1"] +pub struct Broadcast { + pub broadcaster_id: String, + pub bchannel_id: String, + pub version: String, +} + +impl Broadcast { + pub fn id(&self) -> String { + format!("{}/{}", self.broadcaster_id, self.bchannel_id) + } +} + +/// An authorized broadcaster +pub struct Broadcaster { + pub id: String, +} + +impl Broadcaster { + pub fn new_broadcast( + self, + conn: &MysqlConnection, + bchannel_id: String, + version: String, + ) -> HandlerResult { + let broadcast = Broadcast { + broadcaster_id: self.id, + bchannel_id: bchannel_id, + version: version, + }; + Ok(replace_into(broadcastsv1::table) + .values(&broadcast) + .execute(conn) + .context(HandlerErrorKind::DBError)?) + } +} + +// An authorized reader of current broadcasts +//struct BroadcastAdmin; diff --git a/src/db/schema.rs b/src/db/schema.rs new file mode 100644 index 0000000..239ce69 --- /dev/null +++ b/src/db/schema.rs @@ -0,0 +1,8 @@ +table! { + broadcastsv1 (broadcaster_id, bchannel_id) { + broadcaster_id -> Varchar, + bchannel_id -> Varchar, + last_updated -> Timestamp, + version -> Varchar, + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..044ce76 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,110 @@ +/// Error handling based on the failure crate +/// +/// Only rocket's Handlers can render error responses w/ a contextual JSON +/// payload. So request guards should generally return VALIDATION_FAILED, +/// leaving error handling to the Handler (which in turn must take a Result of +/// request guards' fields). +/// +/// HandlerErrors are rocket Responders (render their own error responses). +use std::fmt; +use std::result; + +use failure::{Backtrace, Context, Error, Fail}; +use rocket::{self, response, Request}; +use rocket::http::Status; +use rocket::response::{Responder, Response}; +use rocket_contrib::Json; + +pub type Result = result::Result; + +pub type HandlerResult = result::Result; + +/// Signal a request guard failure, propagated up to the Handler to render an +/// error response +pub const VALIDATION_FAILED: Status = Status::InternalServerError; + +#[derive(Debug)] +pub struct HandlerError { + inner: Context, +} + +#[derive(Clone, Eq, PartialEq, Debug, Fail)] +pub enum HandlerErrorKind { + /// 401 Unauthorized + #[fail(display = "Unauthorized: {}", _0)] + Unauthorized(String), + /// 404 Not Found + #[fail(display = "Not Found")] + NotFound, + #[fail(display = "A database error occurred")] + DBError, + #[fail(display = "Version information not included in body of update")] + MissingVersionDataError, + #[fail(display = "Invalid Version info (must be URL safe Base 64)")] + InvalidVersionDataError, + #[fail(display = "Unexpected rocket error: {:?}", _0)] + RocketError(rocket::Error), // rocket::Error isn't a std Error (so no #[cause]) +} + +impl HandlerErrorKind { + /// Return a rocket response Status to be rendered for an error + pub fn http_status(&self) -> Status { + match *self { + HandlerErrorKind::DBError => Status::ServiceUnavailable, + HandlerErrorKind::NotFound => Status::NotFound, + HandlerErrorKind::Unauthorized(..) => Status::Unauthorized, + _ => Status::BadRequest, + } + } +} + +impl Fail for HandlerError { + fn cause(&self) -> Option<&Fail> { + self.inner.cause() + } + + fn backtrace(&self) -> Option<&Backtrace> { + self.inner.backtrace() + } +} + +impl fmt::Display for HandlerError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.inner, f) + } +} + +impl HandlerError { + pub fn kind(&self) -> &HandlerErrorKind { + self.inner.get_context() + } +} + +impl From for HandlerError { + fn from(kind: HandlerErrorKind) -> HandlerError { + HandlerError { + inner: Context::new(kind), + } + } +} + +impl From> for HandlerError { + fn from(inner: Context) -> HandlerError { + HandlerError { inner: inner } + } +} + +/// Generate HTTP error responses for HandlerErrors +impl<'r> Responder<'r> for HandlerError { + fn respond_to(self, request: &Request) -> response::Result<'r> { + let status = self.kind().http_status(); + let json = Json(json!({ + "status": status.code, + "error": format!("{}", self) + })); + // XXX: logging + Response::build_from(json.respond_to(request)?) + .status(status) + .ok() + } +} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..1ce2943 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,235 @@ +use std::convert::Into; +use std::collections::HashMap; +use std::io::Read; + +use diesel::{QueryDsl, RunQueryDsl}; +use failure::ResultExt; +use rocket::{self, Data, Request, Rocket}; +use rocket::data::{self, FromData}; +use rocket::Outcome::*; +use rocket::outcome::IntoOutcome; +use rocket::request::{self, FromRequest}; +use rocket_contrib::Json; + +use db::{self, pool_from_config}; +use db::models::{Broadcast, Broadcaster}; +use db::schema::broadcastsv1; +use error::{HandlerError, HandlerErrorKind, HandlerResult, Result, VALIDATION_FAILED}; + +impl<'a, 'r> FromRequest<'a, 'r> for Broadcaster { + type Error = HandlerError; + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + if let Some(_auth) = request.headers().get_one("Authorization") { + // These should be guaranteed on the path when we're called + let broadcaster_id = request + .get_param::(0) + .map_err(HandlerErrorKind::RocketError) + .map_err(Into::into) + .into_outcome(VALIDATION_FAILED)?; + // TODO: Validate auth cookie + Success(Broadcaster { id: broadcaster_id }) + } else { + Failure(( + VALIDATION_FAILED, + HandlerErrorKind::Unauthorized("Missing Authorization header".to_string()).into(), + )) + } + } +} + +/// Version information from command line. +struct VersionInput { + value: String, +} + +impl FromData for VersionInput { + type Error = HandlerError; + + fn from_data(_: &Request, data: Data) -> data::Outcome { + let mut string = String::new(); + data.open() + .read_to_string(&mut string) + .context(HandlerErrorKind::MissingVersionDataError) + .map_err(Into::into) + .into_outcome(VALIDATION_FAILED)?; + if string.is_empty() { + return Failure(( + VALIDATION_FAILED, + HandlerErrorKind::InvalidVersionDataError.into(), + )); + } + // TODO Validate the version info + Success(VersionInput { value: string }) + } +} + +// REST Functions + +/// Set a version for a broadcaster / bchannel +#[post("/v1/broadcasts/<_broadcaster_id>/", data = "")] +fn broadcast( + conn: db::Conn, + broadcaster: HandlerResult, + _broadcaster_id: String, + bchannel_id: String, + version: HandlerResult, +) -> HandlerResult { + broadcaster?.new_broadcast(&conn, bchannel_id, version?.value)?; + Ok(Json(json!({ + "status": 200 + }))) +} + +/// Dump the current version table +#[get("/v1/broadcasts")] +//fn get_broadcasts(bcast_admin: BroadcastAdmin, conn: db::Conn) -> HandlerResult { +fn get_broadcasts(conn: db::Conn) -> HandlerResult { + // flatten into HashMap FromIterator<(K, V)> + let broadcasts: HashMap = broadcastsv1::table + .select(( + broadcastsv1::broadcaster_id, + broadcastsv1::bchannel_id, + broadcastsv1::version, + )) + .load::(&*conn) + .context(HandlerErrorKind::DBError)? + .into_iter() + .map(|bcast| (bcast.id(), bcast.version)) + .collect(); + Ok(Json(json!({ + "status": 200, + "broadcasts": broadcasts + }))) +} + +#[error(404)] +fn not_found() -> HandlerResult { + Err(HandlerErrorKind::NotFound)? +} + +pub fn rocket() -> Result { + let rocket = rocket::ignite(); + let pool = pool_from_config(rocket.config())?; + Ok(rocket + .manage(pool) + .mount("/", routes![broadcast, get_broadcasts]) + .catch(errors![not_found])) +} + +#[cfg(test)] +mod test { + use std::env; + + use diesel::Connection; + use rocket::local::Client; + use rocket::http::{Header, Status}; + use rocket::response::Response; + use serde_json::{self, Value}; + + use db::MysqlPool; + use super::rocket; + + /// Return a Rocket Client for testing + /// + /// The managed db pool is set to a maxiumum of one connection w/ + /// a transaction began that is never committed + fn rocket_client() -> Client { + // hacky/easiest way to set into rocket's config + env::set_var("ROCKET_DATABASE_POOL_MAX_SIZE", "1"); + let rocket = rocket().expect("rocket failed"); + { + let pool = rocket.state::().unwrap(); + let conn = &*pool.get().expect("Couldn't connect to database"); + conn.begin_test_transaction().unwrap(); + } + Client::new(rocket).expect("rocket launch failed") + } + + fn auth() -> Header<'static> { + Header::new("Authorization".to_string(), "Bearer XXX".to_string()) + } + + fn json_body(response: &mut Response) -> Value { + assert!(response.content_type().map_or(false, |ct| ct.is_json())); + serde_json::from_str(&response.body_string().unwrap()).unwrap() + } + + #[test] + fn test_post() { + let client = rocket_client(); + let mut response = client + .post("/v1/broadcasts/foo/bar") + .header(auth()) + .body("v1") + .dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(json_body(&mut response), json!({"status": 200})); + } + + #[test] + fn test_post_no_body() { + let client = rocket_client(); + let mut response = client + .post("/v1/broadcasts/foo/bar") + .header(auth()) + .dispatch(); + assert_eq!(response.status(), Status::BadRequest); + let result = json_body(&mut response); + assert_eq!(result.get("status").unwrap(), Status::BadRequest.code); + assert!( + result + .get("error") + .unwrap() + .as_str() + .unwrap() + .contains("Version") + ); + } + + #[test] + fn test_post_no_id() { + let client = rocket_client(); + let mut response = client + .post("/v1/broadcasts/foo") + .header(auth()) + .body("v1") + .dispatch(); + assert_eq!(response.status(), Status::NotFound); + assert_eq!( + json_body(&mut response), + json!({"status": 404, "error": "Not Found"}) + ); + } + + #[test] + fn test_post_no_auth() { + let client = rocket_client(); + let mut response = client.post("/v1/broadcasts/foo/bar").body("v1").dispatch(); + assert_eq!(response.status(), Status::Unauthorized); + let result = json_body(&mut response); + assert_eq!(result.get("status").unwrap(), 401); + } + + #[test] + fn test_post_get() { + let client = rocket_client(); + let _ = client + .post("/v1/broadcasts/foo/bar") + .header(auth()) + .body("v1") + .dispatch(); + let _ = client + .post("/v1/broadcasts/baz/quux") + .header(auth()) + .body("v0") + .dispatch(); + let mut response = client.get("/v1/broadcasts").header(auth()).dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!( + json_body(&mut response), + json!({"status": 200, "broadcasts": {"baz/quux": "v0", "foo/bar": "v1"}}) + ); + } + +} diff --git a/src/main.rs b/src/main.rs index 7e24818..bbc6b89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,22 @@ #![feature(plugin)] #![plugin(rocket_codegen)] -extern crate rocket; +#[macro_use] extern crate diesel; +#[macro_use] +extern crate failure; +extern crate rocket; +#[macro_use] +extern crate rocket_contrib; +extern crate serde; extern crate serde_json; -extern crate websocket; -/* Set a version */ -#[post("/v1/rtu//")] -fn accept(broadcaster_id: String, collection_id: String) -> String { - return String::from("Hello, Other world"); -} - -/* Dump the current table */ -#[get("/v1/rtu")] -fn dump() -> String { - return String::from("Hello, Other world"); -} +mod db; +mod error; +mod http; -// TODO: Websocket handler. +use http::rocket; fn main() { - println!("Hello world."); + rocket().expect("rocket failed").launch(); }