diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 6c1d3c2866fb..ea61d595bc66 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -268,3 +268,34 @@ def extract_routing_key(self, raw_request): if raw_request is None: return None return raw_request.headers.get("x-smg-routing-key") + + def extract_routed_dp_rank_from_header( + self, raw_request: Request, body_routed_dp_rank: Optional[int] = None + ) -> Optional[int]: + """Extract routed_dp_rank from HTTP header, with higher priority than routed_dp_rank in body. + + Header name: X-Data-Parallel-Rank (case-insensitive in HTTP/1.1/2) + """ + if raw_request is None: + return body_routed_dp_rank + + header_value = raw_request.headers.get("x-data-parallel-rank") + if header_value is not None: + try: + header_dp_rank = int(header_value) + if ( + body_routed_dp_rank is not None + and header_dp_rank != body_routed_dp_rank + ): + logger.debug( + f"X-Data-Parallel-Rank header ({header_dp_rank}) overrides " + f"body routed_dp_rank ({body_routed_dp_rank})" + ) + return header_dp_rank + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid X-Data-Parallel-Rank header: must be an integer, got '{header_value}'", + ) + + return body_routed_dp_rank diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 7913af172c3b..45efacb090e0 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -276,6 +276,11 @@ def _convert_to_internal_request( # Extract custom labels from raw request headers custom_labels = self.extract_custom_labels(raw_request) + # Extract routed_dp_rank from header (has higher priority than body) + effective_routed_dp_rank = self.extract_routed_dp_rank_from_header( + raw_request, request.routed_dp_rank + ) + # Resolve LoRA adapter from model parameter or explicit lora_path lora_path = self._resolve_lora_path(request.model, request.lora_path) img_max_dynamic_patch, vid_max_dynamic_patch = _extract_max_dynamic_patch( @@ -297,7 +302,7 @@ def _convert_to_internal_request( bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, - routed_dp_rank=request.routed_dp_rank, + routed_dp_rank=effective_routed_dp_rank, disagg_prefill_dp_rank=request.disagg_prefill_dp_rank, return_hidden_states=request.return_hidden_states, return_routed_experts=request.return_routed_experts, diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 63b11eca4202..2554fa738876 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -97,6 +97,11 @@ def _convert_to_internal_request( # Extract custom labels from raw request headers custom_labels = self.extract_custom_labels(raw_request) + # Extract routed_dp_rank from header (has higher priority than body) + effective_routed_dp_rank = self.extract_routed_dp_rank_from_header( + raw_request, request.routed_dp_rank + ) + # Resolve LoRA adapter from model parameter or explicit lora_path lora_path = self._resolve_lora_path(request.model, request.lora_path) @@ -112,7 +117,7 @@ def _convert_to_internal_request( bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, - routed_dp_rank=request.routed_dp_rank, + routed_dp_rank=effective_routed_dp_rank, disagg_prefill_dp_rank=request.disagg_prefill_dp_rank, return_hidden_states=request.return_hidden_states, return_routed_experts=request.return_routed_experts, diff --git a/test/registered/openai_server/basic/test_serving_chat.py b/test/registered/openai_server/basic/test_serving_chat.py index 2ad3ec2e7e95..27dfa7ea0d7e 100644 --- a/test/registered/openai_server/basic/test_serving_chat.py +++ b/test/registered/openai_server/basic/test_serving_chat.py @@ -783,6 +783,43 @@ async def run_stream(): self.assertEqual(len(chunks), 2) self.assertIn("error", chunks[0]) + # ------------- X-Data-Parallel-Rank header tests ------------- + def test_extract_routed_dp_rank_from_header_no_header(self): + """Test that None is returned when no header is present.""" + self.fastapi_request.headers = {} + result = self.chat.extract_routed_dp_rank_from_header( + self.fastapi_request, body_routed_dp_rank=None + ) + self.assertIsNone(result) + + def test_extract_routed_dp_rank_from_header_with_header(self): + """Test that header value is extracted correctly.""" + self.fastapi_request.headers = {"x-data-parallel-rank": "2"} + result = self.chat.extract_routed_dp_rank_from_header( + self.fastapi_request, body_routed_dp_rank=None + ) + self.assertEqual(result, 2) + + def test_extract_routed_dp_rank_header_overrides_body(self): + """Test that header value has higher priority than body.""" + self.fastapi_request.headers = {"x-data-parallel-rank": "3"} + result = self.chat.extract_routed_dp_rank_from_header( + self.fastapi_request, body_routed_dp_rank=1 + ) + self.assertEqual(result, 3) # header wins + + def test_extract_routed_dp_rank_from_header_invalid(self): + """Test that invalid header value raises HTTPException.""" + from fastapi import HTTPException + + self.fastapi_request.headers = {"x-data-parallel-rank": "abc"} + with self.assertRaises(HTTPException) as context: + self.chat.extract_routed_dp_rank_from_header( + self.fastapi_request, body_routed_dp_rank=None + ) + self.assertEqual(context.exception.status_code, 400) + self.assertIn("must be an integer", context.exception.detail) + if __name__ == "__main__": unittest.main(verbosity=2)