Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/sglang/srt/entrypoints/openai/serving_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,34 @@ def extract_routing_key(self, raw_request):
if raw_request is None:
return None
return raw_request.headers.get("x-smg-routing-key")

def extract_routed_dp_rank_from_header(
self, raw_request: Request, body_routed_dp_rank: Optional[int] = None
) -> Optional[int]:
"""Extract routed_dp_rank from HTTP header, with higher priority than routed_dp_rank in body.

Header name: X-Data-Parallel-Rank (case-insensitive in HTTP/1.1/2)
"""
if raw_request is None:
return body_routed_dp_rank

header_value = raw_request.headers.get("x-data-parallel-rank")
if header_value is not None:
try:
header_dp_rank = int(header_value)
if (
body_routed_dp_rank is not None
and header_dp_rank != body_routed_dp_rank
):
logger.debug(
f"X-Data-Parallel-Rank header ({header_dp_rank}) overrides "
f"body routed_dp_rank ({body_routed_dp_rank})"
)
return header_dp_rank
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid X-Data-Parallel-Rank header: must be an integer, got '{header_value}'",
)

return body_routed_dp_rank
7 changes: 6 additions & 1 deletion python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ def _convert_to_internal_request(
# Extract custom labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request)

# Extract routed_dp_rank from header (has higher priority than body)
effective_routed_dp_rank = self.extract_routed_dp_rank_from_header(
raw_request, request.routed_dp_rank
)

# Resolve LoRA adapter from model parameter or explicit lora_path
lora_path = self._resolve_lora_path(request.model, request.lora_path)
img_max_dynamic_patch, vid_max_dynamic_patch = _extract_max_dynamic_patch(
Expand All @@ -297,7 +302,7 @@ def _convert_to_internal_request(
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
routed_dp_rank=request.routed_dp_rank,
routed_dp_rank=effective_routed_dp_rank,
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def _convert_to_internal_request(
# Extract custom labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request)

# Extract routed_dp_rank from header (has higher priority than body)
effective_routed_dp_rank = self.extract_routed_dp_rank_from_header(
raw_request, request.routed_dp_rank
)

# Resolve LoRA adapter from model parameter or explicit lora_path
lora_path = self._resolve_lora_path(request.model, request.lora_path)

Expand All @@ -112,7 +117,7 @@ def _convert_to_internal_request(
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
routed_dp_rank=request.routed_dp_rank,
routed_dp_rank=effective_routed_dp_rank,
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
Expand Down
37 changes: 37 additions & 0 deletions test/registered/openai_server/basic/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,43 @@ async def run_stream():
self.assertEqual(len(chunks), 2)
self.assertIn("error", chunks[0])

# ------------- X-Data-Parallel-Rank header tests -------------
def test_extract_routed_dp_rank_from_header_no_header(self):
"""Test that None is returned when no header is present."""
self.fastapi_request.headers = {}
result = self.chat.extract_routed_dp_rank_from_header(
self.fastapi_request, body_routed_dp_rank=None
)
self.assertIsNone(result)

def test_extract_routed_dp_rank_from_header_with_header(self):
"""Test that header value is extracted correctly."""
self.fastapi_request.headers = {"x-data-parallel-rank": "2"}
result = self.chat.extract_routed_dp_rank_from_header(
self.fastapi_request, body_routed_dp_rank=None
)
self.assertEqual(result, 2)

def test_extract_routed_dp_rank_header_overrides_body(self):
"""Test that header value has higher priority than body."""
self.fastapi_request.headers = {"x-data-parallel-rank": "3"}
result = self.chat.extract_routed_dp_rank_from_header(
self.fastapi_request, body_routed_dp_rank=1
)
self.assertEqual(result, 3) # header wins

def test_extract_routed_dp_rank_from_header_invalid(self):
"""Test that invalid header value raises HTTPException."""
from fastapi import HTTPException

self.fastapi_request.headers = {"x-data-parallel-rank": "abc"}
with self.assertRaises(HTTPException) as context:
self.chat.extract_routed_dp_rank_from_header(
self.fastapi_request, body_routed_dp_rank=None
)
self.assertEqual(context.exception.status_code, 400)
self.assertIn("must be an integer", context.exception.detail)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading