Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_dp_rank=req.data_parallel_rank,
prefill_dp_rank=req.prefill_data_parallel_rank,
)

req.add_latency(RequestStage.DECODE_PREPARE)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/EngineBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def generate(
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
prefill_data_parallel_rank: Optional[int] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Union[Dict, Iterator[Dict]]:
"""Generate outputs based on given inputs."""
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def generate(
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
prefill_data_parallel_rank: Optional[int] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Union[Dict, Iterator[Dict]]:
"""
Expand Down Expand Up @@ -266,6 +267,7 @@ def generate(
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
prefill_data_parallel_rank=prefill_data_parallel_rank,
rid=rid,
)
generator = self.tokenizer_manager.generate_request(obj, None)
Expand Down Expand Up @@ -315,6 +317,7 @@ async def async_generate(
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
prefill_data_parallel_rank: Optional[int] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Union[Dict, AsyncIterator[Dict]]:
"""
Expand Down Expand Up @@ -352,6 +355,7 @@ async def async_generate(
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
prefill_data_parallel_rank=prefill_data_parallel_rank,
rid=rid,
)
generator = self.tokenizer_manager.generate_request(obj, None)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class CompletionRequest(BaseModel):

# For data parallel rank routing
data_parallel_rank: Optional[int] = None
prefill_data_parallel_rank: Optional[int] = None

# For request id
rid: Optional[Union[List[str], str]] = None
Expand Down Expand Up @@ -536,6 +537,7 @@ class ChatCompletionRequest(BaseModel):

# For data parallel rank routing
data_parallel_rank: Optional[int] = None
prefill_data_parallel_rank: Optional[int] = None

# OpenAI/SGLang default sampling parameters
_DEFAULT_SAMPLING_PARAMS = {
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _convert_to_internal_request(
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
data_parallel_rank=request.data_parallel_rank,
prefill_data_parallel_rank=request.prefill_data_parallel_rank,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
extra_key=self._compute_extra_key(request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _convert_to_internal_request(
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
data_parallel_rank=request.data_parallel_rank,
prefill_data_parallel_rank=request.prefill_data_parallel_rank,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
extra_key=self._compute_extra_key(request),
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin):

# For data parallel rank routing
data_parallel_rank: Optional[int] = None
prefill_data_parallel_rank: Optional[int] = None

# For background responses (OpenAI responses API)
background: bool = False
Expand Down Expand Up @@ -655,6 +656,11 @@ def __getitem__(self, i):
data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
prefill_data_parallel_rank=(
self.prefill_data_parallel_rank
if self.prefill_data_parallel_rank is not None
else None
),
conversation_id=self.conversation_id,
priority=self.priority,
extra_key=self.extra_key,
Expand Down Expand Up @@ -723,6 +729,7 @@ class TokenizedGenerateReqInput(BaseReq):

# For data parallel rank routing
data_parallel_rank: Optional[int] = None
prefill_data_parallel_rank: Optional[int] = None

# Priority for the request
priority: Optional[int] = None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
bootstrap_room: Optional[int] = None,
disagg_mode: Optional[DisaggregationMode] = None,
data_parallel_rank: Optional[int] = None,
prefill_data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
Expand Down Expand Up @@ -737,6 +738,7 @@ def __init__(

# For data parallel rank routing
self.data_parallel_rank: Optional[int] = data_parallel_rank
self.prefill_data_parallel_rank: Optional[int] = prefill_data_parallel_rank

# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,7 @@ def handle_generate_request(
bootstrap_room=recv_req.bootstrap_room,
disagg_mode=self.disaggregation_mode,
data_parallel_rank=recv_req.data_parallel_rank,
prefill_data_parallel_rank=recv_req.prefill_data_parallel_rank,
vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
metrics_collector=(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ def _create_tokenized_object(
return_hidden_states=obj.return_hidden_states,
return_routed_experts=obj.return_routed_experts,
data_parallel_rank=obj.data_parallel_rank,
prefill_data_parallel_rank=obj.prefill_data_parallel_rank,
priority=obj.priority,
extra_key=obj.extra_key,
need_wait_for_image=obj.need_wait_for_image,
Expand Down
Loading