diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 903877c4d2e..888259110bc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -39,5 +39,7 @@ /sgl-router/src/routers @CatherineSue @key4ng @slin1237 /sgl-router/src/tokenizer @slin1237 @CatherineSue /sgl-router/src/tool_parser @slin1237 @CatherineSue +/sgl-router/src/wasm @tonyluj +/sgl-router/examples/wasm @tonyluj /test/srt/test_modelopt* @Edwardf0t1 /test/srt/ascend @ping1jing2 diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index c980f99139e..760a844cad9 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -95,6 +95,12 @@ backoff = { version = "0.4", features = ["tokio"] } strum = { version = "0.26", features = ["derive"] } once_cell = "1.21.3" +# wasm dependencies +sha2 = "0.10" +wasmtime = { version = "38.0", features = ["component-model", "async"] } +wasmtime-wasi = "38.0" +async-channel = "2.5" + [build-dependencies] tonic-prost-build = "0.14.2" prost-build = "0.14.1" diff --git a/sgl-router/examples/wasm/.gitignore b/sgl-router/examples/wasm/.gitignore new file mode 100644 index 00000000000..5fc28f7c44c --- /dev/null +++ b/sgl-router/examples/wasm/.gitignore @@ -0,0 +1,28 @@ +# Rust build artifacts +target/ +**/target/ + +# Cargo lock files (examples don't need locked dependencies) +Cargo.lock +**/Cargo.lock + +# Generated WASM files +*.wasm +*.component.wasm +**/*.wasm +**/*.component.wasm + +# Build scripts output +build/ + +# IDE files +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + diff --git a/sgl-router/examples/wasm/README.md b/sgl-router/examples/wasm/README.md new file mode 100644 index 00000000000..2c0f185ab18 --- /dev/null +++ b/sgl-router/examples/wasm/README.md @@ -0,0 +1,103 @@ +# WASM Guest Examples for sgl-router + +This directory contains example WASM middleware components demonstrating how to implement custom middleware for sgl-router using the WebAssembly Component Model (WIT). + +## Examples Overview + +### [wasm-guest-auth](./wasm-guest-auth/) + +API key authentication middleware that validates API keys for requests to `/api` and `/v1` paths. + +**Features:** +- Validates API keys from `Authorization` header or `x-api-key` header +- Returns `401 Unauthorized` for missing or invalid keys +- Attach point: `OnRequest` only + +**Use case:** Protect API endpoints with API key authentication. + +### [wasm-guest-logging](./wasm-guest-logging/) + +Request tracking and status code conversion middleware. + +**Features:** +- Adds tracking headers (`x-request-id`, `x-wasm-processed`, `x-processed-at`, `x-api-route`) +- Converts `500` errors to `503` for better client handling +- Attach points: `OnRequest` and `OnResponse` + +**Use case:** Request tracing and error status code conversion. + +### [wasm-guest-ratelimit](./wasm-guest-ratelimit/) + +Rate limiting middleware with configurable limits. + +**Features:** +- Rate limiting per identifier (API Key, IP, or Request ID) +- Default: 60 requests per minute +- Returns `429 Too Many Requests` when limit exceeded +- Attach point: `OnRequest` only + +**Note:** This is a simplified demonstration with per-instance state. For production, use router-level rate limiting with shared state. + +**Use case:** Protect against request flooding and abuse. + +## Quick Start + +Each example includes its own README with detailed build and deployment instructions. See individual example directories for: + +- Build instructions +- Deployment configuration +- Customization options +- Testing examples + +## Common Prerequisites + +All examples require: + +- Rust toolchain (latest stable) +- `wasm32-wasip2` target: `rustup target add wasm32-wasip2` +- `wasm-tools`: `cargo install wasm-tools` +- sgl-router running with WASM enabled (`--enable-wasm`) + +## Building All Examples + +```bash +cd examples/wasm +for example in wasm-guest-auth wasm-guest-logging wasm-guest-ratelimit; do + echo "Building $example..." + cd $example && ./build.sh && cd .. +done +``` + +## Deploying Multiple Modules + +You can deploy all three modules together: + +```bash +curl -X POST http://localhost:3000/wasm \ + -H "Content-Type: application/json" \ + -d '{ + "modules": [ + { + "name": "auth-middleware", + "file_path": "/path/to/wasm_guest_auth.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}] + }, + { + "name": "logging-middleware", + "file_path": "/path/to/wasm_guest_logging.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}, {"Middleware": "OnResponse"}] + }, + { + "name": "ratelimit-middleware", + "file_path": "/path/to/wasm_guest_ratelimit.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}] + } + ] + }' +``` + +Modules execute in the order they are deployed. If a module returns `Reject`, subsequent modules won't execute. + diff --git a/sgl-router/examples/wasm/wasm-guest-auth/Cargo.toml b/sgl-router/examples/wasm/wasm-guest-auth/Cargo.toml new file mode 100644 index 00000000000..7b5027681b8 --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-auth/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "wasm-guest-auth" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = { version = "0.21", features = ["macros"] } + diff --git a/sgl-router/examples/wasm/wasm-guest-auth/README.md b/sgl-router/examples/wasm/wasm-guest-auth/README.md new file mode 100644 index 00000000000..905a03ca414 --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-auth/README.md @@ -0,0 +1,62 @@ +# WASM Auth Example for sgl-router + +This example demonstrates API key authentication middleware for sgl-router using the WebAssembly Component Model (WIT). + +## Overview + +This middleware validates API keys for requests to `/api` and `/v1` paths: + +- Supports `Authorization: Bearer ` header +- Supports `Authorization: ApiKey ` header +- Supports `x-api-key` header +- Returns `401 Unauthorized` for missing or invalid keys + +**Default API Key**: `secret-api-key-12345` + +## Quick Start + +### Build and Deploy + +```bash +# Build +cd examples/wasm-guest-auth +./build.sh + +# Deploy (replace file_path with actual path) +curl -X POST http://localhost:3000/wasm \ + -H "Content-Type: application/json" \ + -d '{ + "modules": [{ + "name": "auth-middleware", + "file_path": "/absolute/path/to/wasm_guest_auth.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}] + }] + }' +``` + +### Customization + +Modify `EXPECTED_API_KEY` in `src/lib.rs`: + +```rust +const EXPECTED_API_KEY: &str = "your-secret-key"; +``` + +## Testing + +```bash +# Test unauthorized (returns 401) +curl -v http://localhost:3000/api/test + +# Test authorized (passes) +curl -v http://localhost:3000/api/test \ + -H "Authorization: Bearer secret-api-key-12345" +``` + +## Troubleshooting + +- Verify API key matches `EXPECTED_API_KEY` in code +- Check request header format and path (`/api` or `/v1`) +- Verify module is attached to `OnRequest` phase +- Check router logs for errors diff --git a/sgl-router/examples/wasm/wasm-guest-auth/build.sh b/sgl-router/examples/wasm/wasm-guest-auth/build.sh new file mode 100755 index 00000000000..07e42eda9cc --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-auth/build.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Build script for WASM guest auth example +# This script simplifies the build process for the WASM middleware component + +set -e + +echo "Building WASM guest auth example..." + +# Check if we're in the right directory +if [ ! -f "Cargo.toml" ]; then + echo "Error: Cargo.toml not found. Please run this script from the wasm-guest-auth directory." + exit 1 +fi + +# Check for required tools +command -v cargo >/dev/null 2>&1 || { echo "Error: cargo is required but not installed. Aborting." >&2; exit 1; } + +# Check and install wasm32-wasip2 target +echo "Checking for wasm32-wasip2 target..." +if ! rustup target list --installed | grep -q "wasm32-wasip2"; then + echo "wasm32-wasip2 target not found. Installing..." + rustup target add wasm32-wasip2 + echo "✓ wasm32-wasip2 target installed" +else + echo "✓ wasm32-wasip2 target already installed" +fi + +# Check for wasm-tools +if ! command -v wasm-tools >/dev/null 2>&1; then + echo "Error: wasm-tools is required but not installed." + echo "Install it with: cargo install wasm-tools" + exit 1 +fi + +# Build with cargo (wit-bindgen uses cargo, not wasm-pack) +echo "Running cargo build..." +cargo build --target wasm32-wasip2 --release + +# Output locations +WASM_MODULE="target/wasm32-wasip2/release/wasm_guest_auth.wasm" +WASM_COMPONENT="target/wasm32-wasip2/release/wasm_guest_auth.component.wasm" + +if [ ! -f "$WASM_MODULE" ]; then + echo "Error: Build failed - WASM module not found" + exit 1 +fi + +# Check if the file is already a component +echo "Checking WASM file format..." +if wasm-tools print "$WASM_MODULE" 2>/dev/null | grep -q "^(\s*component"; then + echo "✓ WASM file is already in component format" + # Copy to component path for consistency + cp "$WASM_MODULE" "$WASM_COMPONENT" +else + # Wrap the WASM module into a component format + echo "Wrapping WASM module into component format..." + wasm-tools component new "$WASM_MODULE" -o "$WASM_COMPONENT" + if [ ! -f "$WASM_COMPONENT" ]; then + echo "Error: Failed to create component file" + exit 1 + fi +fi + +if [ -f "$WASM_COMPONENT" ]; then + echo "" + echo "✓ Build successful!" + echo " WASM module: $WASM_MODULE" + echo " WASM component: $WASM_COMPONENT" + echo "" + echo "Next steps:" + echo "1. Use the component file ($WASM_COMPONENT) when adding the module" + echo "2. Prepare the module configuration (see README.md for JSON format)" + echo "3. Use the API endpoint to add the module (see README.md for details)" +else + echo "Error: Component file not found" + exit 1 +fi + diff --git a/sgl-router/examples/wasm/wasm-guest-auth/src/lib.rs b/sgl-router/examples/wasm/wasm-guest-auth/src/lib.rs new file mode 100644 index 00000000000..e381b415d7a --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-auth/src/lib.rs @@ -0,0 +1,66 @@ +//! WASM Guest Auth Example for sgl-router +//! +//! This example demonstrates API key authentication middleware +//! for sgl-router using the WebAssembly Component Model (WIT). +//! +//! Features: +//! - API Key authentication + +wit_bindgen::generate!({ + path: "../../../src/wasm/wit", + world: "sgl-router", +}); + +use exports::sgl::router::middleware_on_request::Guest as OnRequestGuest; +use exports::sgl::router::middleware_on_response::Guest as OnResponseGuest; +use sgl::router::middleware_types::{Request, Response, Action}; + +/// Expected API Key (in production, this should be passed as configuration) +const EXPECTED_API_KEY: &str = "secret-api-key-12345"; + +/// Main middleware implementation +struct Middleware; + +// Helper function to find header value +fn find_header_value(headers: &[sgl::router::middleware_types::Header], name: &str) -> Option { + headers + .iter() + .find(|h| h.name.eq_ignore_ascii_case(name)) + .map(|h| h.value.clone()) +} + +// Implement on-request interface +impl OnRequestGuest for Middleware { + fn on_request(req: Request) -> Action { + // API Key Authentication + // Check for API key in Authorization header for /api routes + if req.path.starts_with("/api") || req.path.starts_with("/v1") { + let api_key = find_header_value(&req.headers, "authorization") + .and_then(|h| { + h.strip_prefix("Bearer ") + .or_else(|| h.strip_prefix("ApiKey ")) + .map(|s| s.to_string()) + }) + .or_else(|| find_header_value(&req.headers, "x-api-key")); + + // Reject if API key is missing or invalid + if api_key.as_deref() != Some(EXPECTED_API_KEY) { + return Action::Reject(401); + } + } + + // Authentication passed, continue processing + Action::Continue + } +} + +// Implement on-response interface (empty - not used for auth) +impl OnResponseGuest for Middleware { + fn on_response(_resp: Response) -> Action { + Action::Continue + } +} + +// Export the component +export!(Middleware); + diff --git a/sgl-router/examples/wasm/wasm-guest-logging/Cargo.toml b/sgl-router/examples/wasm/wasm-guest-logging/Cargo.toml new file mode 100644 index 00000000000..f6351a1d7f5 --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-logging/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "wasm-guest-logging" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = { version = "0.21", features = ["macros"] } + diff --git a/sgl-router/examples/wasm/wasm-guest-logging/README.md b/sgl-router/examples/wasm/wasm-guest-logging/README.md new file mode 100644 index 00000000000..38b18d69e3b --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-logging/README.md @@ -0,0 +1,53 @@ +# WASM Logging Example for sgl-router + +This example demonstrates logging and tracing middleware for sgl-router using the WebAssembly Component Model (WIT). + +## Overview + +This middleware provides: + +- **Request Tracking** - Adds tracking headers (`x-request-id`, `x-wasm-processed`, `x-processed-at`, `x-api-route`) +- **Status Code Conversion** - Converts `500` errors to `503` + +## Quick Start + +### Build and Deploy + +```bash +# Build +cd examples/wasm-guest-logging +./build.sh + +# Deploy (replace file_path with actual path) +curl -X POST http://localhost:3000/wasm \ + -H "Content-Type: application/json" \ + -d '{ + "modules": [{ + "name": "logging-middleware", + "file_path": "/absolute/path/to/wasm_guest_logging.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}, {"Middleware": "OnResponse"}] + }] + }' +``` + +### Customization + +Modify `on_request` or `on_response` functions in `src/lib.rs` to add custom tracking headers or status code conversions. + +## Testing + +```bash +# Check tracking headers +curl -v http://localhost:3000/v1/models 2>&1 | \ + grep -E "(x-request-id|x-wasm-processed|x-processed-at)" + +# Test status code conversion (requires endpoint returning 500) +curl -v http://localhost:3000/some-endpoint 2>&1 | grep -E "(< HTTP|500|503)" +``` + +## Troubleshooting + +- Verify module attached to both `OnRequest` and `OnResponse` phases +- Check router logs for execution errors +- Ensure module built successfully diff --git a/sgl-router/examples/wasm/wasm-guest-logging/build.sh b/sgl-router/examples/wasm/wasm-guest-logging/build.sh new file mode 100755 index 00000000000..ffbca21d2c0 --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-logging/build.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Build script for WASM guest logging example +# This script simplifies the build process for the WASM middleware component + +set -e + +echo "Building WASM guest logging example..." + +# Check if we're in the right directory +if [ ! -f "Cargo.toml" ]; then + echo "Error: Cargo.toml not found. Please run this script from the wasm-guest-logging directory." + exit 1 +fi + +# Check for required tools +command -v cargo >/dev/null 2>&1 || { echo "Error: cargo is required but not installed. Aborting." >&2; exit 1; } + +# Check and install wasm32-wasip2 target +echo "Checking for wasm32-wasip2 target..." +if ! rustup target list --installed | grep -q "wasm32-wasip2"; then + echo "wasm32-wasip2 target not found. Installing..." + rustup target add wasm32-wasip2 + echo "✓ wasm32-wasip2 target installed" +else + echo "✓ wasm32-wasip2 target already installed" +fi + +# Check for wasm-tools +if ! command -v wasm-tools >/dev/null 2>&1; then + echo "Error: wasm-tools is required but not installed." + echo "Install it with: cargo install wasm-tools" + exit 1 +fi + +# Build with cargo (wit-bindgen uses cargo, not wasm-pack) +echo "Running cargo build..." +cargo build --target wasm32-wasip2 --release + +# Output locations +WASM_MODULE="target/wasm32-wasip2/release/wasm_guest_logging.wasm" +WASM_COMPONENT="target/wasm32-wasip2/release/wasm_guest_logging.component.wasm" + +if [ ! -f "$WASM_MODULE" ]; then + echo "Error: Build failed - WASM module not found" + exit 1 +fi + +# Check if the file is already a component +echo "Checking WASM file format..." +if wasm-tools print "$WASM_MODULE" 2>/dev/null | grep -q "^(\s*component"; then + echo "✓ WASM file is already in component format" + # Copy to component path for consistency + cp "$WASM_MODULE" "$WASM_COMPONENT" +else + # Wrap the WASM module into a component format + echo "Wrapping WASM module into component format..." + wasm-tools component new "$WASM_MODULE" -o "$WASM_COMPONENT" + if [ ! -f "$WASM_COMPONENT" ]; then + echo "Error: Failed to create component file" + exit 1 + fi +fi + +if [ -f "$WASM_COMPONENT" ]; then + echo "" + echo "✓ Build successful!" + echo " WASM module: $WASM_MODULE" + echo " WASM component: $WASM_COMPONENT" + echo "" + echo "Next steps:" + echo "1. Use the component file ($WASM_COMPONENT) when adding the module" + echo "2. Prepare the module configuration (see README.md for JSON format)" + echo "3. Use the API endpoint to add the module (see README.md for details)" +else + echo "Error: Component file not found" + exit 1 +fi + diff --git a/sgl-router/examples/wasm/wasm-guest-logging/src/lib.rs b/sgl-router/examples/wasm/wasm-guest-logging/src/lib.rs new file mode 100644 index 00000000000..47e5719a66b --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-logging/src/lib.rs @@ -0,0 +1,90 @@ +//! WASM Guest Logging Example for sgl-router +//! +//! This example demonstrates logging and tracing middleware +//! for sgl-router using the WebAssembly Component Model (WIT). +//! +//! Features: +//! - Request tracking and tracing headers +//! - Response status code conversion + +wit_bindgen::generate!({ + path: "../../../src/wasm/wit", + world: "sgl-router", +}); + +use exports::sgl::router::middleware_on_request::Guest as OnRequestGuest; +use exports::sgl::router::middleware_on_response::Guest as OnResponseGuest; +use sgl::router::middleware_types::{Request, Response, Action, Header, ModifyAction}; + +/// Main middleware implementation +struct Middleware; + +// Helper function to create header +fn create_header(name: &str, value: &str) -> Header { + Header { + name: name.to_string(), + value: value.to_string(), + } +} + +// Implement on-request interface +impl OnRequestGuest for Middleware { + fn on_request(req: Request) -> Action { + let mut modify_action = ModifyAction { + status: None, + headers_set: vec![], + headers_add: vec![], + headers_remove: vec![], + body_replace: None, + }; + + // Request Logging and Tracing + // Add tracing headers with request ID + modify_action.headers_add.push(create_header( + "x-request-id", + &req.request_id, + )); + modify_action.headers_add.push(create_header( + "x-wasm-processed", + "true", + )); + modify_action.headers_add.push(create_header( + "x-processed-at", + &req.now_epoch_ms.to_string(), + )); + + // Add custom header for API requests + if req.path.starts_with("/api") || req.path.starts_with("/v1") { + modify_action.headers_add.push(create_header( + "x-api-route", + "true", + )); + } + + Action::Modify(modify_action) + } +} + +// Implement on-response interface +impl OnResponseGuest for Middleware { + fn on_response(resp: Response) -> Action { + // Status code conversion: Convert 500 to 503 for better client handling + if resp.status == 500 { + let modify_action = ModifyAction { + status: Some(503), + headers_set: vec![], + headers_add: vec![], + headers_remove: vec![], + body_replace: None, + }; + Action::Modify(modify_action) + } else { + // No modification needed + Action::Continue + } + } +} + +// Export the component +export!(Middleware); + diff --git a/sgl-router/examples/wasm/wasm-guest-ratelimit/Cargo.toml b/sgl-router/examples/wasm/wasm-guest-ratelimit/Cargo.toml new file mode 100644 index 00000000000..087e7551569 --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-ratelimit/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "wasm-guest-ratelimit" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = { version = "0.21", features = ["macros"] } + diff --git a/sgl-router/examples/wasm/wasm-guest-ratelimit/README.md b/sgl-router/examples/wasm/wasm-guest-ratelimit/README.md new file mode 100644 index 00000000000..c8b333fb057 --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-ratelimit/README.md @@ -0,0 +1,68 @@ +# WASM Rate Limit Example for sgl-router + +This example demonstrates rate limiting middleware for sgl-router using the WebAssembly Component Model (WIT). + +## Overview + +This middleware provides rate limiting: + +- **Default**: 60 requests per minute per identifier +- **Identifier Priority**: API Key > IP Address > Request ID +- **Response**: Returns `429 Too Many Requests` when limit exceeded + +**Important**: This is a simplified demonstration. Since WASM components are stateless, each worker thread maintains its own counter. For production, implement rate limiting at the router/host level with shared state. + +## Quick Start + +### Build and Deploy + +```bash +# Build +cd examples/wasm-guest-ratelimit +./build.sh + +# Deploy (replace file_path with actual path) +curl -X POST http://localhost:3000/wasm \ + -H "Content-Type: application/json" \ + -d '{ + "modules": [{ + "name": "ratelimit-middleware", + "file_path": "/absolute/path/to/wasm_guest_ratelimit.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}] + }] + }' +``` + +### Customization + +Modify constants in `src/lib.rs`: + +```rust +const RATE_LIMIT_REQUESTS: u64 = 100; // requests per window +const RATE_LIMIT_WINDOW_MS: u64 = 60_000; // time window in ms +``` + +## Testing + +```bash +# Send multiple requests (first 60 succeed, then 429) +for i in {1..65}; do + curl -s -o /dev/null -w "%{http_code}\n" \ + http://localhost:3000/v1/models \ + -H "Authorization: Bearer secret-api-key-12345" +done +``` + +## Limitations + +- Per-instance state (not shared across workers) +- No cross-process state sharing +- Memory growth with unique identifiers +- State lost on instance restart + +## Troubleshooting + +- Verify module attached to `OnRequest` phase +- Check identifier extraction logic matches request format +- Note: Each WASM worker has separate counter diff --git a/sgl-router/examples/wasm/wasm-guest-ratelimit/build.sh b/sgl-router/examples/wasm/wasm-guest-ratelimit/build.sh new file mode 100755 index 00000000000..3c7e4c8338d --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-ratelimit/build.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Build script for WASM guest rate limit example +# This script simplifies the build process for the WASM middleware component + +set -e + +echo "Building WASM guest rate limit example..." + +# Check if we're in the right directory +if [ ! -f "Cargo.toml" ]; then + echo "Error: Cargo.toml not found. Please run this script from the wasm-guest-ratelimit directory." + exit 1 +fi + +# Check for required tools +command -v cargo >/dev/null 2>&1 || { echo "Error: cargo is required but not installed. Aborting." >&2; exit 1; } + +# Check and install wasm32-wasip2 target +echo "Checking for wasm32-wasip2 target..." +if ! rustup target list --installed | grep -q "wasm32-wasip2"; then + echo "wasm32-wasip2 target not found. Installing..." + rustup target add wasm32-wasip2 + echo "✓ wasm32-wasip2 target installed" +else + echo "✓ wasm32-wasip2 target already installed" +fi + +# Check for wasm-tools +if ! command -v wasm-tools >/dev/null 2>&1; then + echo "Error: wasm-tools is required but not installed." + echo "Install it with: cargo install wasm-tools" + exit 1 +fi + +# Build with cargo (wit-bindgen uses cargo, not wasm-pack) +echo "Running cargo build..." +cargo build --target wasm32-wasip2 --release + +# Output locations +WASM_MODULE="target/wasm32-wasip2/release/wasm_guest_ratelimit.wasm" +WASM_COMPONENT="target/wasm32-wasip2/release/wasm_guest_ratelimit.component.wasm" + +if [ ! -f "$WASM_MODULE" ]; then + echo "Error: Build failed - WASM module not found" + exit 1 +fi + +# Check if the file is already a component +echo "Checking WASM file format..." +if wasm-tools print "$WASM_MODULE" 2>/dev/null | grep -q "^(\s*component"; then + echo "✓ WASM file is already in component format" + # Copy to component path for consistency + cp "$WASM_MODULE" "$WASM_COMPONENT" +else + # Wrap the WASM module into a component format + echo "Wrapping WASM module into component format..." + wasm-tools component new "$WASM_MODULE" -o "$WASM_COMPONENT" + if [ ! -f "$WASM_COMPONENT" ]; then + echo "Error: Failed to create component file" + exit 1 + fi +fi + +if [ -f "$WASM_COMPONENT" ]; then + echo "" + echo "✓ Build successful!" + echo " WASM module: $WASM_MODULE" + echo " WASM component: $WASM_COMPONENT" + echo "" + echo "Next steps:" + echo "1. Use the component file ($WASM_COMPONENT) when adding the module" + echo "2. Prepare the module configuration (see README.md for JSON format)" + echo "3. Use the API endpoint to add the module (see README.md for details)" +else + echo "Error: Component file not found" + exit 1 +fi + diff --git a/sgl-router/examples/wasm/wasm-guest-ratelimit/src/lib.rs b/sgl-router/examples/wasm/wasm-guest-ratelimit/src/lib.rs new file mode 100644 index 00000000000..ad74b0628fe --- /dev/null +++ b/sgl-router/examples/wasm/wasm-guest-ratelimit/src/lib.rs @@ -0,0 +1,155 @@ +//! WASM Guest Rate Limit Example for sgl-router +//! +//! This example demonstrates rate limiting middleware +//! for sgl-router using the WebAssembly Component Model (WIT). +//! +//! Features: +//! - Rate limiting based on API Key or IP address +//! - Fixed time window (e.g., 60 requests per minute) +//! - Returns 429 Too Many Requests when limit exceeded +//! +//! Note: This is a simplified implementation. Since WASM components are stateless, +//! each instance maintains its own counters. For production use, consider +//! implementing rate limiting at the host/router level with shared state. + +wit_bindgen::generate!({ + path: "../../../src/wasm/wit", + world: "sgl-router", +}); + +use std::cell::RefCell; + +use exports::sgl::router::{ + middleware_on_request::Guest as OnRequestGuest, + middleware_on_response::Guest as OnResponseGuest, +}; +use sgl::router::middleware_types::{Action, Request, Response}; + +/// Main middleware implementation +struct Middleware; + +// Rate limit configuration +const RATE_LIMIT_REQUESTS: u64 = 60; // Maximum requests per window +const RATE_LIMIT_WINDOW_MS: u64 = 60_000; // Time window in milliseconds (1 minute) + +// Simple in-memory counter (per WASM instance) +// In a real implementation, this would be shared across all instances +// This is a simplified example for demonstration purposes +struct RateLimitState { + requests: Vec<(String, u64)>, // (identifier, timestamp_ms) +} + +impl RateLimitState { + fn new() -> Self { + Self { + requests: Vec::new(), + } + } + + // Clean up old entries outside the time window + fn cleanup(&mut self, current_time_ms: u64) { + let cutoff = current_time_ms.saturating_sub(RATE_LIMIT_WINDOW_MS); + self.requests.retain(|(_, timestamp)| *timestamp > cutoff); + } + + // Check if identifier has exceeded rate limit + fn check_limit(&mut self, identifier: &str, current_time_ms: u64) -> bool { + self.cleanup(current_time_ms); + + // Count requests in current window for this identifier + let count = self + .requests + .iter() + .filter(|(id, timestamp)| { + id == identifier + && *timestamp > current_time_ms.saturating_sub(RATE_LIMIT_WINDOW_MS) + }) + .count() as u64; + + if count >= RATE_LIMIT_REQUESTS { + return false; // Limit exceeded + } + + // Add new request + self.requests + .push((identifier.to_string(), current_time_ms)); + true // Within limit + } +} + +// Thread-local state (per WASM instance thread) +// Using thread_local! is safer than static mut as it avoids unsafe blocks +// and provides separate state for each thread automatically +thread_local! { + static RATE_LIMIT_STATE: RefCell = RefCell::new(RateLimitState::new()); +} + +fn get_identifier(req: &Request) -> String { + // Helper function to find header value + let find_header_value = + |headers: &[sgl::router::middleware_types::Header], name: &str| -> Option { + headers + .iter() + .find(|h| h.name.eq_ignore_ascii_case(name)) + .map(|h| h.value.clone()) + }; + + // Prefer API Key as identifier (more stable than IP) + if let Some(auth_header) = find_header_value(&req.headers, "authorization") { + if auth_header.starts_with("Bearer ") { + return format!("api_key:{}", &auth_header[7..]); + } else if auth_header.starts_with("ApiKey ") { + return format!("api_key:{}", &auth_header[7..]); + } + } + + if let Some(api_key) = find_header_value(&req.headers, "x-api-key") { + return format!("api_key:{}", api_key); + } + + // Fall back to IP address from forwarded headers + if let Some(forwarded_for) = find_header_value(&req.headers, "x-forwarded-for") { + // Take first IP from comma-separated list + let ip = forwarded_for.split(',').next().unwrap_or("").trim(); + if !ip.is_empty() { + return format!("ip:{}", ip); + } + } + + if let Some(real_ip) = find_header_value(&req.headers, "x-real-ip") { + return format!("ip:{}", real_ip); + } + + // Last resort: use request ID (not ideal, but better than nothing) + format!("req_id:{}", req.request_id) +} + +// Implement on-request interface +impl OnRequestGuest for Middleware { + fn on_request(req: Request) -> Action { + let identifier = get_identifier(&req); + let current_time_ms = req.now_epoch_ms; + + // Access thread-local state safely without unsafe blocks + // Each thread gets its own RateLimitState instance + RATE_LIMIT_STATE.with(|state| { + let mut state = state.borrow_mut(); + if !state.check_limit(&identifier, current_time_ms) { + // Rate limit exceeded + return Action::Reject(429); + } + // Within rate limit, continue processing + Action::Continue + }) + } +} + +// Implement on-response interface (empty - not used for rate limiting) +impl OnResponseGuest for Middleware { + fn on_response(_resp: Response) -> Action { + Action::Continue + } +} + +// Export the component +export!(Middleware); diff --git a/sgl-router/src/app_context.rs b/sgl-router/src/app_context.rs index 31e3161c09e..2eb82ce6dd1 100644 --- a/sgl-router/src/app_context.rs +++ b/sgl-router/src/app_context.rs @@ -23,6 +23,7 @@ use crate::{ traits::Tokenizer, }, tool_parser::ParserFactory as ToolParserFactory, + wasm::{config::WasmRuntimeConfig, module_manager::WasmModuleManager}, }; /// Error type for AppContext builder @@ -57,6 +58,7 @@ pub struct AppContext { pub worker_job_queue: Arc>>, pub workflow_engine: Arc>>, pub mcp_manager: Arc>>, + pub wasm_manager: Option>, } pub struct AppContextBuilder { @@ -76,6 +78,7 @@ pub struct AppContextBuilder { worker_job_queue: Option>>>, workflow_engine: Option>>>, mcp_manager: Option>>>, + wasm_manager: Option>, } impl AppContext { @@ -115,6 +118,7 @@ impl AppContextBuilder { worker_job_queue: None, workflow_engine: None, mcp_manager: None, + wasm_manager: None, } } @@ -207,6 +211,11 @@ impl AppContextBuilder { self } + pub fn wasm_manager(mut self, wasm_manager: Option>) -> Self { + self.wasm_manager = wasm_manager; + self + } + pub fn build(self) -> Result { let router_config = self .router_config @@ -249,6 +258,7 @@ impl AppContextBuilder { mcp_manager: self .mcp_manager .ok_or(AppContextBuildError("mcp_manager"))?, + wasm_manager: self.wasm_manager, }) } @@ -272,6 +282,7 @@ impl AppContextBuilder { .with_workflow_engine() .with_mcp_manager(&router_config) .await? + .with_wasm_manager(&router_config)? .router_config(router_config)) } @@ -505,6 +516,19 @@ impl AppContextBuilder { self.mcp_manager = Some(mcp_manager_lock); Ok(self) } + + /// Create wasm manager if enabled in config + fn with_wasm_manager(mut self, config: &RouterConfig) -> Result { + self.wasm_manager = if config.enable_wasm { + Some(Arc::new( + WasmModuleManager::new(WasmRuntimeConfig::default()) + .map_err(|e| format!("Failed to initialize WASM module manager: {}", e))?, + )) + } else { + None + }; + Ok(self) + } } impl Default for AppContextBuilder { diff --git a/sgl-router/src/config/builder.rs b/sgl-router/src/config/builder.rs index f84c3df33e2..b93ed1ebd08 100644 --- a/sgl-router/src/config/builder.rs +++ b/sgl-router/src/config/builder.rs @@ -327,6 +327,13 @@ impl RouterConfigBuilder { self } + // ==================== WASM ==================== + + pub fn enable_wasm(mut self, enable: bool) -> Self { + self.config.enable_wasm = enable; + self + } + pub fn model_path>(mut self, path: S) -> Self { self.config.model_path = Some(path.into()); self diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index d25c3106d3e..adc7b7278de 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -68,6 +68,9 @@ pub struct RouterConfig { /// Loaded from mcp_config_path during config creation #[serde(skip)] pub mcp_config: Option, + /// Enable WASM support + #[serde(default)] + pub enable_wasm: bool, } /// Tokenizer cache configuration @@ -437,6 +440,7 @@ impl Default for RouterConfig { client_identity: None, ca_certificates: vec![], mcp_config: None, + enable_wasm: false, } } } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index be6aac9dbe2..7d0e7010aaa 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -20,6 +20,7 @@ pub mod service_discovery; pub mod tokenizer; pub mod tool_parser; pub mod tree; +pub mod wasm; use crate::metrics::PrometheusConfig; #[pyclass(eq)] @@ -212,6 +213,7 @@ struct Router { client_cert_path: Option, client_key_path: Option, ca_cert_paths: Vec, + enable_wasm: bool, } impl Router { @@ -371,6 +373,7 @@ impl Router { self.client_key_path.as_ref(), ) .add_ca_certificates(self.ca_cert_paths.clone()) + .enable_wasm(true) .build() } } @@ -449,6 +452,7 @@ impl Router { client_cert_path = None, client_key_path = None, ca_cert_paths = vec![], + enable_wasm = false, ))] #[allow(clippy::too_many_arguments)] fn new( @@ -522,6 +526,7 @@ impl Router { client_cert_path: Option, client_key_path: Option, ca_cert_paths: Vec, + enable_wasm: bool, ) -> PyResult { let mut all_urls = worker_urls.clone(); @@ -609,6 +614,7 @@ impl Router { client_cert_path, client_key_path, ca_cert_paths, + enable_wasm, }) } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 12987cb4172..919a024e7ff 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -318,6 +318,9 @@ struct CliArgs { #[arg(long)] mcp_config_path: Option, + + #[arg(long, default_value_t = false)] + enable_wasm: bool, } enum OracleConnectSource { @@ -601,6 +604,7 @@ impl CliArgs { .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) + .enable_wasm(self.enable_wasm) .igw(self.enable_igw); builder.build() diff --git a/sgl-router/src/middleware.rs b/sgl-router/src/middleware.rs index e9fb86220b1..b276c37e0af 100644 --- a/sgl-router/src/middleware.rs +++ b/sgl-router/src/middleware.rs @@ -21,7 +21,20 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; pub use crate::core::token_bucket::TokenBucket; -use crate::{metrics::RouterMetrics, server::AppState}; +use crate::{ + metrics::RouterMetrics, + server::AppState, + wasm::{ + module::{MiddlewareAttachPoint, WasmModuleAttachPoint}, + spec::{ + apply_modify_action_to_headers, build_wit_headers_from_axum_headers, + sgl::router::middleware_types::{ + Action, Request as WitRequest, Response as WitResponse, + }, + }, + types::WasmComponentInput, + }, +}; #[derive(Clone)] pub struct AuthConfig { @@ -554,3 +567,219 @@ pub async fn concurrency_limit_middleware( } } } + +pub async fn wasm_middleware( + State(app_state): State>, + request: Request, + next: Next, +) -> Result { + // Check if WASM is enabled + if !app_state.context.router_config.enable_wasm { + return Ok(next.run(request).await); + } + + // Get WASM manager + let wasm_manager = match &app_state.context.wasm_manager { + Some(manager) => manager, + None => { + return Ok(next.run(request).await); + } + }; + + // Get request ID from extensions or generate one + let request_id = request + .extensions() + .get::() + .map(|r| r.0.clone()) + .unwrap_or_else(|| generate_request_id(request.uri().path())); + + // ===== OnRequest Phase ===== + let on_request_attach_point = + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest); + + let modules_on_request = + match wasm_manager.get_modules_by_attach_point(on_request_attach_point.clone()) { + Ok(modules) => modules, + Err(e) => { + error!("Failed to get WASM modules for OnRequest: {}", e); + return Ok(next.run(request).await); + } + }; + + // Extract request body once before processing modules + let method = request.method().clone(); + let uri = request.uri().clone(); + let mut headers = request.headers().clone(); + let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await { + Ok(bytes) => bytes.to_vec(), + Err(e) => { + error!("Failed to read request body: {}", e); + // Create a minimal request with empty body for error recovery + let error_request = Request::builder() + .uri(uri) + .body(Body::empty()) + .unwrap_or_else(|_| Request::new(Body::empty())); + return Ok(next.run(error_request).await); + } + }; + + // Process each OnRequest module + let mut modified_body = body_bytes; + + for module in modules_on_request { + // Build WIT request from collected data + let wit_headers = build_wit_headers_from_axum_headers(&headers); + let wit_request = WitRequest { + method: method.to_string(), + path: uri.path().to_string(), + query: uri.query().unwrap_or("").to_string(), + headers: wit_headers, + body: modified_body.clone(), + request_id: request_id.clone(), + now_epoch_ms: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| { + // Fallback to 0 if system time is before UNIX_EPOCH + // This should never happen in practice, but provides a safe fallback + Duration::from_millis(0) + }) + .as_millis() as u64, + }; + + // Execute WASM component + let action = match wasm_manager + .execute_module_for_attach_point( + &module, + on_request_attach_point.clone(), + WasmComponentInput::MiddlewareRequest(wit_request), + ) + .await + { + Some(action) => action, + None => continue, // Continue to next module on error + }; + + // Process action + match action { + Action::Continue => { + // Continue to next module or request processing + } + Action::Reject(status) => { + // Immediately reject the request + return Err(StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST)); + } + Action::Modify(modify) => { + // Apply modifications to headers and body + apply_modify_action_to_headers(&mut headers, &modify); + // Apply body_replace + if let Some(body_bytes) = modify.body_replace { + modified_body = body_bytes; + } + } + } + } + + // Reconstruct request with modifications + let mut final_request = Request::builder() + .method(method) + .uri(uri) + .body(Body::from(modified_body)) + .unwrap_or_else(|_| Request::new(Body::empty())); + *final_request.headers_mut() = headers; + + // Continue with request processing + let response = next.run(final_request).await; + + // ===== OnResponse Phase ===== + let on_response_attach_point = + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse); + + let modules_on_response = + match wasm_manager.get_modules_by_attach_point(on_response_attach_point.clone()) { + Ok(modules) => modules, + Err(e) => { + error!("Failed to get WASM modules for OnResponse: {}", e); + return Ok(response); + } + }; + + // Extract response data once before processing modules + let mut status = response.status(); + let mut headers = response.headers().clone(); + let mut body_bytes = match axum::body::to_bytes(response.into_body(), usize::MAX).await { + Ok(bytes) => bytes.to_vec(), + Err(e) => { + error!("Failed to read response body: {}", e); + // Create a minimal response with empty body for error recovery + let error_response = Response::builder() + .status(status) + .body(Body::empty()) + .unwrap_or_else(|_| Response::new(Body::empty())); + return Ok(error_response); + } + }; + + // Process each OnResponse module + for module in modules_on_response { + // Build WIT response from collected data + let wit_headers = build_wit_headers_from_axum_headers(&headers); + let wit_response = WitResponse { + status: status.as_u16(), + headers: wit_headers, + body: body_bytes.clone(), + }; + + // Execute WASM component + let action = match wasm_manager + .execute_module_for_attach_point( + &module, + on_response_attach_point.clone(), + WasmComponentInput::MiddlewareResponse(wit_response), + ) + .await + { + Some(action) => action, + None => continue, // Continue to next module on error + }; + + // Process action - apply modifications incrementally + match action { + Action::Continue => { + // Continue to next module + } + Action::Reject(status_code) => { + // Override response status + status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_REQUEST); + // Return immediately with current state + let final_response = Response::builder() + .status(status) + .body(Body::from(body_bytes)) + .unwrap_or_else(|_| Response::new(Body::empty())); + let mut final_response = final_response; + *final_response.headers_mut() = headers; + return Ok(final_response); + } + Action::Modify(modify) => { + // Apply status modification + if let Some(new_status) = modify.status { + status = StatusCode::from_u16(new_status).unwrap_or(status); + } + // Apply headers modifications + apply_modify_action_to_headers(&mut headers, &modify); + // Apply body_replace + if let Some(new_body) = modify.body_replace { + body_bytes = new_body; + } + } + } + } + + // Reconstruct final response with all modifications + let final_response = Response::builder() + .status(status) + .body(Body::from(body_bytes)) + .unwrap_or_else(|_| Response::new(Body::empty())); + let mut final_response = final_response; + *final_response.headers_mut() = headers; + Ok(final_response) +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index c70ac9c4618..0c96346e97a 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -45,6 +45,7 @@ use crate::{ }, routers::{router_manager::RouterManager, RouterTrait}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, + wasm::route::{add_wasm_module, list_wasm_modules, remove_wasm_module}, }; #[derive(Clone)] @@ -634,6 +635,10 @@ pub fn build_app( .route_layer(axum::middleware::from_fn_with_state( auth_config.clone(), middleware::auth_middleware, + )) + .route_layer(axum::middleware::from_fn_with_state( + app_state.clone(), + middleware::wasm_middleware, )); let public_routes = Router::new() @@ -648,6 +653,9 @@ pub fn build_app( let admin_routes = Router::new() .route("/flush_cache", post(flush_cache)) .route("/get_loads", get(get_loads)) + .route("/wasm", post(add_wasm_module)) + .route("/wasm/{module_uuid}", delete(remove_wasm_module)) + .route("/wasm", get(list_wasm_modules)) .route_layer(axum::middleware::from_fn_with_state( auth_config.clone(), middleware::auth_middleware, diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 575203f7b74..c14c0aaf186 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -594,6 +594,7 @@ mod tests { worker_job_queue: Arc::new(std::sync::OnceLock::new()), workflow_engine: Arc::new(std::sync::OnceLock::new()), mcp_manager: Arc::new(std::sync::OnceLock::new()), + wasm_manager: None, }) } diff --git a/sgl-router/src/wasm/README.md b/sgl-router/src/wasm/README.md new file mode 100644 index 00000000000..c6f5466ef2c --- /dev/null +++ b/sgl-router/src/wasm/README.md @@ -0,0 +1,228 @@ +# WebAssembly (WASM) Extensibility for sgl-router + +This module provides WebAssembly-based extensibility for sgl-router, enabling dynamic, safe, and portable middleware execution without requiring router restarts or recompilation. + +## Overview + +The WASM module allows you to extend sgl-router functionality by deploying WebAssembly components that can: + +- **Intercept requests/responses** at various lifecycle points (OnRequest, OnResponse) +- **Modify HTTP headers and bodies** before/after processing +- **Reject requests** with custom status codes +- **Execute custom logic** in a sandboxed, isolated environment + +## Architecture + +### Components + +The WASM module consists of several key components: + +``` +src/wasm/ +├── module.rs # Data structures (metadata, types, attach points) +├── module_manager.rs # Module lifecycle management (add/remove/list) +├── runtime.rs # WASM execution engine and thread pool +├── route.rs # HTTP API endpoints for module management +├── spec.rs # WIT bindings and type conversions +├── types.rs # Generic input/output types +├── errors.rs # Error definitions +├── config.rs # Runtime configuration +└── wit/ + └── spec.wit # WebAssembly Interface Types (WIT) definitions +``` + +### Execution Flow + +``` +1. HTTP Request arrives at router + ↓ +2. Middleware chain checks for WASM modules attached to OnRequest + ↓ +3. For each module: + a. Module manager retrieves pre-loaded WASM bytes + b. Runtime executes component in isolated worker thread + c. Component processes request via WIT interface + d. Returns Action (Continue/Reject/Modify) + ↓ +4. If Continue: proceed to next middleware/upstream + If Reject: return error response immediately + If Modify: apply changes (headers, body, status) + ↓ +5. After upstream response: + - Modules attached to OnResponse process response + - Apply modifications + ↓ +6. Return final response to client +``` + +### WIT (WebAssembly Interface Types) + +The module uses the WebAssembly Component Model with WIT for type-safe communication between host and WASM components: + +- **Request Processing**: `middleware-on-request::on-request(req: Request) -> Action` +- **Response Processing**: `middleware-on-response::on-response(resp: Response) -> Action` +- **Actions**: `Continue`, `Reject(status)`, or `Modify(modify-action)` + +See [`wit/spec.wit`](./wit/spec.wit) for the complete interface definition. + +## Usage + +### Prerequisites + +- sgl-router compiled with WASM support +- Rust toolchain (for building WASM components) +- `wasm32-wasip2` target: `rustup target add wasm32-wasip2` +- `wasm-tools`: `cargo install wasm-tools` + +### Starting the Router + +Enable WASM support when starting the router: + +```bash +./sgl-router --enable-wasm --worker-urls=http://0.0.0.0:30000 --port=3000 +``` + +### Deploying a WASM Module + +Use the `/wasm` POST endpoint to deploy modules: + +```bash +curl -X POST http://localhost:3000/wasm \ + -H "Content-Type: application/json" \ + -d '{ + "modules": [{ + "name": "my-middleware", + "file_path": "/path/to/my-component.component.wasm", + "module_type": "Middleware", + "attach_points": [{"Middleware": "OnRequest"}] + }] + }' +``` + +### Managing Modules + +**List all modules:** +```bash +curl http://localhost:3000/wasm +``` + +**Remove a module:** +```bash +curl -X DELETE http://localhost:3000/wasm/{module-uuid} +``` + +### Module Configuration + +Each module requires: + +- **name**: Unique identifier for the module +- **file_path**: Absolute path to the WASM component file +- **module_type**: Currently supports `"Middleware"` +- **attach_points**: List of attachment points, e.g., `[{"Middleware": "OnRequest"}]` + +Supported attachment points: +- `{"Middleware": "OnRequest"}` - Execute before forwarding to upstream +- `{"Middleware": "OnResponse"}` - Execute after receiving upstream response +- `{"Middleware": "OnError"}` - Not yet implemented + +## Examples + +See [`examples/wasm/`](../../examples/wasm/) for complete examples: + +1. **[wasm-guest-auth](../../examples/wasm/wasm-guest-auth/)** - API key authentication middleware +2. **[wasm-guest-logging](../../examples/wasm/wasm-guest-logging/)** - Request tracking and status code conversion +3. **[wasm-guest-ratelimit](../../examples/wasm/wasm-guest-ratelimit/)** - Rate limiting middleware + +Each example includes: +- Complete source code +- Build instructions +- Deployment examples +- Testing guidelines + +## Security and Resource Management + +### Sandboxing + +WASM modules run in isolated environments provided by wasmtime, preventing: +- Direct system access +- Memory corruption of the host process +- Unauthorized network access +- File system access (unless explicitly granted via WASI) + +### Resource Limits + +Runtime configuration allows setting limits: + +```rust +WasmRuntimeConfig { + max_memory_pages: 1024, // 64MB limit + max_execution_time_ms: 1000, // 1 second timeout + max_stack_size: 1024 * 1024, // 1MB stack + thread_pool_size: 4, // Worker threads + module_cache_size: 10, // Cached modules per worker +} +``` + +### Error Handling + +- Failed module executions are logged and don't crash the router +- Invalid WASM components are rejected during load time +- Metrics track execution success/failure rates + +## Metrics + +The module exposes execution metrics via the `/wasm` GET endpoint: + +```json +{ + "modules": [...], + "metrics": { + "total_executions": 1000, + "successful_executions": 995, + "failed_executions": 5, + "total_execution_time_ms": 50000, + "max_execution_time_ms": 150, + "average_execution_time_ms": 50.0 + } +} +``` + +## Development + +### Building WASM Components + +WASM components must be built using the Component Model. For Rust: + +```bash +# 1. Build as WASM module +cargo build --target wasm32-wasip2 --release + +# 2. Wrap into component format +wasm-tools component new target/wasm32-wasip2/release/my_module.wasm \ + -o my_module.component.wasm +``` + +### WIT Interface + +Define your component using the WIT interface from `wit/spec.wit`: + +```rust +wit_bindgen::generate!({ + path: "../../../src/wasm/wit", + world: "sgl-router", +}); + +use exports::sgl::router::middleware_on_request::Guest as OnRequestGuest; +use sgl::router::middleware_types::{Request, Action}; + +struct Middleware; + +impl OnRequestGuest for Middleware { + fn on_request(req: Request) -> Action { + // Your logic here + Action::Continue + } +} + +export!(Middleware); +``` diff --git a/sgl-router/src/wasm/config.rs b/sgl-router/src/wasm/config.rs new file mode 100644 index 00000000000..6954b57554c --- /dev/null +++ b/sgl-router/src/wasm/config.rs @@ -0,0 +1,297 @@ +//! WASM Runtime Configuration +//! +//! Defines configuration parameters for the WASM runtime, +//! including memory limits, execution timeouts, and thread pool settings. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WasmRuntimeConfig { + /// Maximum memory size in pages (64KB per page) + pub max_memory_pages: u32, + /// Maximum execution time in milliseconds + pub max_execution_time_ms: u64, + /// Maximum stack size in bytes + pub max_stack_size: usize, + /// Number of worker threads in the pool + pub thread_pool_size: usize, + /// Maximum number of modules to cache per worker + pub module_cache_size: usize, +} + +impl Default for WasmRuntimeConfig { + fn default() -> Self { + let default_thread_pool_size = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4) + .max(1); + + Self { + max_memory_pages: 1024, // 64MB + max_execution_time_ms: 1000, // 1 seconds + max_stack_size: 1024 * 1024, // 1MB + thread_pool_size: default_thread_pool_size, // based on cpu count + module_cache_size: 10, // Cache up to 10 modules per worker + } + } +} + +impl WasmRuntimeConfig { + /// Validate the configuration parameters + pub fn validate(&self) -> Result<(), String> { + // Validate max_memory_pages + if self.max_memory_pages == 0 { + return Err("max_memory_pages cannot be 0".to_string()); + } + if self.max_memory_pages > 65536 { + return Err("max_memory_pages cannot exceed 65536 (4GB)".to_string()); + } + + // Validate max_execution_time_ms + if self.max_execution_time_ms == 0 { + return Err("max_execution_time_ms cannot be 0".to_string()); + } + if self.max_execution_time_ms > 300000 { + return Err("max_execution_time_ms cannot exceed 300000ms (5 minutes)".to_string()); + } + + // Validate max_stack_size + if self.max_stack_size == 0 { + return Err("max_stack_size cannot be 0".to_string()); + } + if self.max_stack_size < 64 * 1024 { + return Err("max_stack_size must be at least 64KB".to_string()); + } + if self.max_stack_size > 16 * 1024 * 1024 { + return Err("max_stack_size cannot exceed 16MB".to_string()); + } + + // Validate thread_pool_size + if self.thread_pool_size == 0 { + return Err("thread_pool_size cannot be 0".to_string()); + } + if self.thread_pool_size > 128 { + return Err("thread_pool_size cannot exceed 128".to_string()); + } + + // Validate module_cache_size + if self.module_cache_size == 0 { + return Err("module_cache_size cannot be 0".to_string()); + } + if self.module_cache_size > 1000 { + return Err("module_cache_size cannot exceed 1000".to_string()); + } + + Ok(()) + } + + /// Create a new config with validation + pub fn new( + max_memory_pages: u32, + max_execution_time_ms: u64, + max_stack_size: usize, + thread_pool_size: usize, + module_cache_size: usize, + ) -> Result { + let config = Self { + max_memory_pages, + max_execution_time_ms, + max_stack_size, + thread_pool_size, + module_cache_size, + }; + config.validate()?; + Ok(config) + } + + /// Get the total memory size in bytes + pub fn get_total_memory_bytes(&self) -> u64 { + self.max_memory_pages as u64 * 64 * 1024 // 64KB per page + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config_validation() { + let config = WasmRuntimeConfig::default(); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_config_new_with_validation() { + let config = WasmRuntimeConfig::new(1024, 1000, 1024 * 1024, 2, 10); + assert!(config.is_ok()); + } + + #[test] + fn test_validation_max_memory_pages_zero() { + let config = WasmRuntimeConfig { + max_memory_pages: 0, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("max_memory_pages cannot be 0")); + } + + #[test] + fn test_validation_max_memory_pages_too_large() { + let config = WasmRuntimeConfig { + max_memory_pages: 65537, // Exceeds 4GB limit + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("max_memory_pages cannot exceed 65536")); + } + + #[test] + fn test_validation_max_execution_time_zero() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 0, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("max_execution_time_ms cannot be 0")); + } + + #[test] + fn test_validation_max_execution_time_too_large() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 300001, // Exceeds 5 minutes + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("max_execution_time_ms cannot exceed 300000ms")); + } + + #[test] + fn test_validation_max_stack_size_too_small() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 32 * 1024, // Less than 64KB + thread_pool_size: 2, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("max_stack_size must be at least 64KB")); + } + + #[test] + fn test_validation_max_stack_size_too_large() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 17 * 1024 * 1024, // Exceeds 16MB + thread_pool_size: 2, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("max_stack_size cannot exceed 16MB")); + } + + #[test] + fn test_validation_thread_pool_size_zero() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 0, + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("thread_pool_size cannot be 0")); + } + + #[test] + fn test_validation_thread_pool_size_too_large() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 129, // Exceeds 128 + module_cache_size: 10, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("thread_pool_size cannot exceed 128")); + } + + #[test] + fn test_validation_module_cache_size_zero() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 0, + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("module_cache_size cannot be 0")); + } + + #[test] + fn test_validation_module_cache_size_too_large() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 1001, // Exceeds 1000 + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("module_cache_size cannot exceed 1000")); + } + + #[test] + fn test_get_total_memory_bytes() { + let config = WasmRuntimeConfig { + max_memory_pages: 1024, + max_execution_time_ms: 1000, + max_stack_size: 1024 * 1024, + thread_pool_size: 2, + module_cache_size: 10, + }; + // 1024 pages * 64KB = 64MB + assert_eq!(config.get_total_memory_bytes(), 64 * 1024 * 1024); + } +} diff --git a/sgl-router/src/wasm/errors.rs b/sgl-router/src/wasm/errors.rs new file mode 100644 index 00000000000..c5b9187ed4c --- /dev/null +++ b/sgl-router/src/wasm/errors.rs @@ -0,0 +1,117 @@ +//! WASM Error Types +//! +//! Defines comprehensive error types for the WASM subsystem, +//! including module, manager, and runtime errors. + +use std::fmt; + +use thiserror::Error; + +pub type Result = std::result::Result; + +/// SHA256 hash wrapper for display purposes +#[derive(Debug, Clone, Copy)] +pub struct Sha256Hash(pub [u8; 32]); + +impl fmt::Display for Sha256Hash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let hex_string: String = self.0.iter().map(|b| format!("{:02x}", b)).collect(); + write!(f, "{}", hex_string) + } +} + +impl From<[u8; 32]> for Sha256Hash { + fn from(hash: [u8; 32]) -> Self { + Sha256Hash(hash) + } +} + +#[derive(Debug, Error)] +pub enum WasmError { + #[error(transparent)] + Module(#[from] WasmModuleError), + + #[error(transparent)] + Manager(#[from] WasmManagerError), + + #[error(transparent)] + Runtime(#[from] WasmRuntimeError), + + #[error(transparent)] + Io(#[from] std::io::Error), + + #[error("{0}")] + Other(String), +} + +#[derive(Debug, Error)] +pub enum WasmModuleError { + #[error("invalid module descriptor: {0}")] + InvalidDescriptor(String), + + #[error("module with same sha256 already exists: {0}")] + DuplicateSha256(Sha256Hash), + + #[error("module not found: {0}")] + NotFound(uuid::Uuid), + + #[error("failed to read module file: {0}")] + FileRead(String), + + #[error("validation failed: {0}")] + ValidationFailed(String), + + #[error("attach point missing: {0}")] + AttachPointMissing(String), + + #[error("invalid function for attach point: {0}")] + AttachPointFunctionInvalid(String), +} + +#[derive(Debug, Error)] +pub enum WasmManagerError { + #[error("failed to acquire lock: {0}")] + LockFailed(String), + + #[error("module add failed: {0}")] + ModuleAddFailed(String), + + #[error("module remove failed: {0}")] + ModuleRemoveFailed(String), + + #[error("runtime unavailable")] + RuntimeUnavailable, + + #[error("execution failed: {0}")] + ExecutionFailed(String), + + #[error("module {0} not found")] + ModuleNotFound(uuid::Uuid), +} + +#[derive(Debug, Error)] +pub enum WasmRuntimeError { + #[error("failed to create engine: {0}")] + EngineCreateFailed(String), + + #[error("failed to compile module: {0}")] + CompileFailed(String), + + #[error("failed to create instance: {0}")] + InstanceCreateFailed(String), + + #[error("function not found: {0}")] + FunctionNotFound(String), + + #[error("execution timeout")] + Timeout, + + #[error("execution failed: {0}")] + CallFailed(String), +} + +impl From for WasmError { + fn from(value: wasmtime::Error) -> Self { + WasmError::Runtime(WasmRuntimeError::CallFailed(value.to_string())) + } +} diff --git a/sgl-router/src/wasm/mod.rs b/sgl-router/src/wasm/mod.rs new file mode 100644 index 00000000000..1a866955ece --- /dev/null +++ b/sgl-router/src/wasm/mod.rs @@ -0,0 +1,13 @@ +//! WebAssembly (WASM) module support for sgl-router +//! +//! This module provides WASM component execution capabilities using the WebAssembly Component Model (WIT). +//! It supports middleware execution at various attach points (OnRequest, OnResponse) with async support. + +pub mod config; +pub mod errors; +pub mod module; +pub mod module_manager; +pub mod route; +pub mod runtime; +pub mod spec; +pub mod types; diff --git a/sgl-router/src/wasm/module.rs b/sgl-router/src/wasm/module.rs new file mode 100644 index 00000000000..8e7c1923bf5 --- /dev/null +++ b/sgl-router/src/wasm/module.rs @@ -0,0 +1,142 @@ +//! WASM Module Data Structures and Types +//! +//! This module defines the core data structures for managing WebAssembly components: +//! - Module metadata (UUID, name, file path, hash, timestamps, metrics) +//! - Module types and attachment points (Middleware hooks: OnRequest, OnResponse, OnError) +//! - API request/response types for module management +//! - Execution metrics and statistics +//! +//! The module provides custom serialization for: +//! - SHA256 hashes (hex string representation) +//! - Timestamps (ISO 8601 format for JSON output) + +use serde::{Deserialize, Serialize, Serializer}; +use uuid::Uuid; + +/// Serialize [u8; 32] as hex string +fn serialize_sha256_hash(hash: &[u8; 32], serializer: S) -> Result +where + S: Serializer, +{ + let hex_string = hash + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(); + serializer.serialize_str(&hex_string) +} + +/// Serialize u64 timestamp (nanoseconds since epoch) as ISO 8601 string +fn serialize_timestamp(timestamp: &u64, serializer: S) -> Result +where + S: Serializer, +{ + use chrono::{DateTime, Utc}; + + // Convert nanoseconds to seconds and remaining nanoseconds + let secs = (*timestamp / 1_000_000_000) as i64; + let nanos = (*timestamp % 1_000_000_000) as u32; + + match DateTime::::from_timestamp(secs, nanos) { + Some(dt) => { + let s = dt.to_rfc3339_opts(chrono::SecondsFormat::Nanos, true); + serializer.serialize_str(&s) + } + None => { + // Fallback: format manually if timestamp is out of range + let s = format!("{}", timestamp); + serializer.serialize_str(&s) + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmModule { + // unique identifier for the module + pub module_uuid: Uuid, + pub module_meta: WasmModuleMeta, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum WasmModuleAddResult { + Success(Uuid), + Error(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmModuleDescriptor { + pub name: String, + pub file_path: String, + pub module_type: WasmModuleType, + pub attach_points: Vec, + pub add_result: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmModuleMeta { + // module name provided by the user + pub name: String, + // path to the module file + pub file_path: String, + // sha256 hash of the module file + #[serde(serialize_with = "serialize_sha256_hash")] + pub sha256_hash: [u8; 32], + // size of the module file in bytes + pub size_bytes: u64, + // timestamp of when the module was created (nanoseconds since epoch) + #[serde(serialize_with = "serialize_timestamp")] + pub created_at: u64, + // timestamp of when the module was last accessed (nanoseconds since epoch) + #[serde(serialize_with = "serialize_timestamp")] + pub last_accessed_at: u64, + // number of times the module was accessed + pub access_count: u64, + // attach points for the module + pub attach_points: Vec, + // Pre-loaded WASM component bytes (loaded into memory for faster execution) + #[serde(skip)] + pub wasm_bytes: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub enum WasmModuleType { + Middleware, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub enum MiddlewareAttachPoint { + OnRequest, + OnResponse, + OnError, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub enum WasmModuleAttachPoint { + Middleware(MiddlewareAttachPoint), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmModuleAddRequest { + pub modules: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmModuleAddResponse { + pub modules: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmModuleListResponse { + pub modules: Vec, + pub metrics: WasmMetrics, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmMetrics { + pub total_executions: u64, + pub successful_executions: u64, + pub failed_executions: u64, + pub total_execution_time_ms: u64, + pub max_execution_time_ms: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub average_execution_time_ms: Option, +} diff --git a/sgl-router/src/wasm/module_manager.rs b/sgl-router/src/wasm/module_manager.rs new file mode 100644 index 00000000000..a33f632b5c6 --- /dev/null +++ b/sgl-router/src/wasm/module_manager.rs @@ -0,0 +1,375 @@ +//! WASM Module Manager + +use std::{ + collections::HashMap, + fs::File, + io::Read, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, RwLock, + }, +}; + +use sha2::{Digest, Sha256}; +use uuid::Uuid; +use wasmtime::{component::Component, Config, Engine}; + +use crate::wasm::{ + config::WasmRuntimeConfig, + errors::{Result, WasmError, WasmManagerError, WasmModuleError, WasmRuntimeError}, + module::{WasmModule, WasmModuleAttachPoint, WasmModuleDescriptor, WasmModuleMeta}, + runtime::WasmRuntime, + types::{WasmComponentInput, WasmComponentOutput}, +}; + +pub struct WasmModuleManager { + modules: Arc>>, + runtime: Arc, + // Metrics + total_executions: AtomicU64, + successful_executions: AtomicU64, + failed_executions: AtomicU64, + total_execution_time_ms: AtomicU64, + max_execution_time_ms: AtomicU64, +} + +impl WasmModuleManager { + pub fn new(config: WasmRuntimeConfig) -> Result { + let runtime = Arc::new(WasmRuntime::new(config)?); + Ok(Self { + modules: Arc::new(RwLock::new(HashMap::new())), + runtime, + total_executions: AtomicU64::new(0), + successful_executions: AtomicU64::new(0), + failed_executions: AtomicU64::new(0), + total_execution_time_ms: AtomicU64::new(0), + max_execution_time_ms: AtomicU64::new(0), + }) + } + + pub fn with_default_config() -> Result { + Self::new(WasmRuntimeConfig::default()) + } + + fn check_duplicate_sha256_hash(&self, sha256_hash: &[u8; 32]) -> Result<()> { + let modules = self + .modules + .read() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + if modules + .values() + .any(|module: &WasmModule| module.module_meta.sha256_hash == *sha256_hash) + { + return Err(WasmModuleError::DuplicateSha256((*sha256_hash).into()).into()); + } + Ok(()) + } + + fn calculate_size_bytes(&self, file_path: &str) -> Result { + let file = File::open(file_path).map_err(|e| WasmError::from(e))?; + let metadata = file.metadata().map_err(|e| WasmError::from(e))?; + Ok(metadata.len()) + } + + fn calculate_sha256_hash(&self, file_path: &str) -> Result<[u8; 32]> { + let mut file = File::open(file_path).map_err(|e| WasmError::from(e))?; + let mut hasher = Sha256::new(); + let mut buffer = [0; 1024]; + loop { + let bytes_read = file.read(&mut buffer).map_err(|e| WasmError::from(e))?; + if bytes_read == 0 { + break; + } + hasher.update(&buffer[..bytes_read]); + } + Ok(hasher.finalize().into()) + } + + fn validate_module_descriptor(&self, descriptor: &WasmModuleDescriptor) -> Result<()> { + if descriptor.name.is_empty() { + return Err(WasmModuleError::InvalidDescriptor( + "Module name cannot be empty".to_string(), + ) + .into()); + } + if descriptor.file_path.is_empty() { + return Err(WasmModuleError::InvalidDescriptor( + "Module file path cannot be empty".to_string(), + ) + .into()); + } + if self.calculate_size_bytes(&descriptor.file_path)? == 0 { + return Err(WasmModuleError::ValidationFailed( + "Module file size cannot be 0".to_string(), + ) + .into()); + } + Ok(()) + } + + pub fn add_module(&self, descriptor: WasmModuleDescriptor) -> Result { + // validate the module descriptor + self.validate_module_descriptor(&descriptor)?; + + // calculate the sha256 hash of the module file + let sha256_hash = self.calculate_sha256_hash(&descriptor.file_path)?; + self.check_duplicate_sha256_hash(&sha256_hash)?; + + // calculate size before moving descriptor + let size_bytes = self.calculate_size_bytes(&descriptor.file_path)?; + + // Pre-load WASM bytes into memory for faster execution + let wasm_bytes = std::fs::read(&descriptor.file_path).map_err(|e| { + WasmError::from(WasmModuleError::FileRead(format!( + "Failed to read WASM file: {}", + e + ))) + })?; + + // Validate that the WASM file is a valid component by attempting to compile it + // This catches errors early during module addition rather than during execution + self.validate_wasm_component(&wasm_bytes)?; + + // now safe, insert the module into the manager + // SystemTime::duration_since only fails if the system time is before UNIX_EPOCH, + // which should never happen in normal operation. If it does, use current time as fallback. + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| { + // Fallback to a reasonable timestamp if system time is invalid + // This should never occur in practice, but provides a safe fallback + std::time::Duration::from_nanos(0) + }) + .as_nanos() as u64; + let module_uuid = Uuid::new_v4(); + let module = WasmModule { + module_uuid, + module_meta: WasmModuleMeta { + name: descriptor.name, + file_path: descriptor.file_path, + sha256_hash, + size_bytes, + created_at: now, + last_accessed_at: now, + access_count: 0, + attach_points: descriptor.attach_points.clone(), + wasm_bytes, + }, + }; + + let mut modules = self + .modules + .write() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + modules.insert(module_uuid, module); + Ok(module_uuid) + } + + /// Validate that WASM bytes represent a valid component + fn validate_wasm_component(&self, wasm_bytes: &[u8]) -> Result<()> { + // Create a temporary engine to validate the component + let mut config = Config::new(); + config.async_support(true); + config.wasm_component_model(true); + let engine = Engine::new(&config) + .map_err(|e| WasmError::from(WasmRuntimeError::EngineCreateFailed(e.to_string())))?; + + // Attempt to compile the component to validate it + Component::new(&engine, wasm_bytes) + .map_err(|e| WasmError::from(WasmRuntimeError::CompileFailed(format!( + "Invalid WASM component: {}. \ + Hint: The WASM file must be in component format. \ + If you're using wit-bindgen, use 'wasm-tools component new' to wrap the WASM module into a component.", + e + ))))?; + + Ok(()) + } + + pub fn remove_module(&self, module_uuid: Uuid) -> Result<()> { + let mut modules = self + .modules + .write() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + if !modules.contains_key(&module_uuid) { + return Err(WasmManagerError::ModuleNotFound(module_uuid).into()); + } + // Remove the module - the wasm_bytes Vec will be dropped automatically, + // releasing the memory + modules.remove(&module_uuid); + Ok(()) + } + + pub fn get_all_modules(&self) -> Result> { + let modules = self + .modules + .read() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + Ok(modules.values().cloned().collect()) + } + + pub fn get_module(&self, module_uuid: Uuid) -> Result> { + let modules = self + .modules + .read() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + Ok(modules.get(&module_uuid).cloned()) + } + + pub fn get_modules(&self) -> Result> { + let modules = self + .modules + .read() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + Ok(modules.values().cloned().collect()) + } + + /// get modules by attach point + pub fn get_modules_by_attach_point( + &self, + attach_point: WasmModuleAttachPoint, + ) -> Result> { + let modules = self + .modules + .read() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + Ok(modules + .values() + .filter(|module| module.module_meta.attach_points.contains(&attach_point)) + .cloned() + .collect()) + } + + pub fn get_runtime(&self) -> &Arc { + &self.runtime + } + + /// Execute WASM module using WIT component model based on attach_point + pub async fn execute_module_wit( + &self, + module_uuid: Uuid, + attach_point: WasmModuleAttachPoint, + input: WasmComponentInput, + ) -> Result { + let start_time = std::time::Instant::now(); + + // First, get the WASM bytes with a read lock (faster) + let wasm_bytes = { + let modules = self + .modules + .read() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + let module = modules + .get(&module_uuid) + .ok_or_else(|| WasmError::from(WasmManagerError::ModuleNotFound(module_uuid)))?; + + // Clone the pre-loaded WASM bytes (already in memory, no file I/O) + module.module_meta.wasm_bytes.clone() + }; + + { + let mut modules = self + .modules + .write() + .map_err(|e| WasmManagerError::LockFailed(e.to_string()))?; + if let Some(module) = modules.get_mut(&module_uuid) { + // SystemTime::duration_since only fails if the system time is before UNIX_EPOCH, + // which should never happen in normal operation. If it does, use current time as fallback. + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| { + // Fallback to a reasonable timestamp if system time is invalid + // This should never occur in practice, but provides a safe fallback + std::time::Duration::from_nanos(0) + }) + .as_nanos() as u64; + module.module_meta.last_accessed_at = now; + module.module_meta.access_count += 1; + } + } + + let result = self + .runtime + .execute_component_async(wasm_bytes, attach_point, input) + .await; + + // Record metrics + let execution_time_ms = start_time.elapsed().as_millis() as u64; + self.total_executions.fetch_add(1, Ordering::Relaxed); + self.total_execution_time_ms + .fetch_add(execution_time_ms, Ordering::Relaxed); + // Update max execution time + self.max_execution_time_ms + .fetch_max(execution_time_ms, Ordering::Relaxed); + + if result.is_ok() { + self.successful_executions.fetch_add(1, Ordering::Relaxed); + } else { + self.failed_executions.fetch_add(1, Ordering::Relaxed); + } + + result + } + + /// Execute WASM module using WIT component model (sync version) + pub fn execute_module_wit_sync( + &self, + module_uuid: Uuid, + attach_point: WasmModuleAttachPoint, + input: WasmComponentInput, + ) -> Result { + let handle = tokio::runtime::Handle::current(); + handle.block_on(self.execute_module_wit(module_uuid, attach_point, input)) + } + + /// Get current metrics + pub fn get_metrics(&self) -> (u64, u64, u64, u64, u64) { + ( + self.total_executions.load(Ordering::Relaxed), + self.successful_executions.load(Ordering::Relaxed), + self.failed_executions.load(Ordering::Relaxed), + self.total_execution_time_ms.load(Ordering::Relaxed), + self.max_execution_time_ms.load(Ordering::Relaxed), + ) + } + + /// Execute a WASM module for a given attach point + /// Returns the Action if successful, or None if execution failed + /// + /// This is a convenience method that wraps execute_module_wit and handles + /// error logging automatically. + pub async fn execute_module_for_attach_point( + &self, + module: &WasmModule, + attach_point: WasmModuleAttachPoint, + input: WasmComponentInput, + ) -> Option { + use tracing::error; + + let action_result = self + .execute_module_wit(module.module_uuid, attach_point, input) + .await; + + match action_result { + Ok(output) => match output { + WasmComponentOutput::MiddlewareAction(action) => Some(action), + }, + Err(e) => { + error!( + "Failed to execute WASM module {}: {}", + module.module_meta.name, e + ); + None + } + } + } +} + +impl Default for WasmModuleManager { + fn default() -> Self { + // with_default_config() should always succeed with default configuration. + // If it fails, it indicates a critical system configuration error. + Self::with_default_config() + .expect("Failed to create WasmModuleManager with default config. This should never happen with valid default configuration.") + } +} diff --git a/sgl-router/src/wasm/route.rs b/sgl-router/src/wasm/route.rs new file mode 100644 index 00000000000..61329498772 --- /dev/null +++ b/sgl-router/src/wasm/route.rs @@ -0,0 +1,90 @@ +//! WASM HTTP API Routes +//! +//! Provides REST API endpoints for managing WASM modules: +//! - POST /wasm - Add modules +//! - DELETE /wasm/:uuid - Remove a module +//! - GET /wasm - List all modules with metrics + +use std::sync::Arc; + +use axum::{ + extract::{Json, Path, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use uuid::Uuid; + +use crate::{ + server::AppState, + wasm::module::{ + WasmMetrics, WasmModuleAddRequest, WasmModuleAddResponse, WasmModuleAddResult, + WasmModuleListResponse, + }, +}; + +pub async fn add_wasm_module( + State(state): State>, + Json(config): Json, +) -> Response { + let Some(wasm_manager) = state.context.wasm_manager.as_ref() else { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + }; + let mut status = StatusCode::OK; + let mut modules = config.modules.clone(); + for module in modules.iter_mut() { + let result = wasm_manager.add_module(module.clone()); + if let Ok(module_uuid) = result { + module.add_result = Some(WasmModuleAddResult::Success(module_uuid)); + } else { + // We know result is Err here, so unwrap_err() is safe + module.add_result = Some(WasmModuleAddResult::Error(result.unwrap_err().to_string())); + status = StatusCode::BAD_REQUEST; + } + } + + let response = WasmModuleAddResponse { modules }; + (status, Json(response)).into_response() +} + +pub async fn remove_wasm_module( + State(state): State>, + Path(module_uuid_str): Path, +) -> Response { + let Ok(module_uuid) = Uuid::parse_str(&module_uuid_str) else { + return StatusCode::BAD_REQUEST.into_response(); + }; + let Some(wasm_manager) = state.context.wasm_manager.as_ref() else { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + }; + if let Err(e) = wasm_manager.remove_module(module_uuid) { + return (StatusCode::BAD_REQUEST, e.to_string()).into_response(); + } + (StatusCode::OK, "Module removed successfully").into_response() +} + +pub async fn list_wasm_modules(State(state): State>) -> Response { + let Some(wasm_manager) = state.context.wasm_manager.as_ref() else { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + }; + let modules = wasm_manager.get_modules(); + if let Ok(modules) = modules { + let (total, success, failed, total_time_ms, max_time_ms) = wasm_manager.get_metrics(); + let average_execution_time_ms = if total > 0 { + Some(total_time_ms as f64 / total as f64) + } else { + None + }; + let metrics = WasmMetrics { + total_executions: total, + successful_executions: success, + failed_executions: failed, + total_execution_time_ms: total_time_ms, + max_execution_time_ms: max_time_ms, + average_execution_time_ms, + }; + let response = WasmModuleListResponse { modules, metrics }; + (StatusCode::OK, Json(response)).into_response() + } else { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } +} diff --git a/sgl-router/src/wasm/runtime.rs b/sgl-router/src/wasm/runtime.rs new file mode 100644 index 00000000000..f27b9518309 --- /dev/null +++ b/sgl-router/src/wasm/runtime.rs @@ -0,0 +1,430 @@ +//! WASM Runtime +//! +//! Manages WASM component execution using wasmtime with async support. +//! Provides a thread pool for concurrent WASM execution and metrics tracking. + +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; + +use tokio::sync::oneshot; +use tracing::{debug, error, info}; +use wasmtime::{ + component::{Component, Linker, ResourceTable}, + Config, Engine, Store, +}; +use wasmtime_wasi::WasiCtx; + +use crate::wasm::{ + config::WasmRuntimeConfig, + errors::{Result, WasmError, WasmRuntimeError}, + module::{MiddlewareAttachPoint, WasmModuleAttachPoint}, + spec::SglRouter, + types::{WasiState, WasmComponentInput, WasmComponentOutput}, +}; + +pub struct WasmRuntime { + config: WasmRuntimeConfig, + thread_pool: Arc, + // Metrics + total_executions: AtomicU64, + successful_executions: AtomicU64, + failed_executions: AtomicU64, + total_execution_time_ms: AtomicU64, + max_execution_time_ms: AtomicU64, +} + +pub struct WasmThreadPool { + sender: async_channel::Sender, + receiver: async_channel::Receiver, + workers: Vec>, + // Metrics + total_tasks: AtomicU64, + completed_tasks: AtomicU64, + failed_tasks: AtomicU64, +} + +pub enum WasmTask { + ExecuteComponent { + wasm_bytes: Vec, + attach_point: WasmModuleAttachPoint, + input: WasmComponentInput, + response: oneshot::Sender>, + }, +} + +impl WasmRuntime { + pub fn new(config: WasmRuntimeConfig) -> Result { + let thread_pool = Arc::new(WasmThreadPool::new(config.clone())?); + + Ok(Self { + config, + thread_pool, + total_executions: AtomicU64::new(0), + successful_executions: AtomicU64::new(0), + failed_executions: AtomicU64::new(0), + total_execution_time_ms: AtomicU64::new(0), + max_execution_time_ms: AtomicU64::new(0), + }) + } + + pub fn with_default_config() -> Result { + Self::new(WasmRuntimeConfig::default()) + } + + pub fn get_config(&self) -> &WasmRuntimeConfig { + &self.config + } + + /// get available cpu count and max recommended cpu count + pub fn get_cpu_info() -> (usize, usize) { + let cpu_count = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4); + let max_recommended = cpu_count.max(1); + (cpu_count, max_recommended) + } + + /// get current thread pool status + pub fn get_thread_pool_info(&self) -> (usize, usize) { + let (_cpu_count, max_recommended) = Self::get_cpu_info(); + let current_workers = self.thread_pool.workers.len(); + (current_workers, max_recommended) + } + + /// Execute WASM component using WIT interface based on attach_point + pub async fn execute_component_async( + &self, + wasm_bytes: Vec, + attach_point: WasmModuleAttachPoint, + input: WasmComponentInput, + ) -> Result { + let start_time = std::time::Instant::now(); + let (response_tx, response_rx) = oneshot::channel(); + + let task = WasmTask::ExecuteComponent { + wasm_bytes, + attach_point, + input, + response: response_tx, + }; + + self.thread_pool.sender.send(task).await.map_err(|e| { + WasmRuntimeError::CallFailed(format!("Failed to send task to thread pool: {}", e)) + })?; + + let result = response_rx.await.map_err(|e| { + WasmRuntimeError::CallFailed(format!( + "Failed to receive response from thread pool: {}", + e + )) + })?; + + let execution_time_ms = start_time.elapsed().as_millis() as u64; + self.total_executions.fetch_add(1, Ordering::Relaxed); + self.total_execution_time_ms + .fetch_add(execution_time_ms, Ordering::Relaxed); + // Update max execution time + self.max_execution_time_ms + .fetch_max(execution_time_ms, Ordering::Relaxed); + + if result.is_ok() { + self.successful_executions.fetch_add(1, Ordering::Relaxed); + } else { + self.failed_executions.fetch_add(1, Ordering::Relaxed); + } + + result + } + + /// Get current metrics + pub fn get_metrics(&self) -> (u64, u64, u64, u64, u64) { + ( + self.total_executions.load(Ordering::Relaxed), + self.successful_executions.load(Ordering::Relaxed), + self.failed_executions.load(Ordering::Relaxed), + self.total_execution_time_ms.load(Ordering::Relaxed), + self.max_execution_time_ms.load(Ordering::Relaxed), + ) + } +} + +impl WasmThreadPool { + pub fn new(config: WasmRuntimeConfig) -> Result { + let (sender, receiver) = async_channel::unbounded(); + + let mut workers = Vec::new(); + // set thread pool size based on cpu count + let max_workers = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4) + .max(1); + let num_workers = config.thread_pool_size.clamp(1, max_workers); + + info!( + target: "sglang_router_rs::wasm::runtime", + "Initializing WASM runtime with {} workers", + num_workers + ); + + for worker_id in 0..num_workers { + let receiver = receiver.clone(); + let config = config.clone(); + + let worker = std::thread::spawn(move || { + // create independent tokio runtime for this thread + let rt = match tokio::runtime::Runtime::new() { + Ok(rt) => rt, + Err(e) => { + error!( + target: "sglang_router_rs::wasm::runtime", + worker_id = worker_id, + "Failed to create tokio runtime: {}", + e + ); + return; + } + }; + + rt.block_on(async { + Self::worker_loop(worker_id, receiver, config).await; + }); + }); + + workers.push(worker); + } + + Ok(Self { + sender, + receiver, + workers, + total_tasks: AtomicU64::new(0), + completed_tasks: AtomicU64::new(0), + failed_tasks: AtomicU64::new(0), + }) + } + + /// Get current thread pool metrics + pub fn get_metrics(&self) -> (u64, u64, u64) { + ( + self.total_tasks.load(Ordering::Relaxed), + self.completed_tasks.load(Ordering::Relaxed), + self.failed_tasks.load(Ordering::Relaxed), + ) + } + + async fn worker_loop( + worker_id: usize, + receiver: async_channel::Receiver, + config: WasmRuntimeConfig, + ) { + debug!( + target: "sglang_router_rs::wasm::runtime", + worker_id = worker_id, + thread_id = ?std::thread::current().id(), + "Worker started" + ); + + let mut wasmtime_config = Config::new(); + wasmtime_config.async_stack_size(config.max_stack_size); + wasmtime_config.async_support(true); + wasmtime_config.wasm_component_model(true); // Enable component model + + let engine = match Engine::new(&wasmtime_config) { + Ok(engine) => engine, + Err(e) => { + error!( + target: "sglang_router_rs::wasm::runtime", + worker_id = worker_id, + "Failed to create engine: {}", + e + ); + return; + } + }; + + loop { + let task = match receiver.recv().await { + Ok(task) => task, + Err(_) => { + debug!( + target: "sglang_router_rs::wasm::runtime", + worker_id = worker_id, + "Worker shutting down" + ); + break; // channel closed, exit loop + } + }; + + match task { + WasmTask::ExecuteComponent { + wasm_bytes, + attach_point, + input, + response, + } => { + let result = Self::execute_component_in_worker( + &engine, + wasm_bytes, + attach_point, + input, + &config, + ) + .await; + + let _ = response.send(result); + } + } + } + } + + async fn execute_component_in_worker( + engine: &Engine, + wasm_bytes: Vec, + attach_point: WasmModuleAttachPoint, + input: WasmComponentInput, + _config: &WasmRuntimeConfig, + ) -> Result { + // Compile component from bytes + // Note: The WASM file must be in component format (not plain WASM module) + // Use `wasm-tools component new` to wrap a WASM module into a component if needed + let component = Component::new(engine, &wasm_bytes).map_err(|e| { + WasmRuntimeError::CompileFailed(format!( + "failed to parse WebAssembly component: {}. \ + Hint: The WASM file must be in component format. \ + If you're using wit-bindgen, use 'wasm-tools component new' to wrap the WASM module into a component.", + e + )) + })?; + + let mut linker = Linker::::new(engine); + wasmtime_wasi::p2::add_to_linker_async(&mut linker)?; + let mut builder = WasiCtx::builder(); + let mut store = Store::new( + engine, + WasiState { + ctx: builder.build(), + table: ResourceTable::new(), + }, + ); + + let output = match attach_point { + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest) => { + let request = match input { + WasmComponentInput::MiddlewareRequest(req) => req, + _ => { + return Err(WasmError::from(WasmRuntimeError::CallFailed( + "Expected MiddlewareRequest input for OnRequest attach point" + .to_string(), + ))); + } + }; + + // Instantiate component (must use async instantiation when async support is enabled) + let bindings = SglRouter::instantiate_async(&mut store, &component, &linker) + .await + .map_err(|e| { + WasmError::from(WasmRuntimeError::InstanceCreateFailed(e.to_string())) + })?; + + // Call on-request (async call when async support is enabled) + let action_result = bindings + .sgl_router_middleware_on_request() + .call_on_request(&mut store, &request) + .await + .map_err(|e| WasmError::from(WasmRuntimeError::CallFailed(e.to_string())))?; + + WasmComponentOutput::MiddlewareAction(action_result) + } + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse) => { + // Extract Response input + let response = match input { + WasmComponentInput::MiddlewareResponse(resp) => resp, + _ => { + return Err(WasmError::from(WasmRuntimeError::CallFailed( + "Expected MiddlewareResponse input for OnResponse attach point" + .to_string(), + ))); + } + }; + + // Instantiate component (must use async instantiation when async support is enabled) + let bindings = SglRouter::instantiate_async(&mut store, &component, &linker) + .await + .map_err(|e| { + WasmError::from(WasmRuntimeError::InstanceCreateFailed(e.to_string())) + })?; + + // Call on-response (async call when async support is enabled) + let action_result = bindings + .sgl_router_middleware_on_response() + .call_on_response(&mut store, &response) + .await + .map_err(|e| WasmError::from(WasmRuntimeError::CallFailed(e.to_string())))?; + + WasmComponentOutput::MiddlewareAction(action_result) + } + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnError) => { + return Err(WasmError::from(WasmRuntimeError::CallFailed( + "OnError attach point not yet implemented".to_string(), + ))); + } + }; + + Ok(output) + } +} + +impl Drop for WasmThreadPool { + fn drop(&mut self) { + // close sender and receiver + self.sender.close(); + self.receiver.close(); + + // wait for all workers to complete + for worker in self.workers.drain(..) { + let _ = worker.join(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::wasm::config::WasmRuntimeConfig; + + #[test] + fn test_get_cpu_info() { + let (cpu_count, max_recommended) = WasmRuntime::get_cpu_info(); + assert!(cpu_count > 0); + assert!(max_recommended > 0); + assert!(max_recommended >= cpu_count); + } + + #[test] + fn test_config_default_values() { + let config = WasmRuntimeConfig::default(); + + assert_eq!(config.max_memory_pages, 1024); + assert_eq!(config.max_execution_time_ms, 1000); + assert_eq!(config.max_stack_size, 1024 * 1024); + assert!(config.thread_pool_size > 0); + assert_eq!(config.module_cache_size, 10); + } + + #[test] + fn test_config_clone() { + let config = WasmRuntimeConfig::default(); + let cloned_config = config.clone(); + + assert_eq!(config.max_memory_pages, cloned_config.max_memory_pages); + assert_eq!( + config.max_execution_time_ms, + cloned_config.max_execution_time_ms + ); + assert_eq!(config.max_stack_size, cloned_config.max_stack_size); + assert_eq!(config.thread_pool_size, cloned_config.thread_pool_size); + assert_eq!(config.module_cache_size, cloned_config.module_cache_size); + } +} diff --git a/sgl-router/src/wasm/spec.rs b/sgl-router/src/wasm/spec.rs new file mode 100644 index 00000000000..1a2b364a254 --- /dev/null +++ b/sgl-router/src/wasm/spec.rs @@ -0,0 +1,60 @@ +//! WIT Bindings and Type Conversions +//! +//! Contains wasmtime component bindings generated from WIT definitions, +//! and helper functions to convert between Axum HTTP types and WIT types. + +use axum::http::{header, HeaderMap, HeaderValue}; + +wasmtime::component::bindgen!({ + path: "src/wasm/wit", + world: "sgl-router", + imports: { default: async | trappable }, + exports: { default: async }, +}); + +/// Build WIT headers from Axum HeaderMap +pub fn build_wit_headers_from_axum_headers( + headers: &HeaderMap, +) -> Vec { + let mut wit_headers = Vec::new(); + for (name, value) in headers.iter() { + if let Ok(value_str) = value.to_str() { + wit_headers.push(sgl::router::middleware_types::Header { + name: name.as_str().to_string(), + value: value_str.to_string(), + }); + } + } + wit_headers +} + +/// Apply ModifyAction header modifications to Axum HeaderMap +pub fn apply_modify_action_to_headers( + headers: &mut HeaderMap, + modify: &sgl::router::middleware_types::ModifyAction, +) { + // Apply headers_set + for header_mod in &modify.headers_set { + if let (Ok(name), Ok(value)) = ( + header_mod.name.parse::(), + header_mod.value.parse::(), + ) { + headers.insert(name, value); + } + } + // Apply headers_add + for header_mod in &modify.headers_add { + if let (Ok(name), Ok(value)) = ( + header_mod.name.parse::(), + header_mod.value.parse::(), + ) { + headers.append(name, value); + } + } + // Apply headers_remove + for name_str in &modify.headers_remove { + if let Ok(name) = name_str.parse::() { + headers.remove(name); + } + } +} diff --git a/sgl-router/src/wasm/types.rs b/sgl-router/src/wasm/types.rs new file mode 100644 index 00000000000..5eaca82bf2b --- /dev/null +++ b/sgl-router/src/wasm/types.rs @@ -0,0 +1,101 @@ +//! WASM Component Type System +//! +//! Provides generic input/output types for WASM component execution +//! based on attach points. + +use wasmtime::component::ResourceTable; +use wasmtime_wasi::{WasiCtx, WasiCtxView, WasiView}; + +use crate::wasm::{ + module::{MiddlewareAttachPoint, WasmModuleAttachPoint}, + spec::sgl::router::middleware_types, +}; + +/// Generic input type for WASM component execution +/// +/// This enum represents all possible input types that can be passed +/// to a WASM component, determined by the attach_point. +#[derive(Debug, Clone)] +pub enum WasmComponentInput { + /// Middleware OnRequest input + MiddlewareRequest(middleware_types::Request), + /// Middleware OnResponse input + MiddlewareResponse(middleware_types::Response), +} + +/// Generic output type from WASM component execution +/// +/// This enum represents all possible output types that can be returned +/// from a WASM component, determined by the attach_point. +#[derive(Debug, Clone)] +pub enum WasmComponentOutput { + /// Middleware Action output + MiddlewareAction(middleware_types::Action), +} + +impl WasmComponentInput { + /// Create input based on attach_point and raw data + /// + /// This helper function validates that the attach_point matches + /// the expected input type. + pub fn from_attach_point(attach_point: &WasmModuleAttachPoint) -> Result { + match attach_point { + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest) => { + // OnRequest expects a Request type, but we can't construct it here + // The caller should use MiddlewareRequest variant directly + Err("OnRequest requires MiddlewareRequest input. Use WasmComponentInput::MiddlewareRequest directly.".to_string()) + } + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse) => { + // OnResponse expects a Response type + Err("OnResponse requires MiddlewareResponse input. Use WasmComponentInput::MiddlewareResponse directly.".to_string()) + } + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnError) => { + Err("OnError attach point not yet implemented".to_string()) + } + } + } + + /// Get the expected attach_point for this input type + pub fn expected_attach_point(&self) -> WasmModuleAttachPoint { + match self { + WasmComponentInput::MiddlewareRequest(_) => { + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest) + } + WasmComponentInput::MiddlewareResponse(_) => { + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse) + } + } + } +} + +impl WasmComponentOutput { + /// Get the attach_point that produced this output type + pub fn from_attach_point(attach_point: &WasmModuleAttachPoint) -> Result { + match attach_point { + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest) => { + // This would be set after execution + Err("Cannot create output before execution".to_string()) + } + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse) => { + Err("Cannot create output before execution".to_string()) + } + WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnError) => { + Err("OnError attach point not yet implemented".to_string()) + } + } + } +} + +pub struct WasiState { + pub ctx: WasiCtx, + pub table: ResourceTable, +} + +impl WasiView for WasiState { + fn ctx(&mut self) -> WasiCtxView<'_> { + WasiCtxView { + ctx: &mut self.ctx, + table: &mut self.table, + } + } +} diff --git a/sgl-router/src/wasm/wit/spec.wit b/sgl-router/src/wasm/wit/spec.wit new file mode 100644 index 00000000000..1f865c9710f --- /dev/null +++ b/sgl-router/src/wasm/wit/spec.wit @@ -0,0 +1,54 @@ +package sgl:router; + +interface middleware-types { + record header { name: string, value: string } + + // onRequest + record request { + method: string, + path: string, + query: string, + headers: list
, + body: list, + request-id: string, + now-epoch-ms: u64, + } + + // onResponse + record response { + status: u16, + headers: list
, + body: list, + } + + // modify action + record modify-action { + status: option, + headers-set: list
, + headers-add: list
, + headers-remove: list, + body-replace: option>, + } + + // return actions + variant action { + continue, + reject(u16), // status code + modify(modify-action), + } +} + +interface middleware-on-request { + use middleware-types.{request, action}; + on-request: func(req: request) -> action; +} + +interface middleware-on-response { + use middleware-types.{response, action}; + on-response: func(resp: response) -> action; +} + +world sgl-router { + export middleware-on-request; + export middleware-on-response; +}