Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 35 additions & 170 deletions crates/lib/src/rpc_server/auth.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,33 @@
use crate::{
constant::{X_API_KEY, X_HMAC_SIGNATURE, X_TIMESTAMP},
rpc_server::middleware_utils::{extract_parts_and_body_bytes, verify_jsonrpc_method},
rpc_server::middleware_utils::{extract_parts_and_body_bytes, get_jsonrpc_method},
};
use hmac::{Hmac, Mac};
use http::{Request, Response, StatusCode};
use jsonrpsee::server::logger::Body;
use sha2::Sha256;
use std::collections::HashSet;

#[derive(Clone)]
pub struct ApiKeyAuthLayer {
api_key: String,
allowed_methods: HashSet<String>,
}

impl ApiKeyAuthLayer {
pub fn new(api_key: String, allowed_methods: Vec<String>) -> Self {
Self { api_key, allowed_methods: allowed_methods.into_iter().collect() }
pub fn new(api_key: String) -> Self {
Self { api_key }
}
}

#[derive(Clone)]
pub struct ApiKeyAuthService<S> {
inner: S,
api_key: String,
allowed_methods: HashSet<String>,
}

impl<S> tower::Layer<S> for ApiKeyAuthLayer {
type Service = ApiKeyAuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
ApiKeyAuthService {
inner,
api_key: self.api_key.clone(),
allowed_methods: self.allowed_methods.clone(),
}
ApiKeyAuthService { inner, api_key: self.api_key.clone() }
}
}

Expand All @@ -58,7 +51,6 @@ where

fn call(&mut self, request: Request<Body>) -> Self::Future {
let api_key = self.api_key.clone();
let allowed_methods = self.allowed_methods.clone();
let mut inner = self.inner.clone();

Box::pin(async move {
Expand All @@ -69,19 +61,16 @@ where

let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;

match verify_jsonrpc_method(&body_bytes, &allowed_methods) {
Ok(method) => {
if method == "liveness" {
let new_body = Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
return inner.call(new_request).await;
}
}
Err(_) => {
return Ok(unauthorized_response);
// Bypass auth for liveness endpoint
if let Some(method) = get_jsonrpc_method(&body_bytes) {
if method == "liveness" {
let new_body = Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
return inner.call(new_request).await;
}
}

// Check for API key header
let req = Request::from_parts(parts, Body::from(body_bytes));
if let Some(provided_key) = req.headers().get(X_API_KEY) {
if provided_key.to_str().unwrap_or("") == api_key {
Expand All @@ -98,12 +87,11 @@ where
pub struct HmacAuthLayer {
secret: String,
max_timestamp_age: i64,
allowed_methods: HashSet<String>,
}

impl HmacAuthLayer {
pub fn new(secret: String, max_timestamp_age: i64, allowed_methods: Vec<String>) -> Self {
Self { secret, max_timestamp_age, allowed_methods: allowed_methods.into_iter().collect() }
pub fn new(secret: String, max_timestamp_age: i64) -> Self {
Self { secret, max_timestamp_age }
}
}

Expand All @@ -115,7 +103,6 @@ impl<S> tower::Layer<S> for HmacAuthLayer {
inner,
secret: self.secret.clone(),
max_timestamp_age: self.max_timestamp_age,
allowed_methods: self.allowed_methods.clone(),
}
}
}
Expand All @@ -125,7 +112,6 @@ pub struct HmacAuthService<S> {
inner: S,
secret: String,
max_timestamp_age: i64,
allowed_methods: HashSet<String>,
}

impl<S> tower::Service<Request<Body>> for HmacAuthService<S>
Expand All @@ -149,7 +135,6 @@ where
fn call(&mut self, request: Request<Body>) -> Self::Future {
let secret = self.secret.clone();
let max_timestamp_age = self.max_timestamp_age;
let allowed_methods = self.allowed_methods.clone();
let mut inner = self.inner.clone();

Box::pin(async move {
Expand All @@ -163,16 +148,12 @@ where

let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;

match verify_jsonrpc_method(&body_bytes, &allowed_methods) {
Ok(method) => {
if method == "liveness" {
let new_body = Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
return inner.call(new_request).await;
}
}
Err(_) => {
return Ok(unauthorized_response);
// Bypass auth for liveness endpoint
if let Some(method) = get_jsonrpc_method(&body_bytes) {
if method == "liveness" {
let new_body = Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
return inner.call(new_request).await;
}
}

Expand All @@ -185,7 +166,7 @@ where
let signature = signature.to_str().unwrap_or("");
let timestamp = timestamp.to_str().unwrap_or("");

// Verify timestamp is within 5 minutes
// Verify timestamp is within allowed age
let parsed_timestamp = timestamp.parse::<i64>();
if parsed_timestamp.is_err() {
return Ok(unauthorized_response);
Expand Down Expand Up @@ -275,8 +256,7 @@ mod tests {

#[tokio::test]
async fn test_api_key_auth_valid_key() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods);
let layer = ApiKeyAuthLayer::new("test-key".to_string());
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
let request = Request::builder()
Expand All @@ -291,8 +271,7 @@ mod tests {

#[tokio::test]
async fn test_api_key_auth_invalid_key() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods);
let layer = ApiKeyAuthLayer::new("test-key".to_string());
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
let request = Request::builder()
Expand All @@ -307,8 +286,7 @@ mod tests {

#[tokio::test]
async fn test_api_key_auth_missing_header() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods);
let layer = ApiKeyAuthLayer::new("test-key".to_string());
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
let request = Request::builder().uri("/test").body(Body::from(body)).unwrap();
Expand All @@ -319,8 +297,7 @@ mod tests {

#[tokio::test]
async fn test_api_key_auth_liveness_bypass() {
let allowed_methods = vec!["liveness".to_string()];
let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods);
let layer = ApiKeyAuthLayer::new("test-key".to_string());
let mut service = layer.layer(MockService);
let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#;
let request = Request::builder()
Expand All @@ -336,9 +313,7 @@ mod tests {
#[tokio::test]
async fn test_hmac_auth_valid_signature() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
let mut service = layer.layer(MockService);

let timestamp = std::time::SystemTime::now()
Expand Down Expand Up @@ -369,9 +344,7 @@ mod tests {
#[tokio::test]
async fn test_hmac_auth_invalid_signature() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
let mut service = layer.layer(MockService);

let timestamp = std::time::SystemTime::now()
Expand All @@ -397,9 +370,7 @@ mod tests {
#[tokio::test]
async fn test_hmac_auth_missing_headers() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
let mut service = layer.layer(MockService);

let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
Expand All @@ -413,9 +384,7 @@ mod tests {
#[tokio::test]
async fn test_hmac_auth_expired_timestamp() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
let mut service = layer.layer(MockService);

// Timestamp from 10 minutes ago (expired)
Expand Down Expand Up @@ -443,31 +412,10 @@ mod tests {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_hmac_auth_liveness_bypass() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let mut service = layer.layer(MockService);

let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#;
let request = Request::builder()
.method(Method::POST)
.uri("/")
.body(Body::from(liveness_body))
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}

#[tokio::test]
async fn test_hmac_auth_malformed_timestamp() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
let mut service = layer.layer(MockService);

let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
Expand All @@ -485,102 +433,19 @@ mod tests {
}

#[tokio::test]
async fn test_api_key_auth_unknown_method_rejected() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods);
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#;
let request = Request::builder()
.uri("/test")
.header(X_API_KEY, "test-key")
.body(Body::from(body))
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_api_key_auth_disabled_method_rejected() {
// Only allow liveness, not getConfig
let allowed_methods = vec!["liveness".to_string()];
let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods);
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
let request = Request::builder()
.uri("/test")
.header(X_API_KEY, "test-key")
.body(Body::from(body))
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_hmac_auth_unknown_method_rejected() {
let secret = "test-secret";
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let mut service = layer.layer(MockService);

let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
.to_string();

let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#;
let message = format!("{timestamp}{body}");

let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
mac.update(message.as_bytes());
let signature = hex::encode(mac.finalize().into_bytes());

let request = Request::builder()
.method(Method::POST)
.uri("/test")
.header(X_TIMESTAMP, &timestamp)
.header(X_HMAC_SIGNATURE, &signature)
.body(Body::from(body))
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_hmac_auth_disabled_method_rejected() {
async fn test_hmac_auth_liveness_bypass() {
let secret = "test-secret";
// Only allow liveness, not getConfig
let allowed_methods = vec!["liveness".to_string()];
let layer =
HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods);
let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
let mut service = layer.layer(MockService);

let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
.to_string();

let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
let message = format!("{timestamp}{body}");

let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
mac.update(message.as_bytes());
let signature = hex::encode(mac.finalize().into_bytes());

let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#;
let request = Request::builder()
.method(Method::POST)
.uri("/test")
.header(X_TIMESTAMP, &timestamp)
.header(X_HMAC_SIGNATURE, &signature)
.body(Body::from(body))
.uri("/")
.body(Body::from(liveness_body))
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(response.status(), StatusCode::OK);
}
}
Loading
Loading