diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index b58f713a1376..0407a12f7fab 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -340,7 +340,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None: mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, - prefill_dp_rank=req.data_parallel_rank, + prefill_dp_rank=req.prefill_data_parallel_rank, ) req.add_latency(RequestStage.DECODE_PREPARE) diff --git a/python/sglang/srt/entrypoints/EngineBase.py b/python/sglang/srt/entrypoints/EngineBase.py index 5d3162afd514..b11258bd4d45 100644 --- a/python/sglang/srt/entrypoints/EngineBase.py +++ b/python/sglang/srt/entrypoints/EngineBase.py @@ -29,6 +29,7 @@ def generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + prefill_data_parallel_rank: Optional[int] = None, rid: Optional[Union[List[str], str]] = None, ) -> Union[Dict, Iterator[Dict]]: """Generate outputs based on given inputs.""" diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 1c00faf3ad05..da31e2013622 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -230,6 +230,7 @@ def generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + prefill_data_parallel_rank: Optional[int] = None, rid: Optional[Union[List[str], str]] = None, ) -> Union[Dict, Iterator[Dict]]: """ @@ -266,6 +267,7 @@ def generate( bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, data_parallel_rank=data_parallel_rank, + prefill_data_parallel_rank=prefill_data_parallel_rank, rid=rid, ) generator = self.tokenizer_manager.generate_request(obj, None) @@ -315,6 +317,7 @@ async def async_generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + prefill_data_parallel_rank: Optional[int] = None, rid: Optional[Union[List[str], str]] = None, ) -> Union[Dict, AsyncIterator[Dict]]: """ @@ -352,6 +355,7 @@ async def async_generate( bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, data_parallel_rank=data_parallel_rank, + prefill_data_parallel_rank=prefill_data_parallel_rank, rid=rid, ) generator = self.tokenizer_manager.generate_request(obj, None) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 34aa364cf27b..129e2023bca7 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -259,6 +259,7 @@ class CompletionRequest(BaseModel): # For data parallel rank routing data_parallel_rank: Optional[int] = None + prefill_data_parallel_rank: Optional[int] = None # For request id rid: Optional[Union[List[str], str]] = None @@ -536,6 +537,7 @@ class ChatCompletionRequest(BaseModel): # For data parallel rank routing data_parallel_rank: Optional[int] = None + prefill_data_parallel_rank: Optional[int] = None # OpenAI/SGLang default sampling parameters _DEFAULT_SAMPLING_PARAMS = { diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index cb0c084a3f0a..1f3c91f32cf3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -212,6 +212,7 @@ def _convert_to_internal_request( bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, data_parallel_rank=request.data_parallel_rank, + prefill_data_parallel_rank=request.prefill_data_parallel_rank, return_hidden_states=request.return_hidden_states, rid=request.rid, extra_key=self._compute_extra_key(request), diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index c6bd4bc9f35c..cb8ac95cedf2 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -117,6 +117,7 @@ def _convert_to_internal_request( bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, data_parallel_rank=request.data_parallel_rank, + prefill_data_parallel_rank=request.prefill_data_parallel_rank, return_hidden_states=request.return_hidden_states, rid=request.rid, extra_key=self._compute_extra_key(request), diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 02131604dcf0..fb62ea0ab059 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -230,6 +230,7 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): # For data parallel rank routing data_parallel_rank: Optional[int] = None + prefill_data_parallel_rank: Optional[int] = None # For background responses (OpenAI responses API) background: bool = False @@ -655,6 +656,11 @@ def __getitem__(self, i): data_parallel_rank=( self.data_parallel_rank if self.data_parallel_rank is not None else None ), + prefill_data_parallel_rank=( + self.prefill_data_parallel_rank + if self.prefill_data_parallel_rank is not None + else None + ), conversation_id=self.conversation_id, priority=self.priority, extra_key=self.extra_key, @@ -723,6 +729,7 @@ class TokenizedGenerateReqInput(BaseReq): # For data parallel rank routing data_parallel_rank: Optional[int] = None + prefill_data_parallel_rank: Optional[int] = None # Priority for the request priority: Optional[int] = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 05125bd3a8c9..72564461781c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -510,6 +510,7 @@ def __init__( bootstrap_room: Optional[int] = None, disagg_mode: Optional[DisaggregationMode] = None, data_parallel_rank: Optional[int] = None, + prefill_data_parallel_rank: Optional[int] = None, vocab_size: Optional[int] = None, priority: Optional[int] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None, @@ -737,6 +738,7 @@ def __init__( # For data parallel rank routing self.data_parallel_rank: Optional[int] = data_parallel_rank + self.prefill_data_parallel_rank: Optional[int] = prefill_data_parallel_rank # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 454998cb1440..f0ac81d265fd 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1447,6 +1447,7 @@ def handle_generate_request( bootstrap_room=recv_req.bootstrap_room, disagg_mode=self.disaggregation_mode, data_parallel_rank=recv_req.data_parallel_rank, + prefill_data_parallel_rank=recv_req.prefill_data_parallel_rank, vocab_size=self.model_config.vocab_size, priority=recv_req.priority, metrics_collector=( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 62cde8d7ff94..f7ab8904c71e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -914,6 +914,7 @@ def _create_tokenized_object( return_hidden_states=obj.return_hidden_states, return_routed_experts=obj.return_routed_experts, data_parallel_rank=obj.data_parallel_rank, + prefill_data_parallel_rank=obj.prefill_data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, need_wait_for_image=obj.need_wait_for_image,