diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 9df3668d0bd5..b05fa214adbe 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -30,11 +30,11 @@ import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import torch import torch.distributed @@ -2324,3 +2324,266 @@ def monkey_patch_vllm_parallel_state(reverse: bool = False): setattr(vllm_parrlel_state, "get_pp_group", get_pp_group) setattr(vllm_parrlel_state, "get_tp_group", get_tp_group) setattr(vllm_parrlel_state, "get_world_group", get_world_group) + + +@dataclass +class RankParallelismConfig: + """ + Complete parallelism configuration for a single inference rank. + + This configuration captures all the parallelism settings needed to recreate + a model shard outside of sglang. It supports: + - TP/PP/EP for model parallelism + - MoE-TP/Attn-TP/Attn-DP for MoE and DP attention. + """ + + tp_size: int = 1 + tp_rank: int = 0 + pp_size: int = 1 + pp_rank: int = 0 + ep_size: int = 1 + ep_rank: int = 0 + moe_tp_size: int = 1 + moe_tp_rank: int = 0 + attn_tp_size: int = 1 + attn_tp_rank: int = 0 + attn_dp_size: int = 1 + attn_dp_rank: int = 0 + attn_cp_size: int = 1 + attn_cp_rank: int = 0 + moe_dp_size: int = 1 + moe_dp_rank: int = 0 + + world_size: int = 1 + global_rank: int = 0 + local_rank: int = 0 + + @property + def has_dp_attention(self) -> bool: + """Check if DP attention is enabled.""" + return self.attn_dp_size > 1 + + @property + def has_expert_parallelism(self) -> bool: + """Check if expert parallelism is enabled.""" + return self.ep_size > 1 + + @property + def has_context_parallelism(self) -> bool: + """Check if context parallelism is enabled.""" + return self.attn_cp_size > 1 + + @property + def has_moe_data_parallelism(self) -> bool: + """Check if MoE data parallelism is enabled.""" + return self.moe_dp_size > 1 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RankParallelismConfig": + """Create from dictionary, filtering unknown fields.""" + import dataclasses + + valid_fields = {f.name for f in dataclasses.fields(cls)} + filtered_data = {k: v for k, v in data.items() if k in valid_fields} + return cls(**filtered_data) + + @classmethod + def from_parallel_state(cls, local_rank: int = 0) -> "RankParallelismConfig": + """Extract current parallelism settings from the global parallel state.""" + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # Import dp_attention lazily to avoid circular imports + from sglang.srt.layers.dp_attention import ( + get_attention_cp_rank, + get_attention_cp_size, + get_attention_dp_rank, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + ) + + return cls( + tp_size=tp_size, + tp_rank=tp_rank, + pp_size=get_pipeline_model_parallel_world_size(), + pp_rank=get_pipeline_model_parallel_rank(), + ep_size=get_moe_expert_parallel_world_size(), + ep_rank=get_moe_expert_parallel_rank(), + moe_tp_size=get_moe_tensor_parallel_world_size(), + moe_tp_rank=get_moe_tensor_parallel_rank(), + attn_tp_size=get_attention_tp_size(), + attn_tp_rank=get_attention_tp_rank(), + attn_dp_size=get_attention_dp_size(), + attn_dp_rank=get_attention_dp_rank(), + attn_cp_size=get_attention_cp_size(), + attn_cp_rank=get_attention_cp_rank(), + moe_dp_size=get_moe_data_parallel_world_size(), + moe_dp_rank=get_moe_data_parallel_rank(), + world_size=( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ), + global_rank=( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ), + local_rank=local_rank, + ) + + def __repr__(self) -> str: + parts = [ + f"TP={self.tp_rank}/{self.tp_size}", + f"PP={self.pp_rank}/{self.pp_size}", + ] + if self.has_expert_parallelism: + parts.append(f"EP={self.ep_rank}/{self.ep_size}") + parts.append(f"MoE-TP={self.moe_tp_rank}/{self.moe_tp_size}") + if self.has_dp_attention: + parts.append(f"AttnTP={self.attn_tp_rank}/{self.attn_tp_size}") + parts.append(f"AttnDP={self.attn_dp_rank}/{self.attn_dp_size}") + if self.has_context_parallelism: + parts.append(f"AttnCP={self.attn_cp_rank}/{self.attn_cp_size}") + if self.has_moe_data_parallelism: + parts.append(f"MoE-DP={self.moe_dp_rank}/{self.moe_dp_size}") + parts.append(f"Global={self.global_rank}/{self.world_size}") + return f"RankParallelismConfig({', '.join(parts)})" + + +class ParallelismContext: + """ + Context manager for creating model replicas with specific parallelism settings. + + This context manager temporarily sets global variables to allow creating model + shards outside of a real distributed environment. + Usage: + server_args = ... # Get from engine.get_server_info() + parallelism_info = engine.get_parallelism_config(rank) + sglang.srt.server_args._global_server_args = server_args + + with ParallelismContext(RankParallelismConfig.from_dict(parallelism_info)): + model = get_model( + model_config=model_config, + load_config=load_config, + device_config=device_config, + ) + """ + + def __init__(self, parallelism_config: RankParallelismConfig): + self.config = parallelism_config + self._original_globals: Dict[str, Any] = {} + + def _create_mock_group(self, world_size: int, rank_in_group: int): + """Create a mock group coordinator with all necessary properties.""" + mock_group = MagicMock() + mock_group.world_size = world_size + mock_group.rank_in_group = rank_in_group + mock_group.rank = rank_in_group # Use rank_in_group for .rank as well + mock_group.local_rank = rank_in_group + mock_group.ranks = list(range(world_size)) + mock_group.first_rank = 0 + mock_group.last_rank = world_size - 1 + mock_group.is_first_rank = rank_in_group == 0 + mock_group.is_last_rank = rank_in_group == world_size - 1 + mock_group.next_rank = mock_group.ranks[(rank_in_group + 1) % world_size] + mock_group.prev_rank = mock_group.ranks[(rank_in_group - 1) % world_size] + return mock_group + + def __enter__(self): + conf = self.config + + # Import the modules we need to modify + from sglang.srt.distributed import parallel_state + from sglang.srt.layers import dp_attention + + # Save original global variables to restore later + self._original_globals["_TP"] = getattr(parallel_state, "_TP", None) + self._original_globals["_PP"] = getattr(parallel_state, "_PP", None) + self._original_globals["_MOE_EP"] = getattr(parallel_state, "_MOE_EP", None) + self._original_globals["_MOE_TP"] = getattr(parallel_state, "_MOE_TP", None) + self._original_globals["_ATTN_TP"] = getattr(parallel_state, "_ATTN_TP", None) + self._original_globals["_ATTN_CP"] = getattr(parallel_state, "_ATTN_CP", None) + self._original_globals["_MOE_DP"] = getattr(parallel_state, "_MOE_DP", None) + self._original_globals["_ATTN_TP_RANK"] = getattr( + dp_attention, "_ATTN_TP_RANK", None + ) + self._original_globals["_ATTN_TP_SIZE"] = getattr( + dp_attention, "_ATTN_TP_SIZE", None + ) + self._original_globals["_ATTN_DP_RANK"] = getattr( + dp_attention, "_ATTN_DP_RANK", None + ) + self._original_globals["_ATTN_DP_SIZE"] = getattr( + dp_attention, "_ATTN_DP_SIZE", None + ) + self._original_globals["_ENABLE_DP_ATTENTION_FLAG"] = getattr( + dp_attention, "_ENABLE_DP_ATTENTION_FLAG", False + ) + + # Create mock group objects with the correct attributes + mock_tp_group = self._create_mock_group(conf.tp_size, conf.tp_rank) + mock_pp_group = self._create_mock_group(conf.pp_size, conf.pp_rank) + mock_ep_group = self._create_mock_group(conf.ep_size, conf.ep_rank) + mock_moe_tp_group = self._create_mock_group(conf.moe_tp_size, conf.moe_tp_rank) + mock_attn_tp_group = self._create_mock_group(conf.attn_tp_size, conf.attn_tp_rank) + mock_attn_cp_group = self._create_mock_group(conf.attn_cp_size, conf.attn_cp_rank) + mock_moe_dp_group = self._create_mock_group(conf.moe_dp_size, conf.moe_dp_rank) + + # Set the global group objects directly on parallel_state module + parallel_state._TP = mock_tp_group + parallel_state._PP = mock_pp_group + parallel_state._MOE_EP = mock_ep_group + parallel_state._MOE_TP = mock_moe_tp_group + parallel_state._ATTN_TP = mock_attn_tp_group + parallel_state._ATTN_CP = mock_attn_cp_group + parallel_state._MOE_DP = mock_moe_dp_group + + # Set dp_attention globals directly + dp_attention._ATTN_TP_RANK = conf.attn_tp_rank + dp_attention._ATTN_TP_SIZE = conf.attn_tp_size + dp_attention._ATTN_DP_RANK = conf.attn_dp_rank + dp_attention._ATTN_DP_SIZE = conf.attn_dp_size + # Enable DP attention flag if attn_dp_size > 1 + dp_attention._ENABLE_DP_ATTENTION_FLAG = conf.attn_dp_size > 1 + + logger.info( + f"[ParallelismContext] Activated: TP={conf.tp_rank}/{conf.tp_size}, " + f"PP={conf.pp_rank}/{conf.pp_size}, EP={conf.ep_rank}/{conf.ep_size}, " + f"MoE-TP={conf.moe_tp_rank}/{conf.moe_tp_size}, " + f"AttnTP={conf.attn_tp_rank}/{conf.attn_tp_size}, " + f"AttnDP={conf.attn_dp_rank}/{conf.attn_dp_size}, " + f"AttnCP={conf.attn_cp_rank}/{conf.attn_cp_size}, " + f"MoE-DP={conf.moe_dp_rank}/{conf.moe_dp_size}" + ) + return self + + def __exit__(self, *args): + from sglang.srt.distributed import parallel_state + from sglang.srt.layers import dp_attention + + # Restore original parallel_state globals + parallel_state._TP = self._original_globals.get("_TP") + parallel_state._PP = self._original_globals.get("_PP") + parallel_state._MOE_EP = self._original_globals.get("_MOE_EP") + parallel_state._MOE_TP = self._original_globals.get("_MOE_TP") + parallel_state._ATTN_TP = self._original_globals.get("_ATTN_TP") + parallel_state._ATTN_CP = self._original_globals.get("_ATTN_CP") + parallel_state._MOE_DP = self._original_globals.get("_MOE_DP") + + # Restore original dp_attention globals + dp_attention._ATTN_TP_RANK = self._original_globals.get("_ATTN_TP_RANK") + dp_attention._ATTN_TP_SIZE = self._original_globals.get("_ATTN_TP_SIZE") + dp_attention._ATTN_DP_RANK = self._original_globals.get("_ATTN_DP_RANK") + dp_attention._ATTN_DP_SIZE = self._original_globals.get("_ATTN_DP_SIZE") + dp_attention._ENABLE_DP_ATTENTION_FLAG = self._original_globals.get( + "_ENABLE_DP_ATTENTION_FLAG", False + ) + + logger.info("[ParallelismContext] Deactivated") + return False diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 87c8c0a19f89..bf429aa712e0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -79,6 +79,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ScoreResult from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( + parse_parallelism_config_from_scheduler_infos, parse_remote_instance_transfer_engine_info_from_scheduler_infos, ) from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info @@ -200,6 +201,9 @@ def __init__(self, **kwargs): scheduler_init_result.scheduler_infos ) ) + self.parallelism_config = parse_parallelism_config_from_scheduler_infos( + scheduler_init_result.scheduler_infos + ) # Initialize ZMQ sockets context = zmq.Context(2) @@ -1172,7 +1176,11 @@ def _wait_for_scheduler_ready( raise RuntimeError( "Initialization failed. Please see the error messages above." ) - scheduler_infos.append(data) + + if "_dp_scheduler_infos" in data: + scheduler_infos.extend(data["_dp_scheduler_infos"]) + else: + scheduler_infos.append(data) return scheduler_infos diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index d57108c410ac..f9350580b6f9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -154,6 +154,7 @@ from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( + parse_parallelism_config_from_scheduler_infos, parse_remote_instance_transfer_engine_info_from_scheduler_infos, ) from sglang.srt.observability.func_timer import enable_func_timer @@ -199,6 +200,7 @@ class _GlobalState: # ) # } remote_instance_transfer_engine_info: Optional[Dict] = None + parallelism_config_info: Optional[Dict] = None _global_state: Optional[_GlobalState] = None @@ -1049,6 +1051,30 @@ async def get_remote_instance_transfer_engine_info(rank: int = None): return Response(status_code=HTTPStatus.BAD_REQUEST) +@app.get("/parallelism_config") +async def parallelism_config(rank: int = 0): + """Get parallelism config for a specific TP rank.""" + if rank < 0: + return Response(status_code=HTTPStatus.BAD_REQUEST) + + if ( + _global_state.parallelism_config_info is None + or len(_global_state.parallelism_config_info) == 0 + ): + logger.error("Parallelism config info is not available.") + return Response(status_code=HTTPStatus.BAD_REQUEST) + + try: + result = { + "rank": rank, + **dataclasses.asdict(_global_state.parallelism_config_info[rank]), + } + return result + except Exception as e: + logger.error(f"Exception: {e}") + return Response(status_code=HTTPStatus.BAD_REQUEST) + + @app.post("/init_weights_update_group") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def init_weights_update_group( @@ -1983,6 +2009,9 @@ def _setup_and_run_http_server( scheduler_infos ) ) + parallelism_config_info = parse_parallelism_config_from_scheduler_infos( + scheduler_infos + ) # Set global states set_global_state( @@ -1991,6 +2020,7 @@ def _setup_and_run_http_server( template_manager=template_manager, scheduler_info=scheduler_infos[0], remote_instance_transfer_engine_info=remote_instance_transfer_engine_info, + parallelism_config_info=parallelism_config_info, ) ) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index ae5b38ec7160..6b720820a0bb 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -492,6 +492,7 @@ def launch_tensor_parallel_group( self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] + self.scheduler_infos = scheduler_info def maybe_external_dp_rank_routing(self, req: Req): if req.routed_dp_rank is not None: @@ -590,6 +591,7 @@ def run_data_parallel_controller_process( "status": "ready", "max_total_num_tokens": controller.max_total_num_tokens, "max_req_input_len": controller.max_req_input_len, + "_dp_scheduler_infos": controller.scheduler_infos, } ) if server_args.node_rank == 0: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a7d3c4550a7e..df9e23ed3d78 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1128,11 +1128,13 @@ def get_init_info(self) -> Dict[str, Any]: remote_instance_transfer_engine_session_id, remote_instance_transfer_engine_weights_info_dict, ) = self.get_remote_instance_transfer_engine_info() + parallelism_config = self.get_parallelism_config() result_dict.update( { "tp_rank": self.tp_rank, "remote_instance_transfer_engine_session_id": remote_instance_transfer_engine_session_id, "remote_instance_transfer_engine_weights_info_dict": remote_instance_transfer_engine_weights_info_dict, + "parallelism_config_info": parallelism_config, } ) @@ -3051,6 +3053,9 @@ def update_cache_from_scheduler( def get_remote_instance_transfer_engine_info(self): return self.tp_worker.get_remote_instance_transfer_engine_info() + def get_parallelism_config(self): + return self.tp_worker.get_parallelism_config() + class IdleSleeper: """ diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8408cc9055f2..b10de5e16aba 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -355,6 +355,8 @@ def _init_model_runner(self): req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, draft_model_idx=0 if self.is_multi_layer_eagle else None, + attn_cp_rank=self.attn_cp_rank, + moe_dp_rank=self.moe_dp_rank, ) def _init_multi_layer_eagle_model_runners(self): @@ -380,6 +382,8 @@ def _init_multi_layer_eagle_model_runners(self): req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, draft_model_idx=i, + attn_cp_rank=self.attn_cp_rank, + moe_dp_rank=self.moe_dp_rank, ) ) @@ -439,6 +443,9 @@ def get_remote_instance_transfer_engine_info(self): self.model_runner.remote_instance_transfer_engine_weight_info, ) + def get_parallelism_config(self): + return self.model_runner.parallelism_config + def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9b615cce8499..823ff0306ce1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -68,7 +68,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) -from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.distributed.parallel_state import ( + RankParallelismConfig, + monkey_patch_vllm_parallel_state, +) from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.elastic_ep.expert_backup_client import ExpertBackupClient from sglang.srt.environ import envs @@ -343,6 +346,7 @@ def __init__( self.remote_instance_transfer_engine = None self.remote_instance_transfer_engine_session_id = "" self.remote_instance_transfer_engine_weight_info = None + self.parallelism_config = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: @@ -456,6 +460,9 @@ def initialize(self, pre_model_load_memory: float): if self.server_args.remote_instance_weight_loader_use_transfer_engine(): self.remote_instance_init_transfer_engine() + self.parallelism_config = RankParallelismConfig.from_parallel_state( + self.tp_rank + ) if not self.is_draft_worker: set_global_expert_location_metadata( diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index c063ea342d6b..cf7cb78208a3 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -120,6 +120,14 @@ def parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_in return remote_instance_transfer_engine_info +def parse_parallelism_config_from_scheduler_infos(scheduler_infos): + parallelism_config_info = {} + for data in scheduler_infos: + if "tp_rank" in data and "parallelism_config_info" in data: + parallelism_config_info[data["tp_rank"]] = data["parallelism_config_info"] + return parallelism_config_info + + def register_memory_region(model, transfer_engine): if importlib.util.find_spec("torch") is None: return register_memory_region_v1(model, transfer_engine) diff --git a/test/registered/distributed/test_parallelism_context_integration.py b/test/registered/distributed/test_parallelism_context_integration.py new file mode 100644 index 000000000000..e777dbf1160c --- /dev/null +++ b/test/registered/distributed/test_parallelism_context_integration.py @@ -0,0 +1,288 @@ +""" +Integration tests for ParallelismContext with real sglang servers. + +Tests that ParallelismContext can instantiate models with correct tensor parallel +sharding by comparing parameter names and sizes against a running sglang server. + +Run with: + pytest test/registered/distributed/test_parallelism_context_integration.py -v + +Full test suite (non-CI): + - TP=2 small model (Qwem2.5-1.5B-Instruct) + - EP=2 small MOE model (DeepSeek-Coder-V2-Lite-Instruct) + - MLA model with hybrid dp attention (DeepSeek-Coder-V2-Lite-Instruct) + +CI test (reduced): + - TP=2 small model only +""" + +import dataclasses +import gc +from typing import Dict, List, Tuple + +import pytest +import requests +import torch + +from sglang.srt.distributed.parallel_state import RankParallelismConfig +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, +) +from sglang.utils import terminate_process + + +def get_transfer_engine_info(url: str, rank: int) -> Dict: + """Get transfer engine info (parameter names and sizes) for a rank.""" + response = requests.get( + f"{url}/get_remote_instance_transfer_engine_info", + params={"rank": rank}, + ) + response.raise_for_status() + return response.json() + + +def get_parallelism_config(url: str, rank: int) -> Dict: + """Get parallelism config for a rank.""" + response = requests.get(f"{url}/parallelism_config", params={"rank": rank}) + response.raise_for_status() + return response.json() + + +def get_server_info(url: str) -> Dict: + """Get server info.""" + response = requests.get(f"{url}/server_info") + response.raise_for_status() + return response.json() + + +def verify_model_params_match_for_rank( + url: str, + rank: int, + server_info: Dict, + test_gpu_id: int, +): + """Verify model parameters match for a specific rank by recreating a model shard. + + Args: + url: Server URL + rank: The rank to verify + server_info: Server info dict + test_gpu_id: GPU ID to use for instantiating the test model + """ + transfer_info = get_transfer_engine_info(url, rank) + server_weights_info = transfer_info["remote_instance_transfer_engine_info"][1] + + # Get parallelism config from running server + parallelism_config_data = get_parallelism_config(url, rank) + parallelism_config = RankParallelismConfig.from_dict(parallelism_config_data) + # Get server args from server info + from sglang.srt.server_args import ServerArgs + + valid_fields = {f.name for f in dataclasses.fields(ServerArgs)} + filtered_info = {k: v for k, v in server_info.items() if k in valid_fields} + filtered_info.pop("model_config", None) + server_args = ServerArgs(**filtered_info) + + from sglang.srt import server_args as server_args_module + from sglang.srt.distributed.parallel_state import ParallelismContext + + original_global_server_args = server_args_module._global_server_args + + try: + # In a Mock ParallelismContext, instantiate the model for this rank. + # Use a separate GPU (test_gpu_id) to avoid memory conflicts with the running server. + server_args_module._global_server_args = server_args + with ParallelismContext(parallelism_config): + from sglang.srt.configs.device_config import DeviceConfig + from sglang.srt.configs.load_config import LoadConfig + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.model_loader import get_model + + model_config = ModelConfig.from_server_args(server_args) + load_config = LoadConfig(load_format="dummy") + device_config = DeviceConfig(device="cuda", gpu_id=test_gpu_id) + + torch.cuda.set_device(test_gpu_id) + model = get_model( + model_config=model_config, + load_config=load_config, + device_config=device_config, + ) + model_params = {} + for name, param in model.named_parameters(): + model_params[name] = param.numel() * param.element_size() + + # Verify all server parameters exist in model with same size + mismatches = [] + missing = [] + for param_name, (ptr, numel, elem_size) in server_weights_info.items(): + expected_size = numel * elem_size + if param_name not in model_params: + missing.append(param_name) + elif model_params[param_name] != expected_size: + mismatches.append( + f"{param_name}: model={model_params[param_name]}, server={expected_size}" + ) + + assert not missing, f"Rank {rank}: Missing parameters: {missing}" + assert not mismatches, f"Rank {rank}: Size mismatches: {mismatches}" + del model + torch.cuda.empty_cache() + + finally: + server_args_module._global_server_args = original_global_server_args + + +TEST_CONFIGS: List[Tuple[str, str, int, List[str], int, bool]] = [ + # Basic TP=2 test (CI only) + ( + "tp2_small", + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + 2, + [], + 2, + False, + ), + # EP=2: MoE experts split across 2 groups, moe_tp=1 per group + ( + "mla_ep2", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 2, + ["--ep-size", "2"], + 2, + True, + ), + ( + "mla_dp2_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "2"], + 4, + True, + ), + ( + "mla_dp2_ep2_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "2", "--ep-size", "2"], + 4, + True, + ), + ( + "mla_dp2_ep4_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "2", "--ep-size", "4"], + 4, + True, + ), + ( + "mla_dp4_ep2_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "4", "--ep-size", "2"], + 4, + True, + ), +] + + +def get_test_configs(): + if is_in_ci(): + return [TEST_CONFIGS[0]] + else: + return TEST_CONFIGS + + +def _get_test_params(): + """Generate pytest parameters based on test configs.""" + configs = get_test_configs() + params = [] + ids = [] + for ( + test_id, + model_name, + tp_size, + extra_args, + min_gpus, + trust_remote_code, + ) in configs: + params.append( + pytest.param( + (model_name, tp_size, extra_args, min_gpus, trust_remote_code), + id=test_id, + ) + ) + return params + + +class TestParallelismContextIntegration: + """ + Test that ParallelismContext can instantiate models with the same + parameter names and sizes as the sglang server engine. + """ + + @pytest.mark.parametrize("config", _get_test_params()) + def test_model_instantiation_matches_server(self, config): + """ + Test that a model instantiated with ParallelismContext has the same + parameter names and sizes as the model in the sglang server. + + This test: + 1. Starts a server with specified parallelism config + 2. Gets transfer_engine_info for all ranks (contains param names and sizes) + 3. Gets parallelism_config and server_info + 4. Uses ParallelismContext to instantiate a model for each rank + 5. Compares the parameter names and sizes + """ + model_name, tp_size, extra_args, min_gpus, trust_remote_code = config + url = DEFAULT_URL_FOR_TEST + + # Need min_gpus for server + 1 extra GPU for test model instantiation + required_gpus = min_gpus + 1 + if torch.cuda.device_count() < required_gpus: + pytest.skip( + f"Need at least {required_gpus} GPUs (server={min_gpus} + test=1), have {torch.cuda.device_count()}" + ) + + # The test model will be instantiated on GPU after the server's GPUs + test_gpu_id = min_gpus # e.g., if server uses 0-1, test uses 2 + + # Build server args + other_args = [ + "--tp-size", + str(tp_size), + "--remote-instance-weight-loader-start-seed-via-transfer-engine", + ] + if trust_remote_code: + other_args.append("--trust-remote-code") + other_args.extend(extra_args) + + process = None + try: + process = popen_launch_server( + model_name, + url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + server_info = get_server_info(url) + + for rank in range(tp_size): + verify_model_params_match_for_rank(url, rank, server_info, test_gpu_id) + + finally: + if process is not None: + terminate_process(process) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])