Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce number of unwraps in api crate #2681

Merged
merged 2 commits into from
Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions api/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,25 @@
// limitations under the License.

use crate::router::{Handler, HandlerObj, ResponseFuture};
use crate::web::response;
use futures::future::ok;
use hyper::header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE};
use hyper::{Body, Request, Response, StatusCode};
use ring::constant_time::verify_slices_are_equal;

lazy_static! {
pub static ref GRIN_BASIC_REALM: HeaderValue =
HeaderValue::from_str("Basic realm=GrinAPI").unwrap();
}

// Basic Authentication Middleware
pub struct BasicAuthMiddleware {
api_basic_auth: String,
basic_realm: String,
basic_realm: &'static HeaderValue,
}

impl BasicAuthMiddleware {
pub fn new(api_basic_auth: String, basic_realm: String) -> BasicAuthMiddleware {
pub fn new(api_basic_auth: String, basic_realm: &'static HeaderValue) -> BasicAuthMiddleware {
BasicAuthMiddleware {
api_basic_auth,
basic_realm,
Expand All @@ -39,8 +45,12 @@ impl Handler for BasicAuthMiddleware {
req: Request<Body>,
mut handlers: Box<dyn Iterator<Item = HandlerObj>>,
) -> ResponseFuture {
let next_handler = match handlers.next() {
Some(h) => h,
None => return response(StatusCode::INTERNAL_SERVER_ERROR, "no handler found"),
};
if req.method().as_str() == "OPTIONS" {
return handlers.next().unwrap().call(req, handlers);
return next_handler.call(req, handlers);
}
if req.headers().contains_key(AUTHORIZATION)
&& verify_slices_are_equal(
Expand All @@ -49,21 +59,18 @@ impl Handler for BasicAuthMiddleware {
)
.is_ok()
{
handlers.next().unwrap().call(req, handlers)
next_handler.call(req, handlers)
} else {
// Unauthorized 401
unauthorized_response(&self.basic_realm)
}
}
}

fn unauthorized_response(basic_realm: &str) -> ResponseFuture {
fn unauthorized_response(basic_realm: &HeaderValue) -> ResponseFuture {
let response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(
WWW_AUTHENTICATE,
HeaderValue::from_str(basic_realm).unwrap(),
)
.header(WWW_AUTHENTICATE, basic_realm)
.body(Body::empty())
.unwrap();
Box::new(ok(response))
Expand Down
8 changes: 4 additions & 4 deletions api/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ fn build_request<'a>(
.into()
})?;
let mut builder = Request::builder();
if api_secret.is_some() {
let basic_auth =
"Basic ".to_string() + &to_base64(&("grin:".to_string() + &api_secret.unwrap()));
if let Some(api_secret) = api_secret {
let basic_auth = format!("Basic {}", to_base64(&format!("grin:{}", api_secret)));
builder.header(AUTHORIZATION, basic_auth);
}

Expand Down Expand Up @@ -223,6 +222,7 @@ fn send_request_async(req: Request<Body>) -> Box<dyn Future<Item = String, Error

fn send_request(req: Request<Body>) -> Result<String, Error> {
let task = send_request_async(req);
let mut rt = Runtime::new().unwrap();
let mut rt =
Runtime::new().context(ErrorKind::Internal("can't create Tokio runtime".to_owned()))?;
Ok(rt.block_on(task)?)
}
36 changes: 11 additions & 25 deletions api/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,26 @@ mod server_api;
mod transactions_api;
mod utils;

use crate::router::{Router, RouterError};

// Server
use self::server_api::IndexHandler;
use self::server_api::StatusHandler;

// Blocks
use self::blocks_api::BlockHandler;
use self::blocks_api::HeaderHandler;

// TX Set
use self::transactions_api::TxHashSetHandler;

// Chain
use self::chain_api::ChainCompactHandler;
use self::chain_api::ChainHandler;
use self::chain_api::ChainValidationHandler;
use self::chain_api::OutputHandler;

// Pool Handlers
use self::pool_api::PoolInfoHandler;
use self::pool_api::PoolPushHandler;

// Peers
use self::peers_api::PeerHandler;
use self::peers_api::PeersAllHandler;
use self::peers_api::PeersConnectedHandler;

use crate::auth::BasicAuthMiddleware;
use self::pool_api::PoolInfoHandler;
use self::pool_api::PoolPushHandler;
use self::server_api::IndexHandler;
use self::server_api::StatusHandler;
use self::transactions_api::TxHashSetHandler;
use crate::auth::{BasicAuthMiddleware, GRIN_BASIC_REALM};
use crate::chain;
use crate::p2p;
use crate::pool;
use crate::rest::*;
use crate::router::{Router, RouterError};
use crate::util;
use crate::util::RwLock;
use std::net::SocketAddr;
Expand All @@ -76,11 +63,10 @@ pub fn start_rest_apis(
) -> bool {
let mut apis = ApiServer::new();
let mut router = build_router(chain, tx_pool, peers).expect("unable to build API router");
if api_secret.is_some() {
let api_basic_auth =
"Basic ".to_string() + &util::to_base64(&("grin:".to_string() + &api_secret.unwrap()));
let basic_realm = "Basic realm=GrinAPI".to_string();
let basic_auth_middleware = Arc::new(BasicAuthMiddleware::new(api_basic_auth, basic_realm));
if let Some(api_secret) = api_secret {
let api_basic_auth = format!("Basic {}", util::to_base64(&format!("grin:{}", api_secret)));
let basic_auth_middleware =
Arc::new(BasicAuthMiddleware::new(api_basic_auth, &GRIN_BASIC_REALM));
router.add_middleware(basic_auth_middleware);
}

Expand Down
23 changes: 12 additions & 11 deletions api/src/handlers/blocks_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl HeaderHandler {
return Ok(h);
}
if let Ok(height) = input.parse() {
match w(&self.chain).get_header_by_height(height) {
match w(&self.chain)?.get_header_by_height(height) {
Ok(header) => return Ok(BlockHeaderPrintable::from_header(&header)),
Err(_) => return Err(ErrorKind::NotFound)?,
}
Expand All @@ -50,15 +50,15 @@ impl HeaderHandler {
let vec = util::from_hex(input)
.map_err(|e| ErrorKind::Argument(format!("invalid input: {}", e)))?;
let h = Hash::from_vec(&vec);
let header = w(&self.chain)
let header = w(&self.chain)?
.get_block_header(&h)
.context(ErrorKind::NotFound)?;
Ok(BlockHeaderPrintable::from_header(&header))
}

fn get_header_for_output(&self, commit_id: String) -> Result<BlockHeaderPrintable, Error> {
let oid = get_output(&self.chain, &commit_id)?.1;
match w(&self.chain).get_header_for_output(&oid) {
match w(&self.chain)?.get_header_for_output(&oid) {
Ok(header) => Ok(BlockHeaderPrintable::from_header(&header)),
Err(_) => Err(ErrorKind::NotFound)?,
}
Expand All @@ -85,22 +85,23 @@ pub struct BlockHandler {

impl BlockHandler {
fn get_block(&self, h: &Hash) -> Result<BlockPrintable, Error> {
let block = w(&self.chain).get_block(h).context(ErrorKind::NotFound)?;
Ok(BlockPrintable::from_block(&block, w(&self.chain), false))
let chain = w(&self.chain)?;
let block = chain.get_block(h).context(ErrorKind::NotFound)?;
BlockPrintable::from_block(&block, chain, false)
.map_err(|_| ErrorKind::Internal("chain error".to_owned()).into())
}

fn get_compact_block(&self, h: &Hash) -> Result<CompactBlockPrintable, Error> {
let block = w(&self.chain).get_block(h).context(ErrorKind::NotFound)?;
Ok(CompactBlockPrintable::from_compact_block(
&block.into(),
w(&self.chain),
))
let chain = w(&self.chain)?;
let block = chain.get_block(h).context(ErrorKind::NotFound)?;
CompactBlockPrintable::from_compact_block(&block.into(), chain)
.map_err(|_| ErrorKind::Internal("chain error".to_owned()).into())
}

// Try to decode the string as a height or a hash.
fn parse_input(&self, input: String) -> Result<Hash, Error> {
if let Ok(height) = input.parse() {
match w(&self.chain).get_header_by_height(height) {
match w(&self.chain)?.get_header_by_height(height) {
Ok(header) => return Ok(header.hash()),
Err(_) => return Err(ErrorKind::NotFound)?,
}
Expand Down
17 changes: 10 additions & 7 deletions api/src/handlers/chain_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::types::*;
use crate::util;
use crate::util::secp::pedersen::Commitment;
use crate::web::*;
use failure::ResultExt;
use hyper::{Body, Request, StatusCode};
use std::sync::Weak;

Expand All @@ -32,7 +33,7 @@ pub struct ChainHandler {

impl ChainHandler {
fn get_tip(&self) -> Result<Tip, Error> {
let head = w(&self.chain)
let head = w(&self.chain)?
.head()
.map_err(|e| ErrorKind::Internal(format!("can't get head: {}", e)))?;
Ok(Tip::from_tip(head))
Expand All @@ -53,7 +54,7 @@ pub struct ChainValidationHandler {

impl Handler for ChainValidationHandler {
fn get(&self, _req: Request<Body>) -> ResponseFuture {
match w(&self.chain).validate(true) {
match w_fut!(&self.chain).validate(true) {
Ok(_) => response(StatusCode::OK, "{}"),
Err(e) => response(
StatusCode::INTERNAL_SERVER_ERROR,
Expand All @@ -72,7 +73,7 @@ pub struct ChainCompactHandler {

impl Handler for ChainCompactHandler {
fn post(&self, _req: Request<Body>) -> ResponseFuture {
match w(&self.chain).compact() {
match w_fut!(&self.chain).compact() {
Ok(_) => response(StatusCode::OK, "{}"),
Err(e) => response(
StatusCode::INTERNAL_SERVER_ERROR,
Expand Down Expand Up @@ -118,23 +119,25 @@ impl OutputHandler {
commitments: Vec<Commitment>,
include_proof: bool,
) -> Result<BlockOutputs, Error> {
let header = w(&self.chain)
let header = w(&self.chain)?
.get_header_by_height(block_height)
.map_err(|_| ErrorKind::NotFound)?;

// TODO - possible to compact away blocks we care about
// in the period between accepting the block and refreshing the wallet
let block = w(&self.chain)
let chain = w(&self.chain)?;
let block = chain
.get_block(&header.hash())
.map_err(|_| ErrorKind::NotFound)?;
let outputs = block
.outputs()
.iter()
.filter(|output| commitments.is_empty() || commitments.contains(&output.commit))
.map(|output| {
OutputPrintable::from_output(output, w(&self.chain), Some(&header), include_proof)
OutputPrintable::from_output(output, chain.clone(), Some(&header), include_proof)
})
.collect();
.collect::<Result<Vec<_>, _>>()
.context(ErrorKind::Internal("cain error".to_owned()))?;

Ok(BlockOutputs {
header: BlockHeaderInfo::from_header(&header),
Expand Down
10 changes: 5 additions & 5 deletions api/src/handlers/peers_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct PeersAllHandler {

impl Handler for PeersAllHandler {
fn get(&self, _req: Request<Body>) -> ResponseFuture {
let peers = &w(&self.peers).all_peers();
let peers = &w_fut!(&self.peers).all_peers();
json_response_pretty(&peers)
}
}
Expand All @@ -37,7 +37,7 @@ pub struct PeersConnectedHandler {

impl Handler for PeersConnectedHandler {
fn get(&self, _req: Request<Body>) -> ResponseFuture {
let peers: Vec<PeerInfoDisplay> = w(&self.peers)
let peers: Vec<PeerInfoDisplay> = w_fut!(&self.peers)
.connected_peers()
.iter()
.map(|p| p.info.clone().into())
Expand Down Expand Up @@ -73,7 +73,7 @@ impl Handler for PeerHandler {
);
}

match w(&self.peers).get_peer(peer_addr) {
match w_fut!(&self.peers).get_peer(peer_addr) {
Ok(peer) => json_response(&peer),
Err(_) => response(StatusCode::NOT_FOUND, "peer not found"),
}
Expand Down Expand Up @@ -101,8 +101,8 @@ impl Handler for PeerHandler {
};

match command {
"ban" => w(&self.peers).ban_peer(addr, ReasonForBan::ManualBan),
"unban" => w(&self.peers).unban_peer(addr),
"ban" => w_fut!(&self.peers).ban_peer(addr, ReasonForBan::ManualBan),
"unban" => w_fut!(&self.peers).unban_peer(addr),
_ => return response(StatusCode::BAD_REQUEST, "invalid command"),
};

Expand Down
10 changes: 7 additions & 3 deletions api/src/handlers/pool_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::util;
use crate::util::RwLock;
use crate::web::*;
use failure::ResultExt;
use futures::future::ok;
use futures::future::{err, ok};
use futures::Future;
use hyper::{Body, Request, StatusCode};
use std::sync::Weak;
Expand All @@ -37,7 +37,7 @@ pub struct PoolInfoHandler {

impl Handler for PoolInfoHandler {
fn get(&self, _req: Request<Body>) -> ResponseFuture {
let pool_arc = w(&self.tx_pool);
let pool_arc = w_fut!(&self.tx_pool);
let pool = pool_arc.read();

json_response(&PoolInfo {
Expand All @@ -63,7 +63,11 @@ impl PoolPushHandler {
let params = QueryParams::from(req.uri().query());

let fluff = params.get("fluff").is_some();
let pool_arc = w(&self.tx_pool).clone();
let pool_arc = match w(&self.tx_pool) {
//w(&self.tx_pool).clone();
Ok(p) => p,
Err(e) => return Box::new(err(e)),
};

Box::new(
parse_body(req)
Expand Down
4 changes: 2 additions & 2 deletions api/src/handlers/server_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ pub struct StatusHandler {

impl StatusHandler {
fn get_status(&self) -> Result<Status, Error> {
let head = w(&self.chain)
let head = w(&self.chain)?
.head()
.map_err(|e| ErrorKind::Internal(format!("can't get head: {}", e)))?;
Ok(Status::from_tip_and_peers(
head,
w(&self.peers).peer_count(),
w(&self.peers)?.peer_count(),
))
}
}
Expand Down
Loading