diff --git a/deploy/inference-gateway/epp/pkg/plugins/disagg/decode_scorer.go b/deploy/inference-gateway/epp/pkg/plugins/disagg/decode_scorer.go index 8c4f0b1fc91f..206460b55fb3 100644 --- a/deploy/inference-gateway/epp/pkg/plugins/disagg/decode_scorer.go +++ b/deploy/inference-gateway/epp/pkg/plugins/disagg/decode_scorer.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "strconv" "sync" log "sigs.k8s.io/controller-runtime/pkg/log" @@ -39,6 +40,8 @@ const ( WorkerIDHeader = "x-worker-instance-id" PrefillWorkerIDHeader = "x-prefill-instance-id" + DpRankHeader = "x-dp-rank" + PrefillDpRankHeader = "x-prefill-dp-rank" RoutingModeHeader = "x-dynamo-routing-mode" // decodeStateKey is the key used to store routing state in PluginState @@ -55,6 +58,7 @@ var _ rc.ResponseComplete = &DynDecodeScorer{} // DecodeRoutingState holds routing information passed from Score() to PreRequest(). type DecodeRoutingState struct { WorkerID string + DpRank uint32 PrefillWorkerID string TokenData []int64 } @@ -66,6 +70,7 @@ func (s *DecodeRoutingState) Clone() plugins.StateData { } clone := &DecodeRoutingState{ WorkerID: s.WorkerID, + DpRank: s.DpRank, PrefillWorkerID: s.PrefillWorkerID, } if s.TokenData != nil { @@ -157,8 +162,10 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl } workerIDStr := fmt.Sprintf("%d", result.WorkerID) + dpRankStr := strconv.FormatUint(uint64(result.DpRank), 10) logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected", "decodeWorkerID", workerIDStr, + "decodeDpRank", result.DpRank, "isDisaggregated", isDisaggregated, "tokenCount", len(result.TokenData)) @@ -167,6 +174,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl req.Headers = map[string]string{} } req.Headers[WorkerIDHeader] = workerIDStr + req.Headers[DpRankHeader] = dpRankStr if isDisaggregated { req.Headers[RoutingModeHeader] = "disaggregated" @@ -188,6 +196,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl if req.RequestId != "" { routingState := &DecodeRoutingState{ WorkerID: workerIDStr, + DpRank: result.DpRank, TokenData: result.TokenData, } s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState) @@ -226,7 +235,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL return } - if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil { + if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, state.DpRank); addErr != nil { logger.V(logutil.DEFAULT).Error(addErr, "DynDecodeScorer PreRequest: failed to add request", "requestID", request.RequestId) return @@ -235,6 +244,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL logger.V(logutil.VERBOSE).Info("DynDecodeScorer PreRequest: registered request", "requestID", request.RequestId, "workerID", state.WorkerID, + "dpRank", state.DpRank, "tokenCount", len(state.TokenData)) } diff --git a/deploy/inference-gateway/epp/pkg/plugins/disagg/prefill_scorer.go b/deploy/inference-gateway/epp/pkg/plugins/disagg/prefill_scorer.go index 55c3da4365c6..6ee9686b4705 100644 --- a/deploy/inference-gateway/epp/pkg/plugins/disagg/prefill_scorer.go +++ b/deploy/inference-gateway/epp/pkg/plugins/disagg/prefill_scorer.go @@ -120,11 +120,13 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc } prefillWorkerID := strconv.FormatUint(result.WorkerID, 10) + prefillDpRank := strconv.FormatUint(uint64(result.DpRank), 10) logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected", "prefillWorkerID", prefillWorkerID, + "prefillDpRank", result.DpRank, "tokenCount", len(result.TokenData)) - // Set the prefill worker ID header directly on the request. + // Set the prefill worker ID and DP rank headers directly on the request. // The request object is shared across all profile runs in the scheduling // cycle, so the decode scorer (which runs in the next profile) will see it. // This is more reliable than CycleState which may be scoped per profile. @@ -132,6 +134,7 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc req.Headers = map[string]string{} } req.Headers[PrefillWorkerIDHeader] = prefillWorkerID + req.Headers[PrefillDpRankHeader] = prefillDpRank // Score: 1.0 for all pods. The label-filter has already restricted to prefill workers, // and the FFI router's internal selection is authoritative. diff --git a/deploy/inference-gateway/epp/pkg/plugins/dynamo_kv_scorer/plugin.go b/deploy/inference-gateway/epp/pkg/plugins/dynamo_kv_scorer/plugin.go index 4269ce4a9383..0c66bdc926bc 100644 --- a/deploy/inference-gateway/epp/pkg/plugins/dynamo_kv_scorer/plugin.go +++ b/deploy/inference-gateway/epp/pkg/plugins/dynamo_kv_scorer/plugin.go @@ -52,6 +52,8 @@ typedef struct { bool is_disaggregated; uint64_t prefill_worker_id; uint64_t decode_worker_id; + uint32_t prefill_dp_rank; + uint32_t decode_dp_rank; uint32_t *token_ids; size_t token_count; } CRoutingResult; @@ -411,6 +413,7 @@ func CallFreeRequest(requestID string) error { // RoutingResult holds the result of a prefill or decode routing call. type RoutingResult struct { WorkerID uint64 + DpRank uint32 TokenData []int64 } @@ -455,9 +458,10 @@ func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResul } workerID := uint64(result.prefill_worker_id) + dpRank := uint32(result.prefill_dp_rank) C.free_routing_result(&result) - return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil + return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil } // CallRouteDecodeRequest routes a request to the best decode worker. @@ -501,7 +505,8 @@ func CallRouteDecodeRequest(requestJSON string, podsJSON string, isDisaggregated } workerID := uint64(result.decode_worker_id) + dpRank := uint32(result.decode_dp_rank) C.free_routing_result(&result) - return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil + return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil } diff --git a/docs/kubernetes/inference-gateway.md b/docs/kubernetes/inference-gateway.md index 89acdd22d3c2..c7a16be292d4 100644 --- a/docs/kubernetes/inference-gateway.md +++ b/docs/kubernetes/inference-gateway.md @@ -10,12 +10,17 @@ title: Inference Gateway (GAIE) Integrate Dynamo with the Gateway API Inference Extension for intelligent KV-aware request routing at the gateway layer. -EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml), following the checked-in GAIE/EPP configuration layout used by this repository. +## Features -Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. A request only exercises disaggregated routing when the EPP config defines a `prefill` profile and prefill workers are available. The standalone [`epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml) currently only defines a `decode` profile, while the recipe examples use separate aggregated and disaggregated configs under `recipes/llama-3-70b/vllm/agg/gaie/` and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/`. Unless `DYN_ENFORCE_DISAGG=true`, deployments without a `prefill` profile or prefill workers fall back to aggregated serving. -If you want to use LoRA deploy Dynamo without the Inference Gateway. +- EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml), following the checked-in GAIE/EPP configuration layout used by this repository. -Currently, these setups are only supported with the kGateway based Inference Gateway. +- Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. A request only exercises disaggregated routing when the EPP config defines a `prefill` profile and prefill workers are available. The standalone [`epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml) currently only defines a `decode` profile, while the recipe examples use separate aggregated and disaggregated configs under `recipes/llama-3-70b/vllm/agg/gaie/` and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/`. Unless `DYN_ENFORCE_DISAGG=true`, deployments without a `prefill` profile or prefill workers fall back to aggregated serving. + +- GAIE integration supports Data Parallelism. + +- If you want to use LoRA deploy Dynamo without the Inference Gateway. + +- Currently, these setups are only tested with the kGateway Inference Gateway. ## Prerequisites diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 30c9fd710f0d..84c156fe6228 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -404,6 +404,10 @@ pub struct CRoutingResult { pub prefill_worker_id: u64, /// Decode worker ID pub decode_worker_id: u64, + /// Data parallel rank selected for the prefill worker + pub prefill_dp_rank: u32, + /// Data parallel rank selected for the decode worker + pub decode_dp_rank: u32, /// Token IDs (needed for add_request callback) pub token_ids: *mut u32, /// Number of tokens in the request @@ -416,6 +420,8 @@ impl Default for CRoutingResult { is_disaggregated: false, prefill_worker_id: 0, decode_worker_id: 0, + prefill_dp_rank: 0, + decode_dp_rank: 0, token_ids: ptr::null_mut(), token_count: 0, } @@ -449,7 +455,7 @@ impl RouterHandles { lora_name: Option, priority_jump: f64, allowed_worker_ids: Option>, - ) -> Result { + ) -> Result<(u64, u32), QueryRouterResult> { if let Some(ref ids) = allowed_worker_ids { self.prefill_router.register_workers(ids); } @@ -464,7 +470,6 @@ impl RouterHandles { allowed_worker_ids, ) .await - .map(|(worker_id, _dp_rank)| worker_id) .map_err(|e| { tracing::error!(error = ?e, "Prefill query failed"); QueryRouterResult::ErrQueryFailed @@ -1203,25 +1208,27 @@ pub unsafe extern "C" fn route_prefill_request( let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) }; let result = handles.runtime.secondary().block_on(async { - let prefill_worker_id = handles + let (prefill_worker_id, prefill_dp_rank) = handles .query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids) .await?; tracing::info!( prefill_worker_id = prefill_worker_id, + prefill_dp_rank = prefill_dp_rank, token_count = tokens.len(), "Routed prefill request" ); - Ok(prefill_worker_id) + Ok((prefill_worker_id, prefill_dp_rank)) }); match result { - Ok(prefill_worker_id) => { + Ok((prefill_worker_id, prefill_dp_rank)) => { let out = unsafe { &mut *out_result }; *out = CRoutingResult::default(); out.is_disaggregated = true; out.prefill_worker_id = prefill_worker_id; + out.prefill_dp_rank = prefill_dp_rank; write_tokens_to_result(&tokens, out); QueryRouterResult::Ok } @@ -1290,6 +1297,7 @@ pub unsafe extern "C" fn route_decode_request( *out = CRoutingResult::default(); out.is_disaggregated = is_disaggregated; out.decode_worker_id = decode_worker.worker_id; + out.decode_dp_rank = decode_worker.dp_rank; write_tokens_to_result(&tokens, out); QueryRouterResult::Ok } diff --git a/lib/llm/src/kv_router/prefill_router/execution.rs b/lib/llm/src/kv_router/prefill_router/execution.rs index 11193e72640c..4f189230795e 100644 --- a/lib/llm/src/kv_router/prefill_router/execution.rs +++ b/lib/llm/src/kv_router/prefill_router/execution.rs @@ -40,7 +40,11 @@ impl PrefillRouter { // Worker selection let (worker_id, dp_rank) = if let Some(id) = preselected_worker { - let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0); + let dp_rank = req + .routing + .as_ref() + .and_then(|r| r.prefill_dp_rank.or(r.dp_rank)) + .unwrap_or(0); tracing::debug!( worker_id = id, dp_rank = dp_rank, diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 4413e4725341..7bbfda997e5d 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -324,7 +324,8 @@ impl OpenAIPreprocessor { backend_instance_id: nvext.backend_instance_id, prefill_worker_id: nvext.prefill_worker_id, decode_worker_id: nvext.decode_worker_id, - dp_rank: None, // dp_rank is set later in the pipeline + dp_rank: nvext.dp_rank, + prefill_dp_rank: nvext.prefill_dp_rank, expected_output_tokens: hints.and_then(|h| h.osl), priority_jump: hints.and_then(|h| { h.priority diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index fb4c9f2df9e1..27335fcc20ec 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -34,10 +34,14 @@ pub struct RoutingHints { #[serde(default, skip_serializing_if = "Option::is_none")] pub decode_worker_id: Option, - /// Data parallel rank for the request + /// Data parallel rank for the decode worker #[serde(default, skip_serializing_if = "Option::is_none")] pub dp_rank: Option, + /// Data parallel rank for the prefill worker in disaggregated serving + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prefill_dp_rank: Option, + /// Expected number of output tokens for this request. /// Used as a hint for routing decisions to estimate resource requirements. #[serde(default, skip_serializing_if = "Option::is_none")] diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index 0d4e146b32ac..4fbc4ea6f3d0 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -11,12 +11,16 @@ pub use crate::protocols::common::timing::TimingInfo; pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id"; pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id"; +pub const HEADER_DP_RANK: &str = "x-dp-rank"; +pub const HEADER_PREFILL_DP_RANK: &str = "x-prefill-dp-rank"; /// Apply routing overrides from HTTP headers to nvext. /// /// Header mappings: /// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id` /// - `x-prefill-instance-id` -> `prefill_worker_id` +/// - `x-dp-rank` -> `dp_rank` (decode worker's DP rank) +/// - `x-prefill-dp-rank` -> `prefill_dp_rank` /// /// Headers take priority over existing nvext values when present. /// If no headers are present, returns the original nvext unchanged. @@ -31,7 +35,18 @@ pub fn apply_header_routing_overrides(nvext: Option, headers: &HeaderMap) .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()); - if worker_id.is_none() && prefill_id.is_none() { + let dp_rank = headers + .get(HEADER_DP_RANK) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + let prefill_dp_rank = headers + .get(HEADER_PREFILL_DP_RANK) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + if worker_id.is_none() && prefill_id.is_none() && dp_rank.is_none() && prefill_dp_rank.is_none() + { return nvext; } @@ -43,6 +58,12 @@ pub fn apply_header_routing_overrides(nvext: Option, headers: &HeaderMap) if let Some(id) = prefill_id { ext.prefill_worker_id = Some(id); } + if let Some(rank) = dp_rank { + ext.dp_rank = Some(rank); + } + if let Some(rank) = prefill_dp_rank { + ext.prefill_dp_rank = Some(rank); + } Some(ext) } @@ -164,6 +185,19 @@ pub struct NvExt { #[serde(default, skip_serializing_if = "Option::is_none")] pub decode_worker_id: Option, + /// Data parallel rank for the decode worker, set by the EPP via the + /// `x-dp-rank` header. When a worker hosts multiple DP engines, + /// this steers the request to the correct engine instance. + #[builder(default, setter(strip_option))] + #[serde(default, skip_serializing_if = "Option::is_none")] + pub dp_rank: Option, + + /// Data parallel rank for the prefill worker in disaggregated serving, + /// set by the EPP via the `x-prefill-dp-rank` header. + #[builder(default, setter(strip_option))] + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prefill_dp_rank: Option, + /// Agent-provided hints for request handling. #[builder(default, setter(strip_option))] #[serde(default, skip_serializing_if = "Option::is_none")] @@ -372,29 +406,22 @@ mod tests { assert!(nv_ext.validate().is_ok()); } - // Test apply_header_routing_overrides - worker header present, prefill header absent #[test] fn test_apply_header_routing_overrides() { use axum::http::HeaderMap; - // Only HEADER_WORKER_INSTANCE_ID is in the header let mut headers = HeaderMap::new(); headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap()); - // Note: HEADER_PREFILL_INSTANCE_ID is NOT in the header - - let nvext = NvExt::builder() - .backend_instance_id(999) - .decode_worker_id(888) - .prefill_worker_id(777) - .build() - .unwrap(); + headers.insert(HEADER_PREFILL_INSTANCE_ID, "456".parse().unwrap()); + headers.insert(HEADER_DP_RANK, "3".parse().unwrap()); + headers.insert(HEADER_PREFILL_DP_RANK, "5".parse().unwrap()); - let result = apply_header_routing_overrides(Some(nvext), &headers).unwrap(); + let result = apply_header_routing_overrides(None, &headers).unwrap(); - // Header should override backend_instance_id and decode_worker_id assert_eq!(result.backend_instance_id, Some(123)); assert_eq!(result.decode_worker_id, Some(123)); - // prefill_worker_id should remain from original nvext (not overwritten by header) - assert_eq!(result.prefill_worker_id, Some(777)); + assert_eq!(result.prefill_worker_id, Some(456)); + assert_eq!(result.dp_rank, Some(3)); + assert_eq!(result.prefill_dp_rank, Some(5)); } } diff --git a/tests/router/common.py b/tests/router/common.py index f7909d972b2f..46765164a1ec 100644 --- a/tests/router/common.py +++ b/tests/router/common.py @@ -2161,6 +2161,8 @@ async def run_direct_mode_tests(): headers = { "x-worker-instance-id": str(target_decode), "x-prefill-instance-id": str(target_prefill), + "x-dp-rank": "0", + "x-prefill-dp-rank": "0", } async with aiohttp.ClientSession() as session: