Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 265 additions & 2 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
import weakref
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import torch
import torch.distributed
Expand Down Expand Up @@ -2324,3 +2324,266 @@ def monkey_patch_vllm_parallel_state(reverse: bool = False):
setattr(vllm_parrlel_state, "get_pp_group", get_pp_group)
setattr(vllm_parrlel_state, "get_tp_group", get_tp_group)
setattr(vllm_parrlel_state, "get_world_group", get_world_group)


@dataclass
class RankParallelismConfig:
"""
Complete parallelism configuration for a single inference rank.

This configuration captures all the parallelism settings needed to recreate
a model shard outside of sglang. It supports:
- TP/PP/EP for model parallelism
- MoE-TP/Attn-TP/Attn-DP for MoE and DP attention.
"""

tp_size: int = 1
tp_rank: int = 0
pp_size: int = 1
pp_rank: int = 0
ep_size: int = 1
ep_rank: int = 0
moe_tp_size: int = 1
moe_tp_rank: int = 0
attn_tp_size: int = 1
attn_tp_rank: int = 0
attn_dp_size: int = 1
attn_dp_rank: int = 0
attn_cp_size: int = 1
attn_cp_rank: int = 0
moe_dp_size: int = 1
moe_dp_rank: int = 0

world_size: int = 1
global_rank: int = 0
local_rank: int = 0

@property
def has_dp_attention(self) -> bool:
"""Check if DP attention is enabled."""
return self.attn_dp_size > 1

@property
def has_expert_parallelism(self) -> bool:
"""Check if expert parallelism is enabled."""
return self.ep_size > 1

@property
def has_context_parallelism(self) -> bool:
"""Check if context parallelism is enabled."""
return self.attn_cp_size > 1

@property
def has_moe_data_parallelism(self) -> bool:
"""Check if MoE data parallelism is enabled."""
return self.moe_dp_size > 1

def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return asdict(self)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RankParallelismConfig":
"""Create from dictionary, filtering unknown fields."""
import dataclasses

valid_fields = {f.name for f in dataclasses.fields(cls)}
filtered_data = {k: v for k, v in data.items() if k in valid_fields}
return cls(**filtered_data)

@classmethod
def from_parallel_state(cls, local_rank: int = 0) -> "RankParallelismConfig":
"""Extract current parallelism settings from the global parallel state."""
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

# Import dp_attention lazily to avoid circular imports
from sglang.srt.layers.dp_attention import (
get_attention_cp_rank,
get_attention_cp_size,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)

return cls(
tp_size=tp_size,
tp_rank=tp_rank,
pp_size=get_pipeline_model_parallel_world_size(),
pp_rank=get_pipeline_model_parallel_rank(),
ep_size=get_moe_expert_parallel_world_size(),
ep_rank=get_moe_expert_parallel_rank(),
moe_tp_size=get_moe_tensor_parallel_world_size(),
moe_tp_rank=get_moe_tensor_parallel_rank(),
attn_tp_size=get_attention_tp_size(),
attn_tp_rank=get_attention_tp_rank(),
attn_dp_size=get_attention_dp_size(),
attn_dp_rank=get_attention_dp_rank(),
attn_cp_size=get_attention_cp_size(),
attn_cp_rank=get_attention_cp_rank(),
moe_dp_size=get_moe_data_parallel_world_size(),
moe_dp_rank=get_moe_data_parallel_rank(),
world_size=(
torch.distributed.get_world_size()
if torch.distributed.is_initialized()
else 1
),
global_rank=(
torch.distributed.get_rank()
if torch.distributed.is_initialized()
else 0
),
local_rank=local_rank,
)

def __repr__(self) -> str:
parts = [
f"TP={self.tp_rank}/{self.tp_size}",
f"PP={self.pp_rank}/{self.pp_size}",
]
if self.has_expert_parallelism:
parts.append(f"EP={self.ep_rank}/{self.ep_size}")
parts.append(f"MoE-TP={self.moe_tp_rank}/{self.moe_tp_size}")
if self.has_dp_attention:
parts.append(f"AttnTP={self.attn_tp_rank}/{self.attn_tp_size}")
parts.append(f"AttnDP={self.attn_dp_rank}/{self.attn_dp_size}")
if self.has_context_parallelism:
parts.append(f"AttnCP={self.attn_cp_rank}/{self.attn_cp_size}")
if self.has_moe_data_parallelism:
parts.append(f"MoE-DP={self.moe_dp_rank}/{self.moe_dp_size}")
parts.append(f"Global={self.global_rank}/{self.world_size}")
return f"RankParallelismConfig({', '.join(parts)})"


class ParallelismContext:
"""
Context manager for creating model replicas with specific parallelism settings.

This context manager temporarily sets global variables to allow creating model
shards outside of a real distributed environment.
Usage:
server_args = ... # Get from engine.get_server_info()
parallelism_info = engine.get_parallelism_config(rank)
sglang.srt.server_args._global_server_args = server_args

with ParallelismContext(RankParallelismConfig.from_dict(parallelism_info)):
model = get_model(
model_config=model_config,
load_config=load_config,
device_config=device_config,
)
"""

def __init__(self, parallelism_config: RankParallelismConfig):
self.config = parallelism_config
self._original_globals: Dict[str, Any] = {}

def _create_mock_group(self, world_size: int, rank_in_group: int):
"""Create a mock group coordinator with all necessary properties."""
mock_group = MagicMock()
mock_group.world_size = world_size
mock_group.rank_in_group = rank_in_group
mock_group.rank = rank_in_group # Use rank_in_group for .rank as well
mock_group.local_rank = rank_in_group
mock_group.ranks = list(range(world_size))
mock_group.first_rank = 0
mock_group.last_rank = world_size - 1
mock_group.is_first_rank = rank_in_group == 0
mock_group.is_last_rank = rank_in_group == world_size - 1
mock_group.next_rank = mock_group.ranks[(rank_in_group + 1) % world_size]
mock_group.prev_rank = mock_group.ranks[(rank_in_group - 1) % world_size]
return mock_group

def __enter__(self):
conf = self.config

# Import the modules we need to modify
from sglang.srt.distributed import parallel_state
from sglang.srt.layers import dp_attention

# Save original global variables to restore later
self._original_globals["_TP"] = getattr(parallel_state, "_TP", None)
self._original_globals["_PP"] = getattr(parallel_state, "_PP", None)
self._original_globals["_MOE_EP"] = getattr(parallel_state, "_MOE_EP", None)
self._original_globals["_MOE_TP"] = getattr(parallel_state, "_MOE_TP", None)
self._original_globals["_ATTN_TP"] = getattr(parallel_state, "_ATTN_TP", None)
self._original_globals["_ATTN_CP"] = getattr(parallel_state, "_ATTN_CP", None)
self._original_globals["_MOE_DP"] = getattr(parallel_state, "_MOE_DP", None)
self._original_globals["_ATTN_TP_RANK"] = getattr(
dp_attention, "_ATTN_TP_RANK", None
)
self._original_globals["_ATTN_TP_SIZE"] = getattr(
dp_attention, "_ATTN_TP_SIZE", None
)
self._original_globals["_ATTN_DP_RANK"] = getattr(
dp_attention, "_ATTN_DP_RANK", None
)
self._original_globals["_ATTN_DP_SIZE"] = getattr(
dp_attention, "_ATTN_DP_SIZE", None
)
self._original_globals["_ENABLE_DP_ATTENTION_FLAG"] = getattr(
dp_attention, "_ENABLE_DP_ATTENTION_FLAG", False
)

# Create mock group objects with the correct attributes
mock_tp_group = self._create_mock_group(conf.tp_size, conf.tp_rank)
mock_pp_group = self._create_mock_group(conf.pp_size, conf.pp_rank)
mock_ep_group = self._create_mock_group(conf.ep_size, conf.ep_rank)
mock_moe_tp_group = self._create_mock_group(conf.moe_tp_size, conf.moe_tp_rank)
mock_attn_tp_group = self._create_mock_group(conf.attn_tp_size, conf.attn_tp_rank)
mock_attn_cp_group = self._create_mock_group(conf.attn_cp_size, conf.attn_cp_rank)
mock_moe_dp_group = self._create_mock_group(conf.moe_dp_size, conf.moe_dp_rank)

# Set the global group objects directly on parallel_state module
parallel_state._TP = mock_tp_group
parallel_state._PP = mock_pp_group
parallel_state._MOE_EP = mock_ep_group
parallel_state._MOE_TP = mock_moe_tp_group
parallel_state._ATTN_TP = mock_attn_tp_group
parallel_state._ATTN_CP = mock_attn_cp_group
parallel_state._MOE_DP = mock_moe_dp_group

# Set dp_attention globals directly
dp_attention._ATTN_TP_RANK = conf.attn_tp_rank
dp_attention._ATTN_TP_SIZE = conf.attn_tp_size
dp_attention._ATTN_DP_RANK = conf.attn_dp_rank
dp_attention._ATTN_DP_SIZE = conf.attn_dp_size
# Enable DP attention flag if attn_dp_size > 1
dp_attention._ENABLE_DP_ATTENTION_FLAG = conf.attn_dp_size > 1

logger.info(
f"[ParallelismContext] Activated: TP={conf.tp_rank}/{conf.tp_size}, "
f"PP={conf.pp_rank}/{conf.pp_size}, EP={conf.ep_rank}/{conf.ep_size}, "
f"MoE-TP={conf.moe_tp_rank}/{conf.moe_tp_size}, "
f"AttnTP={conf.attn_tp_rank}/{conf.attn_tp_size}, "
f"AttnDP={conf.attn_dp_rank}/{conf.attn_dp_size}, "
f"AttnCP={conf.attn_cp_rank}/{conf.attn_cp_size}, "
f"MoE-DP={conf.moe_dp_rank}/{conf.moe_dp_size}"
)
return self

def __exit__(self, *args):
from sglang.srt.distributed import parallel_state
from sglang.srt.layers import dp_attention

# Restore original parallel_state globals
parallel_state._TP = self._original_globals.get("_TP")
parallel_state._PP = self._original_globals.get("_PP")
parallel_state._MOE_EP = self._original_globals.get("_MOE_EP")
parallel_state._MOE_TP = self._original_globals.get("_MOE_TP")
parallel_state._ATTN_TP = self._original_globals.get("_ATTN_TP")
parallel_state._ATTN_CP = self._original_globals.get("_ATTN_CP")
parallel_state._MOE_DP = self._original_globals.get("_MOE_DP")

# Restore original dp_attention globals
dp_attention._ATTN_TP_RANK = self._original_globals.get("_ATTN_TP_RANK")
dp_attention._ATTN_TP_SIZE = self._original_globals.get("_ATTN_TP_SIZE")
dp_attention._ATTN_DP_RANK = self._original_globals.get("_ATTN_DP_RANK")
dp_attention._ATTN_DP_SIZE = self._original_globals.get("_ATTN_DP_SIZE")
dp_attention._ENABLE_DP_ATTENTION_FLAG = self._original_globals.get(
"_ENABLE_DP_ATTENTION_FLAG", False
)

logger.info("[ParallelismContext] Deactivated")
return False
10 changes: 9 additions & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ScoreResult
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
parse_parallelism_config_from_scheduler_infos,
parse_remote_instance_transfer_engine_info_from_scheduler_infos,
)
from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info
Expand Down Expand Up @@ -200,6 +201,9 @@ def __init__(self, **kwargs):
scheduler_init_result.scheduler_infos
)
)
self.parallelism_config = parse_parallelism_config_from_scheduler_infos(
scheduler_init_result.scheduler_infos
)

# Initialize ZMQ sockets
context = zmq.Context(2)
Expand Down Expand Up @@ -1172,7 +1176,11 @@ def _wait_for_scheduler_ready(
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_infos.append(data)

if "_dp_scheduler_infos" in data:
scheduler_infos.extend(data["_dp_scheduler_infos"])
else:
scheduler_infos.append(data)
return scheduler_infos


Expand Down
30 changes: 30 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
parse_parallelism_config_from_scheduler_infos,
parse_remote_instance_transfer_engine_info_from_scheduler_infos,
)
from sglang.srt.observability.func_timer import enable_func_timer
Expand Down Expand Up @@ -199,6 +200,7 @@ class _GlobalState:
# )
# }
remote_instance_transfer_engine_info: Optional[Dict] = None
parallelism_config_info: Optional[Dict] = None


_global_state: Optional[_GlobalState] = None
Expand Down Expand Up @@ -1049,6 +1051,30 @@ async def get_remote_instance_transfer_engine_info(rank: int = None):
return Response(status_code=HTTPStatus.BAD_REQUEST)


@app.get("/parallelism_config")
async def parallelism_config(rank: int = 0):
"""Get parallelism config for a specific TP rank."""
if rank < 0:
return Response(status_code=HTTPStatus.BAD_REQUEST)

if (
_global_state.parallelism_config_info is None
or len(_global_state.parallelism_config_info) == 0
):
logger.error("Parallelism config info is not available.")
return Response(status_code=HTTPStatus.BAD_REQUEST)

try:
result = {
"rank": rank,
**dataclasses.asdict(_global_state.parallelism_config_info[rank]),
}
return result
except Exception as e:
logger.error(f"Exception: {e}")
return Response(status_code=HTTPStatus.BAD_REQUEST)


@app.post("/init_weights_update_group")
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def init_weights_update_group(
Expand Down Expand Up @@ -1983,6 +2009,9 @@ def _setup_and_run_http_server(
scheduler_infos
)
)
parallelism_config_info = parse_parallelism_config_from_scheduler_infos(
scheduler_infos
)

# Set global states
set_global_state(
Expand All @@ -1991,6 +2020,7 @@ def _setup_and_run_http_server(
template_manager=template_manager,
scheduler_info=scheduler_infos[0],
remote_instance_transfer_engine_info=remote_instance_transfer_engine_info,
parallelism_config_info=parallelism_config_info,
)
)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def launch_tensor_parallel_group(

self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
self.scheduler_infos = scheduler_info

def maybe_external_dp_rank_routing(self, req: Req):
if req.routed_dp_rank is not None:
Expand Down Expand Up @@ -590,6 +591,7 @@ def run_data_parallel_controller_process(
"status": "ready",
"max_total_num_tokens": controller.max_total_num_tokens,
"max_req_input_len": controller.max_req_input_len,
"_dp_scheduler_infos": controller.scheduler_infos,
}
)
if server_args.node_rank == 0:
Expand Down
Loading
Loading