diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 2a7a4ea2e33d..e642bdc27abe 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -79,6 +79,12 @@ class LoadConfig: remote_instance_weight_loader_transfer_engine: Optional[Any] = None modelexpress_url: Optional[str] = None modelexpress_model_name: Optional[str] = None + # Fields for building SourceIdentity (needed by both seed and client) + modelexpress_tp_size: Optional[int] = None + modelexpress_pp_size: Optional[int] = None + modelexpress_ep_size: Optional[int] = None + modelexpress_dtype: Optional[str] = None + modelexpress_quantization: Optional[str] = None # ModelOpt-specific loading options modelopt_checkpoint_restore_path: Optional[str] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a59742b94354..c0551c0ec353 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import socket import threading import time +import uuid from collections import defaultdict from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union @@ -781,6 +782,17 @@ def _publish_modelexpress_metadata(self): ) return + # Build SourceIdentity for this instance + identity = p2p_pb2.SourceIdentity( + model_name=model_name, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG, + tensor_parallel_size=self.server_args.tp_size, + pipeline_parallel_size=self.server_args.pp_size, + expert_parallel_size=self.server_args.ep_size, + dtype=self.server_args.dtype or "", + quantization=self.server_args.quantization or "", + ) + # Build tensor descriptors from weight_info dict tensors = [] for name, (addr, numel, element_size) in weight_info.items(): @@ -799,27 +811,33 @@ def _publish_modelexpress_metadata(self): tensors=tensors, ) + # Generate a unique worker_id for this running instance + worker_id = str(uuid.uuid4()) + mx_client = MxClient(server_url=mx_url) try: logger.info( "ModelExpress source: publishing metadata for model=%s, " - "tp_rank=%d, session=%s, %d tensors", + "tp_rank=%d, session=%s, %d tensors, worker_id=%s", model_name, self.tp_rank, session_id, len(tensors), + worker_id, ) - mx_client.publish_metadata(model_name, [worker]) - mx_client.publish_ready( - model_name, - worker_id=self.tp_rank, - session_id=mx_client.session_id, - metadata_hash="", + mx_source_id = mx_client.publish_metadata(identity, worker, worker_id) + mx_client.update_status( + mx_source_id=mx_source_id, + worker_id=worker_id, + worker_rank=self.tp_rank, + status=p2p_pb2.SOURCE_STATUS_READY, ) logger.info( - "ModelExpress source: published ready for model=%s, tp_rank=%d", + "ModelExpress source: published ready for model=%s, " + "tp_rank=%d, mx_source_id=%s", model_name, self.tp_rank, + mx_source_id, ) finally: mx_client.close() @@ -1114,6 +1132,11 @@ def load_model(self): modelexpress_url=self.server_args.modelexpress_url, modelexpress_model_name=self.server_args.modelexpress_model_name or self.server_args.model_path, + modelexpress_tp_size=self.server_args.tp_size, + modelexpress_pp_size=self.server_args.pp_size, + modelexpress_ep_size=self.server_args.ep_size, + modelexpress_dtype=self.server_args.dtype, + modelexpress_quantization=self.server_args.quantization or "", modelopt_config=modelopt_config, rl_quant_profile=self.server_args.rl_quant_profile, draft_model_idx=self.draft_model_idx, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index b0884c68153e..cce20ee7cb0d 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2288,6 +2288,8 @@ def load_model_from_modelexpress( ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" try: + import grpc + from modelexpress import p2p_pb2 from modelexpress.client import MxClient except ImportError as exc: raise ImportError( @@ -2303,6 +2305,13 @@ def load_model_from_modelexpress( tp_rank = load_config.tp_rank model_name = load_config.modelexpress_model_name + target_device = torch.device(device_config.device) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + logger.info( "ModelExpress: registering memory regions for tp_rank=%d...", tp_rank ) @@ -2310,40 +2319,64 @@ def load_model_from_modelexpress( model, transfer_engine ) - # Wait for seed to be ready via ModelExpress + # Build SourceIdentity matching the seed's identity + identity = p2p_pb2.SourceIdentity( + model_name=model_name, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG, + tensor_parallel_size=load_config.modelexpress_tp_size or 1, + pipeline_parallel_size=load_config.modelexpress_pp_size or 1, + expert_parallel_size=load_config.modelexpress_ep_size or 1, + dtype=load_config.modelexpress_dtype or "", + quantization=load_config.modelexpress_quantization or "", + ) + + # Query MX server for a READY source matching our identity and rank mx_client = MxClient(server_url=load_config.modelexpress_url) try: logger.info( - "ModelExpress: waiting for seed ready (model=%s)...", - model_name, - ) - ready, session_id, metadata_hash = mx_client.wait_for_ready( + "ModelExpress: looking for seed (model=%s, rank=%d)...", model_name, - worker_id=tp_rank, + tp_rank, ) - if not ready: - raise RuntimeError( - f"ModelExpress: timed out waiting for seed ready " - f"(model={model_name}, worker={tp_rank})" + try: + resp = mx_client.list_sources( + identity=identity, + status_filter=p2p_pb2.SOURCE_STATUS_READY, ) + except grpc.RpcError as e: + raise RuntimeError( + f"ModelExpress: cannot reach server at " + f"{load_config.modelexpress_url}: " + f"{e.code()}: {e.details()}" + ) from e + + source_ref = None + for inst in resp.instances: + if inst.worker_rank == tp_rank: + source_ref = inst + break - response = mx_client.get_metadata(model_name) - if not response.found: + if source_ref is None: raise RuntimeError( - f"ModelExpress: no metadata found for model={model_name}" + f"ModelExpress: no READY source found for " + f"model={model_name}, rank={tp_rank}. " + f"Ensure the seed instance is running and has published metadata." ) - # Find the worker matching our tp_rank - source_worker = None - for w in response.workers: - if w.worker_rank == tp_rank: - source_worker = w - break - if source_worker is None: + # Fetch full metadata for the discovered worker + response = mx_client.get_metadata( + mx_source_id=source_ref.mx_source_id, + worker_id=source_ref.worker_id, + ) + if not response.found: raise RuntimeError( - f"ModelExpress: no worker metadata for rank={tp_rank}" + f"ModelExpress: no metadata found for " + f"source_id={source_ref.mx_source_id}, " + f"worker_id={source_ref.worker_id}" ) + source_worker = response.worker + # Extract session_id from oneof backend_metadata backend_field = source_worker.WhichOneof("backend_metadata") if backend_field == "transfer_engine_session_id":