diff --git a/Cargo.lock b/Cargo.lock index 0f042d2..5e50a60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,6 +62,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "async-stream" version = "0.3.5" @@ -673,6 +679,20 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -758,6 +778,36 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "deadpool" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6541a3916932fe57768d4be0b1ffb5ec7cbf74ca8c903fdfd5c0fe8aa958f0ed" +dependencies = [ + "deadpool-runtime", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-redis" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfae6799b68a735270e4344ee3e834365f707c72da09c9a8bb89b45cc3351395" +dependencies = [ + "deadpool", + "redis", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1567,6 +1617,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1591,6 +1651,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "num_enum" version = "0.7.3" @@ -1959,6 +2029,29 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redis" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7e86f5670bd8b028edfb240f0616cad620705b31ec389d55e4f3da2c38dcd48" +dependencies = [ + "arc-swap", + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "native-tls", + "num-bigint", + "percent-encoding", + "pin-project-lite", + "ryu", + "tokio", + "tokio-native-tls", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -3106,8 +3199,12 @@ dependencies = [ "axum", "chrono", "chrono-tz", + "deadpool", + "deadpool-redis", "lambda_http", + "lazy_static", "openssl", + "redis", "reqwest", "serde", "serde_dynamo", diff --git a/Cargo.toml b/Cargo.toml index 5ba4435..82e496b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -exclude = ["worker-infra", "server-infra"] +exclude = ["worker-infra", "server-infra", "target"] members = ["server", "wh-core", "worker", "message-handler"] resolver = "2" diff --git a/server/Cargo.toml b/server/Cargo.toml index 74d0f41..4ddf295 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -29,4 +29,12 @@ openssl = { version = "0.10.66", features = ["vendored"] } wh-core = { path = "../wh-core" } +deadpool-redis = { version = "0.18.0", features = ["rt_tokio_1"] } +redis = { version = "0.27.2", default-features = false, features = [ + "tls", + "tokio-native-tls-comp", +] } + tracing = "0.1.40" +deadpool = { version = "0.12.1", features = ["managed"] } +lazy_static = "1.5.0" diff --git a/server/src/main.rs b/server/src/main.rs index e445ce1..9fb92c0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,29 +1,58 @@ -use std::env::set_var; +use std::env::{self, set_var}; +use std::sync::Arc; use axum::Router; use lambda_http::tower::ServiceBuilder; use lambda_http::{run, Error}; +use deadpool::managed::{PoolConfig, QueueMode, Timeouts}; +use deadpool_redis::{Config, Pool, Runtime}; + +use lazy_static::lazy_static; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +use crate::rate_limit::RateLimit; use crate::v2::handler as waterheater_calc; use crate::v2::router::v2_routes; mod common; mod http; mod middleware; -mod rate_limiter; +mod rate_limit; mod tests; mod v2; +lazy_static! { + static ref REDIS_POOL: Arc = Arc::new(create_redis_pool()); +} + #[derive(Clone)] struct AppState { + pub redis_pool: Arc, dynamo_client: aws_sdk_dynamodb::Client, } +fn create_redis_pool() -> Pool { + let redis_endpoint = env::var("REDIS_ENDPOINT").unwrap_or("http://localhost".into()); + let redis_url = format!("redis://{}", redis_endpoint); + + let cfg = Config { + connection: None, + url: Some(redis_url), + pool: Some(PoolConfig { + max_size: 10, + timeouts: Timeouts::default(), + queue_mode: QueueMode::Fifo, + }), + }; + + cfg.create_pool(Some(Runtime::Tokio1)).unwrap() +} + #[tokio::main] async fn main() -> Result<(), Error> { + set_var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH", "true"); tracing_subscriber::fmt() .json() .with_max_level(tracing::Level::INFO) @@ -37,11 +66,10 @@ async fn main() -> Result<(), Error> { let client = aws_sdk_dynamodb::Client::new(&config); let state = AppState { + redis_pool: REDIS_POOL.clone(), dynamo_client: client, }; - set_var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH", "true"); - #[derive(OpenApi)] #[openapi( paths(waterheater_calc::handle_enable_water_heater), @@ -65,7 +93,7 @@ async fn main() -> Result<(), Error> { .layer(axum::middleware::from_fn(middleware::inject_connect_info)) .layer(axum::middleware::from_fn_with_state( state.clone(), - rate_limiter::rate_limit, + RateLimit::rate_limit, )), ) .with_state(state); diff --git a/server/src/rate_limit.rs b/server/src/rate_limit.rs index cba76aa..b08f835 100644 --- a/server/src/rate_limit.rs +++ b/server/src/rate_limit.rs @@ -5,9 +5,9 @@ use axum::{ middleware::Next, response::Response, }; - -use redis::aio::Connection; +use deadpool_redis::Connection; use redis::RedisError; + use std::env; use tracing::{error, info}; @@ -31,33 +31,29 @@ impl RateLimit { let client_ip = Self::extract_client_ip(&request).unwrap_or("unknown"); let capacity = 20; // Maximum 20 requests allowed - let refill_rate_per_millisecond = 20.0 / 60_000.0; // Tokens per millisecond + let refill_rate = 20.0 / 60.0; // Refill rate: 20 tokens per minute let current_time = chrono::Utc::now().timestamp_millis() as f64 / 1000.0; - let start_time = std::time::Instant::now(); - let mut conn = match state.redis_pool.get_async_connection().await { + let mut conn = match state.redis_pool.get().await { Ok(c) => c, Err(e) => { error!("Failed to connect to Redis: {}", e); + // Allow the request if Redis is unavailable return Ok(next.run(request).await); } }; - let conn_acquisition_time = start_time.elapsed(); - info!( - "Redis connection acquisition time: {:?}", - conn_acquisition_time - ); - - let allowed = Self::check_rate_limit( - &mut conn, - client_ip, - capacity, - refill_rate_per_millisecond, - current_time, - ) - .await - .unwrap_or(true); + + let allowed = + match Self::check_rate_limit(&mut conn, client_ip, capacity, refill_rate, current_time) + .await + { + Ok(val) => val, + Err(e) => { + error!("Rate limit check failed: {}", e); + true + } + }; if allowed { Ok(next.run(request).await) @@ -89,25 +85,33 @@ impl RateLimit { let script = r#" local key = KEYS[1] local capacity = tonumber(ARGV[1]) - local refill_time = tonumber(ARGV[2]) + local refill_rate = tonumber(ARGV[2]) local current_time = tonumber(ARGV[3]) - local data = redis.call("GET", key) - local tokens = tonumber(data) + local data = redis.call("HMGET", key, "tokens", "last_refill") + local tokens = tonumber(data[1]) + local last_refill = tonumber(data[2]) + if tokens == nil then - tokens = capacity - 1 - redis.call("SETEX", key, refill_time, tokens) - return 1 + tokens = capacity + last_refill = current_time end - if tokens > 0 then + local delta = current_time - last_refill + local tokens_to_add = delta * refill_rate + tokens = math.min(tokens + tokens_to_add, capacity) + last_refill = current_time + + local allowed = 0 + if tokens >= 1 then + allowed = 1 tokens = tokens - 1 - redis.call("SET", key, tokens) - redis.call("EXPIRE", key, refill_time) - return 1 - else - return 0 end + + redis.call("HMSET", key, "tokens", tokens, "last_refill", last_refill) + redis.call("EXPIRE", key, 3600) + + return allowed "#; let allowed: i32 = redis::cmd("EVAL") diff --git a/server/src/rate_limiter.rs b/server/src/rate_limiter.rs index 5afe679..5153bb3 100644 --- a/server/src/rate_limiter.rs +++ b/server/src/rate_limiter.rs @@ -43,25 +43,25 @@ async fn is_rate_limited( dynamodb_client: &DynamoDbClient, client_id: &str, ) -> Result { - // Rate limiting constants let max_requests = 20; let window_seconds = 60; - let now = Utc::now(); - let window_start = now.timestamp() - (now.timestamp() % window_seconds); - - // Set the expiration time for the item - let expires_at = now + Duration::seconds(window_seconds * 2); + let now = Utc::now().timestamp(); + let expires_at = now + window_seconds; + let window_start = now; + // Prepare UpdateExpression let update_expression = "\ SET \ - #request_count = if_not_exists(#request_count, :start) + :inc, \ - #window_start = :window_start, \ + #request_count = if_not_exists(#request_count, :start_count) + :inc,\ + #window_start = :window_start,\ #expires_at = :expires_at"; - let condition_expression = - "attribute_not_exists(#request_count) OR #request_count <= :max_requests"; + let condition_expression = "\ + (attribute_not_exists(#expires_at) OR #expires_at <= :now OR #request_count < :max_requests)\ + AND (#window_start = :window_start OR #expires_at <= :now)"; + // Perform the UpdateItem operation let result = dynamodb_client .update_item() .table_name("waterheater_calc_rate_limits") @@ -71,14 +71,11 @@ async fn is_rate_limited( .expression_attribute_names("#request_count", "request_count") .expression_attribute_names("#window_start", "window_start") .expression_attribute_names("#expires_at", "expires_at") - .expression_attribute_values(":inc", AttributeValue::N(1.to_string())) - .expression_attribute_values(":start", AttributeValue::N(0.to_string())) + .expression_attribute_values(":start_count", AttributeValue::N("0".to_string())) + .expression_attribute_values(":inc", AttributeValue::N("1".to_string())) .expression_attribute_values(":window_start", AttributeValue::N(window_start.to_string())) + .expression_attribute_values(":expires_at", AttributeValue::N(expires_at.to_string())) .expression_attribute_values(":max_requests", AttributeValue::N(max_requests.to_string())) - .expression_attribute_values( - ":expires_at", - AttributeValue::N(expires_at.timestamp().to_string()), - ) .send() .await;