diff --git a/sgl-model-gateway/src/middleware.rs b/sgl-model-gateway/src/middleware.rs index 5cd188d7ec01..fbad1188adc5 100644 --- a/sgl-model-gateway/src/middleware.rs +++ b/sgl-model-gateway/src/middleware.rs @@ -612,7 +612,8 @@ pub async fn wasm_middleware( let method = request.method().clone(); let uri = request.uri().clone(); let mut headers = request.headers().clone(); - let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await { + let max_body_size = wasm_manager.get_max_body_size(); + let body_bytes = match axum::body::to_bytes(request.into_body(), max_body_size).await { Ok(bytes) => bytes.to_vec(), Err(e) => { error!("Failed to read request body: {}", e); @@ -708,7 +709,7 @@ pub async fn wasm_middleware( // Extract response data once before processing modules let mut status = response.status(); let mut headers = response.headers().clone(); - let mut body_bytes = match axum::body::to_bytes(response.into_body(), usize::MAX).await { + let mut body_bytes = match axum::body::to_bytes(response.into_body(), max_body_size).await { Ok(bytes) => bytes.to_vec(), Err(e) => { error!("Failed to read response body: {}", e); diff --git a/sgl-model-gateway/src/wasm/config.rs b/sgl-model-gateway/src/wasm/config.rs index 6954b57554c6..edc75092799c 100644 --- a/sgl-model-gateway/src/wasm/config.rs +++ b/sgl-model-gateway/src/wasm/config.rs @@ -17,6 +17,8 @@ pub struct WasmRuntimeConfig { pub thread_pool_size: usize, /// Maximum number of modules to cache per worker pub module_cache_size: usize, + /// Maximum HTTP body size in bytes for middleware request/response processing + pub max_body_size: usize, } impl Default for WasmRuntimeConfig { @@ -32,6 +34,7 @@ impl Default for WasmRuntimeConfig { max_stack_size: 1024 * 1024, // 1MB thread_pool_size: default_thread_pool_size, // based on cpu count module_cache_size: 10, // Cache up to 10 modules per worker + max_body_size: 10 * 1024 * 1024, // 10MB } } } @@ -82,6 +85,14 @@ impl WasmRuntimeConfig { return Err("module_cache_size cannot exceed 1000".to_string()); } + // Validate max_body_size + if self.max_body_size == 0 { + return Err("max_body_size cannot be 0".to_string()); + } + if self.max_body_size > 100 * 1024 * 1024 { + return Err("max_body_size cannot exceed 100MB".to_string()); + } + Ok(()) } @@ -92,6 +103,7 @@ impl WasmRuntimeConfig { max_stack_size: usize, thread_pool_size: usize, module_cache_size: usize, + max_body_size: usize, ) -> Result { let config = Self { max_memory_pages, @@ -99,6 +111,7 @@ impl WasmRuntimeConfig { max_stack_size, thread_pool_size, module_cache_size, + max_body_size, }; config.validate()?; Ok(config) @@ -122,7 +135,7 @@ mod tests { #[test] fn test_config_new_with_validation() { - let config = WasmRuntimeConfig::new(1024, 1000, 1024 * 1024, 2, 10); + let config = WasmRuntimeConfig::new(1024, 1000, 1024 * 1024, 2, 10, 10 * 1024 * 1024); assert!(config.is_ok()); } @@ -134,6 +147,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -148,6 +162,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -164,6 +179,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -180,6 +196,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -196,6 +213,7 @@ mod tests { max_stack_size: 32 * 1024, // Less than 64KB thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -212,6 +230,7 @@ mod tests { max_stack_size: 17 * 1024 * 1024, // Exceeds 16MB thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -228,6 +247,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 0, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -242,6 +262,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 129, // Exceeds 128 module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -258,6 +279,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 0, + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -274,6 +296,7 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 1001, // Exceeds 1000 + max_body_size: 10 * 1024 * 1024, }; let result = config.validate(); assert!(result.is_err()); @@ -290,8 +313,41 @@ mod tests { max_stack_size: 1024 * 1024, thread_pool_size: 2, module_cache_size: 10, + max_body_size: 10 * 1024 * 1024, }; // 1024 pages * 64KB = 64MB assert_eq!(config.get_total_memory_bytes(), 64 * 1024 * 1024); } + + #[test] + fn test_validation_max_body_size_zero() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + max_body_size: 0, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("max_body_size cannot be 0")); + } + + #[test] + fn test_validation_max_body_size_too_large() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + max_body_size: 101 * 1024 * 1024, // Exceeds 100MB + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("max_body_size cannot exceed 100MB")); + } } diff --git a/sgl-model-gateway/src/wasm/module_manager.rs b/sgl-model-gateway/src/wasm/module_manager.rs index 7fc9876e13b8..a6c5ec273c91 100644 --- a/sgl-model-gateway/src/wasm/module_manager.rs +++ b/sgl-model-gateway/src/wasm/module_manager.rs @@ -128,6 +128,11 @@ impl WasmModuleManager { &self.runtime } + /// Get the configured maximum body size for HTTP request/response processing + pub fn get_max_body_size(&self) -> usize { + self.runtime.get_config().max_body_size + } + /// Execute WASM module using WebAssembly component model based on attach_point pub async fn execute_module_interface( &self,