diff --git a/Makefile b/Makefile index 9d7bdffd..6d77d76f 100644 --- a/Makefile +++ b/Makefile @@ -14,4 +14,4 @@ include makefiles/METRICS.makefile all: check test build # Run all tests (unit + TypeScript + integration) -test-all: test test-ts test-integration \ No newline at end of file +test-all: build test test-ts test-integration \ No newline at end of file diff --git a/crates/lib/src/rpc_server/auth.rs b/crates/lib/src/rpc_server/auth.rs index d61d1e2b..132683fa 100644 --- a/crates/lib/src/rpc_server/auth.rs +++ b/crates/lib/src/rpc_server/auth.rs @@ -1,22 +1,20 @@ use crate::{ constant::{X_API_KEY, X_HMAC_SIGNATURE, X_TIMESTAMP}, - rpc_server::middleware_utils::{extract_parts_and_body_bytes, verify_jsonrpc_method}, + rpc_server::middleware_utils::{extract_parts_and_body_bytes, get_jsonrpc_method}, }; use hmac::{Hmac, Mac}; use http::{Request, Response, StatusCode}; use jsonrpsee::server::logger::Body; use sha2::Sha256; -use std::collections::HashSet; #[derive(Clone)] pub struct ApiKeyAuthLayer { api_key: String, - allowed_methods: HashSet, } impl ApiKeyAuthLayer { - pub fn new(api_key: String, allowed_methods: Vec) -> Self { - Self { api_key, allowed_methods: allowed_methods.into_iter().collect() } + pub fn new(api_key: String) -> Self { + Self { api_key } } } @@ -24,17 +22,12 @@ impl ApiKeyAuthLayer { pub struct ApiKeyAuthService { inner: S, api_key: String, - allowed_methods: HashSet, } impl tower::Layer for ApiKeyAuthLayer { type Service = ApiKeyAuthService; fn layer(&self, inner: S) -> Self::Service { - ApiKeyAuthService { - inner, - api_key: self.api_key.clone(), - allowed_methods: self.allowed_methods.clone(), - } + ApiKeyAuthService { inner, api_key: self.api_key.clone() } } } @@ -58,7 +51,6 @@ where fn call(&mut self, request: Request) -> Self::Future { let api_key = self.api_key.clone(); - let allowed_methods = self.allowed_methods.clone(); let mut inner = self.inner.clone(); Box::pin(async move { @@ -69,19 +61,16 @@ where let (parts, body_bytes) = extract_parts_and_body_bytes(request).await; - match verify_jsonrpc_method(&body_bytes, &allowed_methods) { - Ok(method) => { - if method == "liveness" { - let new_body = Body::from(body_bytes); - let new_request = Request::from_parts(parts, new_body); - return inner.call(new_request).await; - } - } - Err(_) => { - return Ok(unauthorized_response); + // Bypass auth for liveness endpoint + if let Some(method) = get_jsonrpc_method(&body_bytes) { + if method == "liveness" { + let new_body = Body::from(body_bytes); + let new_request = Request::from_parts(parts, new_body); + return inner.call(new_request).await; } } + // Check for API key header let req = Request::from_parts(parts, Body::from(body_bytes)); if let Some(provided_key) = req.headers().get(X_API_KEY) { if provided_key.to_str().unwrap_or("") == api_key { @@ -98,12 +87,11 @@ where pub struct HmacAuthLayer { secret: String, max_timestamp_age: i64, - allowed_methods: HashSet, } impl HmacAuthLayer { - pub fn new(secret: String, max_timestamp_age: i64, allowed_methods: Vec) -> Self { - Self { secret, max_timestamp_age, allowed_methods: allowed_methods.into_iter().collect() } + pub fn new(secret: String, max_timestamp_age: i64) -> Self { + Self { secret, max_timestamp_age } } } @@ -115,7 +103,6 @@ impl tower::Layer for HmacAuthLayer { inner, secret: self.secret.clone(), max_timestamp_age: self.max_timestamp_age, - allowed_methods: self.allowed_methods.clone(), } } } @@ -125,7 +112,6 @@ pub struct HmacAuthService { inner: S, secret: String, max_timestamp_age: i64, - allowed_methods: HashSet, } impl tower::Service> for HmacAuthService @@ -149,7 +135,6 @@ where fn call(&mut self, request: Request) -> Self::Future { let secret = self.secret.clone(); let max_timestamp_age = self.max_timestamp_age; - let allowed_methods = self.allowed_methods.clone(); let mut inner = self.inner.clone(); Box::pin(async move { @@ -163,16 +148,12 @@ where let (parts, body_bytes) = extract_parts_and_body_bytes(request).await; - match verify_jsonrpc_method(&body_bytes, &allowed_methods) { - Ok(method) => { - if method == "liveness" { - let new_body = Body::from(body_bytes); - let new_request = Request::from_parts(parts, new_body); - return inner.call(new_request).await; - } - } - Err(_) => { - return Ok(unauthorized_response); + // Bypass auth for liveness endpoint + if let Some(method) = get_jsonrpc_method(&body_bytes) { + if method == "liveness" { + let new_body = Body::from(body_bytes); + let new_request = Request::from_parts(parts, new_body); + return inner.call(new_request).await; } } @@ -185,7 +166,7 @@ where let signature = signature.to_str().unwrap_or(""); let timestamp = timestamp.to_str().unwrap_or(""); - // Verify timestamp is within 5 minutes + // Verify timestamp is within allowed age let parsed_timestamp = timestamp.parse::(); if parsed_timestamp.is_err() { return Ok(unauthorized_response); @@ -275,8 +256,7 @@ mod tests { #[tokio::test] async fn test_api_key_auth_valid_key() { - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods); + let layer = ApiKeyAuthLayer::new("test-key".to_string()); let mut service = layer.layer(MockService); let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; let request = Request::builder() @@ -291,8 +271,7 @@ mod tests { #[tokio::test] async fn test_api_key_auth_invalid_key() { - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods); + let layer = ApiKeyAuthLayer::new("test-key".to_string()); let mut service = layer.layer(MockService); let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; let request = Request::builder() @@ -307,8 +286,7 @@ mod tests { #[tokio::test] async fn test_api_key_auth_missing_header() { - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods); + let layer = ApiKeyAuthLayer::new("test-key".to_string()); let mut service = layer.layer(MockService); let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; let request = Request::builder().uri("/test").body(Body::from(body)).unwrap(); @@ -319,8 +297,7 @@ mod tests { #[tokio::test] async fn test_api_key_auth_liveness_bypass() { - let allowed_methods = vec!["liveness".to_string()]; - let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods); + let layer = ApiKeyAuthLayer::new("test-key".to_string()); let mut service = layer.layer(MockService); let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#; let request = Request::builder() @@ -336,9 +313,7 @@ mod tests { #[tokio::test] async fn test_hmac_auth_valid_signature() { let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); + let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE); let mut service = layer.layer(MockService); let timestamp = std::time::SystemTime::now() @@ -369,9 +344,7 @@ mod tests { #[tokio::test] async fn test_hmac_auth_invalid_signature() { let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); + let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE); let mut service = layer.layer(MockService); let timestamp = std::time::SystemTime::now() @@ -397,9 +370,7 @@ mod tests { #[tokio::test] async fn test_hmac_auth_missing_headers() { let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); + let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE); let mut service = layer.layer(MockService); let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; @@ -413,9 +384,7 @@ mod tests { #[tokio::test] async fn test_hmac_auth_expired_timestamp() { let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); + let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE); let mut service = layer.layer(MockService); // Timestamp from 10 minutes ago (expired) @@ -443,31 +412,10 @@ mod tests { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } - #[tokio::test] - async fn test_hmac_auth_liveness_bypass() { - let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); - let mut service = layer.layer(MockService); - - let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#; - let request = Request::builder() - .method(Method::POST) - .uri("/") - .body(Body::from(liveness_body)) - .unwrap(); - - let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - } - #[tokio::test] async fn test_hmac_auth_malformed_timestamp() { let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); + let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE); let mut service = layer.layer(MockService); let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; @@ -485,102 +433,19 @@ mod tests { } #[tokio::test] - async fn test_api_key_auth_unknown_method_rejected() { - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods); - let mut service = layer.layer(MockService); - let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#; - let request = Request::builder() - .uri("/test") - .header(X_API_KEY, "test-key") - .body(Body::from(body)) - .unwrap(); - - let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - } - - #[tokio::test] - async fn test_api_key_auth_disabled_method_rejected() { - // Only allow liveness, not getConfig - let allowed_methods = vec!["liveness".to_string()]; - let layer = ApiKeyAuthLayer::new("test-key".to_string(), allowed_methods); - let mut service = layer.layer(MockService); - let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; - let request = Request::builder() - .uri("/test") - .header(X_API_KEY, "test-key") - .body(Body::from(body)) - .unwrap(); - - let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - } - - #[tokio::test] - async fn test_hmac_auth_unknown_method_rejected() { - let secret = "test-secret"; - let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); - let mut service = layer.layer(MockService); - - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - .to_string(); - - let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#; - let message = format!("{timestamp}{body}"); - - let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); - mac.update(message.as_bytes()); - let signature = hex::encode(mac.finalize().into_bytes()); - - let request = Request::builder() - .method(Method::POST) - .uri("/test") - .header(X_TIMESTAMP, ×tamp) - .header(X_HMAC_SIGNATURE, &signature) - .body(Body::from(body)) - .unwrap(); - - let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - } - - #[tokio::test] - async fn test_hmac_auth_disabled_method_rejected() { + async fn test_hmac_auth_liveness_bypass() { let secret = "test-secret"; - // Only allow liveness, not getConfig - let allowed_methods = vec!["liveness".to_string()]; - let layer = - HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE, allowed_methods); + let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE); let mut service = layer.layer(MockService); - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - .to_string(); - - let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#; - let message = format!("{timestamp}{body}"); - - let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); - mac.update(message.as_bytes()); - let signature = hex::encode(mac.finalize().into_bytes()); - + let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#; let request = Request::builder() .method(Method::POST) - .uri("/test") - .header(X_TIMESTAMP, ×tamp) - .header(X_HMAC_SIGNATURE, &signature) - .body(Body::from(body)) + .uri("/") + .body(Body::from(liveness_body)) .unwrap(); let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(response.status(), StatusCode::OK); } } diff --git a/crates/lib/src/rpc_server/middleware_utils.rs b/crates/lib/src/rpc_server/middleware_utils.rs index 36dfc42d..bc2420ae 100644 --- a/crates/lib/src/rpc_server/middleware_utils.rs +++ b/crates/lib/src/rpc_server/middleware_utils.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use futures_util::TryStreamExt; -use http::Request; +use http::{Request, Response, StatusCode}; use jsonrpsee::server::logger::Body; use crate::KoraError; @@ -43,3 +43,169 @@ pub fn verify_jsonrpc_method( } Err(KoraError::InvalidRequest("Method not allowed".to_string())) } + +/// Method validation layer - applies first in middleware stack to fail fast +#[derive(Clone)] +pub struct MethodValidationLayer { + allowed_methods: HashSet, +} + +impl MethodValidationLayer { + pub fn new(allowed_methods: Vec) -> Self { + Self { allowed_methods: allowed_methods.into_iter().collect() } + } +} + +#[derive(Clone)] +pub struct MethodValidationService { + inner: S, + allowed_methods: HashSet, +} + +impl tower::Layer for MethodValidationLayer { + type Service = MethodValidationService; + + fn layer(&self, inner: S) -> Self::Service { + MethodValidationService { inner, allowed_methods: self.allowed_methods.clone() } + } +} + +impl tower::Service> for MethodValidationService +where + S: tower::Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let allowed_methods = self.allowed_methods.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + let (parts, body_bytes) = extract_parts_and_body_bytes(request).await; + + match verify_jsonrpc_method(&body_bytes, &allowed_methods) { + Ok(_) => {} + Err(_) => { + // Method not allowed + let method_not_allowed_response = Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(Body::empty()) + .expect("Failed to build METHOD_NOT_ALLOWED response"); + return Ok(method_not_allowed_response); + } + } + + let new_body = Body::from(body_bytes); + let new_request = Request::from_parts(parts, new_body); + inner.call(new_request).await + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::Method; + use std::{ + future::Ready, + task::{Context, Poll}, + }; + use tower::{Layer, Service, ServiceExt}; + + // Mock service that always returns OK + #[derive(Clone)] + struct MockService; + + impl tower::Service> for MockService { + type Response = Response; + type Error = std::convert::Infallible; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Request) -> Self::Future { + std::future::ready(Ok(Response::builder().status(200).body(Body::empty()).unwrap())) + } + } + + #[tokio::test] + async fn test_method_validation_disallowed_method() { + let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; + let layer = MethodValidationLayer::new(allowed_methods); + let mut service = layer.layer(MockService); + + let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#; + let request = + Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[tokio::test] + async fn test_method_validation_malformed_json() { + let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; + let layer = MethodValidationLayer::new(allowed_methods); + let mut service = layer.layer(MockService); + + let body = r#"{"invalid json"#; + let request = + Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[tokio::test] + async fn test_method_validation_missing_method_field() { + let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()]; + let layer = MethodValidationLayer::new(allowed_methods); + let mut service = layer.layer(MockService); + + let body = r#"{"jsonrpc":"2.0","id":1}"#; + let request = + Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[tokio::test] + async fn test_method_validation_multiple_allowed_methods() { + let allowed_methods = vec![ + "liveness".to_string(), + "getConfig".to_string(), + "signTransaction".to_string(), + "estimateTransactionFee".to_string(), + ]; + let layer = MethodValidationLayer::new(allowed_methods); + let mut service = layer.layer(MockService); + + // Test each allowed method + for method in &["liveness", "getConfig", "signTransaction", "estimateTransactionFee"] { + let body = format!(r#"{{"jsonrpc":"2.0","method":"{}","id":1}}"#, method); + let request = Request::builder() + .method(Method::POST) + .uri("/test") + .body(Body::from(body)) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK, "Method {} should be allowed", method); + } + } +} diff --git a/crates/lib/src/rpc_server/server.rs b/crates/lib/src/rpc_server/server.rs index 5e1b4511..64740ff1 100644 --- a/crates/lib/src/rpc_server/server.rs +++ b/crates/lib/src/rpc_server/server.rs @@ -3,6 +3,7 @@ use crate::{ metrics::run_metrics_server_if_required, rpc_server::{ auth::{ApiKeyAuthLayer, HmacAuthLayer}, + middleware_utils::MethodValidationLayer, rpc::KoraRpc, }, usage_limit::UsageTracker, @@ -76,24 +77,19 @@ pub async fn run_rpc_server(rpc: KoraRpc, port: u16) -> Result("liveness", rpc_params![]).await; assert!(result.is_err()); - assert!(result.err().unwrap().to_string().contains("Method not found")); + let error_msg = result.err().unwrap().to_string(); + // The error should be HTTP 405 (caught by MethodValidationLayer middleware) + assert!(error_msg.contains("405"), "Expected 405 METHOD_NOT_ALLOWED, got: {}", error_msg); }