Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,15 +355,15 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
req.retraction_mb_id = None
self.retracted_queue.append(req)
else:
dp_rank = self._resolve_dp_rank(req)
if dp_rank is None:
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
if prefill_dp_rank is None:
self.pending_reqs.append(req)
return
self._create_receiver_and_enqueue(req, dp_rank)
self._create_receiver_and_enqueue(req, prefill_dp_rank)

def _resolve_dp_rank(self, req: Req) -> Optional[int]:
if req.data_parallel_rank is not None:
return req.data_parallel_rank
def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:
if req.disagg_prefill_dp_rank is not None:
return req.disagg_prefill_dp_rank

if _is_fake_transfer(req, self.scheduler.server_args):
return 0
Expand All @@ -379,7 +379,7 @@ def _resolve_dp_rank(self, req: Req) -> Optional[int]:

return None

def _create_receiver_and_enqueue(self, req: Req, dp_rank: int) -> None:
def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:
backend = (
TransferBackend.FAKE
if _is_fake_transfer(req, self.scheduler.server_args)
Expand All @@ -391,7 +391,7 @@ def _create_receiver_and_enqueue(self, req: Req, dp_rank: int) -> None:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_dp_rank=dp_rank,
prefill_dp_rank=prefill_dp_rank,
)

self.queue.append(
Expand Down Expand Up @@ -493,26 +493,26 @@ def _update_handshake_waiters(
raise ValueError(f"Unexpected poll case: {poll}")

def _resolve_pending_reqs(self) -> None:
"""Batch-resolve dp_ranks for pending requests and create receivers."""
"""Batch-resolve prefill_dp_ranks for pending requests and create receivers."""
if not self.pending_reqs:
return

bootstrap_addr = f"{self.pending_reqs[0].bootstrap_host}:{self.pending_reqs[0].bootstrap_port}"

# If a request is following the bootstrap room,
# we need get the prefill info before resolving the dp_rank,
# we need get the prefill info before resolving the prefill_dp_ranks
# which is a conflict with the lazy resolve logic in CommonKVReceiver,
# so we need to ensure the parallel info before resolving the dp_rank
# so we need to ensure the parallel info before resolving it.
if not self.kv_manager.ensure_parallel_info(bootstrap_addr):
return

resolved = []
need_query = []
for req in self.pending_reqs:
# NOTE: we need resolve it again because we may ensure the parallel info here
dp_rank = self._resolve_dp_rank(req)
if dp_rank is not None:
resolved.append((req, dp_rank))
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
if prefill_dp_rank is not None:
resolved.append((req, prefill_dp_rank))
else:
need_query.append(req)

Expand All @@ -534,8 +534,8 @@ def _resolve_pending_reqs(self) -> None:
else:
self.pending_reqs = []

for req, dp_rank in resolved:
self._create_receiver_and_enqueue(req, dp_rank)
for req, prefill_dp_rank in resolved:
self._create_receiver_and_enqueue(req, prefill_dp_rank)

def pop_preallocated(
self, rids_to_check: Optional[List[str]] = None
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/disaggregation/encode_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def __init__(
skip_mm_pool=True,
)

def create_req(self, recv_req):
def create_req(self, recv_req: TokenizedGenerateReqInput):
req = Req(
recv_req.rid,
recv_req.input_text,
Expand All @@ -362,7 +362,8 @@ def create_req(self, recv_req):
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
disagg_mode=self.scheduler.disaggregation_mode,
data_parallel_rank=recv_req.data_parallel_rank,
routed_dp_rank=recv_req.routed_dp_rank,
disagg_prefill_dp_rank=recv_req.disagg_prefill_dp_rank,
vocab_size=self.scheduler.model_config.vocab_size,
priority=recv_req.priority,
metrics_collector=(
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/entrypoints/EngineBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def generate(
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
routed_dp_rank: Optional[int] = None,
disagg_prefill_dp_rank: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
rid: Optional[Union[List[str], str]] = None,
priority: Optional[int] = None,
Expand Down
67 changes: 45 additions & 22 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,35 @@ def __init__(self, **kwargs):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

def _resolve_routed_dp_rank(
self,
routed_dp_rank: Optional[int],
data_parallel_rank: Optional[int],
) -> Optional[int]:
if data_parallel_rank is not None:
import warnings

warnings.warn(
"'data_parallel_rank' is deprecated, use 'routed_dp_rank' instead.",
DeprecationWarning,
stacklevel=3,
)
if routed_dp_rank is None:
routed_dp_rank = data_parallel_rank

if self.server_args.enable_dp_attention:
if routed_dp_rank is None:
logger.debug("routed_dp_rank not provided, using default dispatch")
elif routed_dp_rank < 0:
raise ValueError("routed_dp_rank must be non-negative")
elif routed_dp_rank >= self.server_args.dp_size:
raise ValueError(
f"routed_dp_rank must be less than dp_size: {self.server_args.dp_size}"
)

logger.debug(f"routed_dp_rank: {routed_dp_rank}")
return routed_dp_rank

def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
Expand Down Expand Up @@ -232,6 +261,9 @@ def generate(
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
routed_dp_rank: Optional[int] = None,
disagg_prefill_dp_rank: Optional[int] = None,
# Deprecated: use routed_dp_rank instead
data_parallel_rank: Optional[int] = None,
external_trace_header: Optional[Dict] = None,
rid: Optional[Union[List[str], str]] = None,
Expand All @@ -242,15 +274,9 @@ def generate(
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.debug("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
)
routed_dp_rank = self._resolve_routed_dp_rank(
routed_dp_rank, data_parallel_rank
)

obj = GenerateReqInput(
text=prompt,
Expand All @@ -271,7 +297,8 @@ def generate(
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
routed_dp_rank=routed_dp_rank,
disagg_prefill_dp_rank=disagg_prefill_dp_rank,
external_trace_header=external_trace_header,
rid=rid,
session_params=session_params,
Expand Down Expand Up @@ -324,6 +351,9 @@ async def async_generate(
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
routed_dp_rank: Optional[int] = None,
disagg_prefill_dp_rank: Optional[int] = None,
# Deprecated: use routed_dp_rank instead
data_parallel_rank: Optional[int] = None,
external_trace_header: Optional[Dict] = None,
rid: Optional[Union[List[str], str]] = None,
Expand All @@ -334,18 +364,10 @@ async def async_generate(
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
routed_dp_rank = self._resolve_routed_dp_rank(
routed_dp_rank, data_parallel_rank
)

if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.debug("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
)

logger.debug(f"data_parallel_rank: {data_parallel_rank}")
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
Expand All @@ -365,7 +387,8 @@ async def async_generate(
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
routed_dp_rank=routed_dp_rank,
disagg_prefill_dp_rank=disagg_prefill_dp_rank,
external_trace_header=external_trace_header,
rid=rid,
session_params=session_params,
Expand Down
36 changes: 34 additions & 2 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ class BatchResponse(BaseModel):
metadata: Optional[dict] = None


def _migrate_deprecated_dp_rank(values: dict) -> dict:
if isinstance(values, dict) and values.get("data_parallel_rank") is not None:
import warnings

warnings.warn(
"'data_parallel_rank' is deprecated, use 'routed_dp_rank' instead.",
DeprecationWarning,
stacklevel=2,
)
if values.get("routed_dp_rank") is None:
values["routed_dp_rank"] = values["data_parallel_rank"]
return values


class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
Expand Down Expand Up @@ -285,7 +299,11 @@ class CompletionRequest(BaseModel):
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None

# For data parallel rank routing
# For DP routing — external router assigns a specific DP worker
routed_dp_rank: Optional[int] = None
# For PD disagg — hint telling decode which prefill DP worker has the KV cache
disagg_prefill_dp_rank: Optional[int] = None
# Deprecated: use routed_dp_rank instead
data_parallel_rank: Optional[int] = None

# For request id
Expand All @@ -300,6 +318,11 @@ class CompletionRequest(BaseModel):
# For custom metric labels
custom_labels: Optional[Dict[str, str]] = None

@model_validator(mode="before")
@classmethod
def _handle_deprecated_dp_rank(cls, values):
return _migrate_deprecated_dp_rank(values)

@field_validator("max_tokens")
@classmethod
def validate_max_tokens_positive(cls, v):
Expand Down Expand Up @@ -614,7 +637,11 @@ class ChatCompletionRequest(BaseModel):
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None

# For data parallel rank routing
# For DP routing — external router assigns a specific DP worker
routed_dp_rank: Optional[int] = None
# For PD disagg — hint telling decode which prefill DP worker has the KV cache
disagg_prefill_dp_rank: Optional[int] = None
# Deprecated: use routed_dp_rank instead
data_parallel_rank: Optional[int] = None

# OpenAI/SGLang default sampling parameters
Expand All @@ -626,6 +653,11 @@ class ChatCompletionRequest(BaseModel):
"repetition_penalty": 1.0,
}

@model_validator(mode="before")
@classmethod
def _handle_deprecated_dp_rank(cls, values):
return _migrate_deprecated_dp_rank(values)

@model_validator(mode="before")
@classmethod
def set_tool_choice_default(cls, values):
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def _convert_to_internal_request(
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
data_parallel_rank=request.data_parallel_rank,
routed_dp_rank=request.routed_dp_rank,
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
rid=request.rid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _convert_to_internal_request(
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
data_parallel_rank=request.data_parallel_rank,
routed_dp_rank=request.routed_dp_rank,
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
rid=request.rid,
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,9 @@ def launch_tensor_parallel_group(
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]

def maybe_external_dp_rank_routing(self, req: Req):
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
if req.routed_dp_rank is not None:
logger.debug(f"Direct routing to DP rank {req.routed_dp_rank}")
self.workers[req.routed_dp_rank].send_pyobj(req)
return True
return False

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
retraction_counts=recv_obj.retraction_counts,
token_steps=recv_obj.token_steps,
load=recv_obj.load,
dp_ranks=recv_obj.dp_ranks,
time_stats=recv_obj.time_stats,
)

Expand Down
Loading
Loading