From c01d0bc1aa75a9f6938fb216261dd1902312a2e9 Mon Sep 17 00:00:00 2001 From: Zhongdongming Dai Date: Mon, 23 Mar 2026 11:24:45 -0700 Subject: [PATCH 1/6] Update ModelExpress metadata API to match new SourceIdentity-based schema --- python/sglang/srt/configs/load_config.py | 6 ++ .../sglang/srt/model_executor/model_runner.py | 40 ++++++++--- python/sglang/srt/model_loader/loader.py | 68 +++++++++++++------ 3 files changed, 85 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 443ba643d083..30a3f7455eb3 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -78,6 +78,12 @@ class LoadConfig: remote_instance_weight_loader_transfer_engine: Optional[Any] = None modelexpress_url: Optional[str] = None modelexpress_model_name: Optional[str] = None + # Fields for building SourceIdentity (needed by both seed and client) + modelexpress_tp_size: int = 1 + modelexpress_pp_size: int = 1 + modelexpress_ep_size: int = 1 + modelexpress_dtype: Optional[str] = None + modelexpress_quantization: Optional[str] = None # ModelOpt-specific loading options modelopt_checkpoint_restore_path: Optional[str] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fc9afafac90b..115334b69703 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -703,6 +703,8 @@ def remote_instance_init_transfer_engine(self): def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" try: + import uuid + from modelexpress import p2p_pb2 from modelexpress.client import MxClient except ImportError as exc: @@ -725,6 +727,17 @@ def _publish_modelexpress_metadata(self): ) return + # Build SourceIdentity for this instance + identity = p2p_pb2.SourceIdentity( + model_name=model_name, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG, + tensor_parallel_size=self.server_args.tp_size, + pipeline_parallel_size=self.server_args.pp_size, + expert_parallel_size=self.server_args.ep_size, + dtype=self.server_args.dtype or "", + quantization=self.server_args.quantization or "", + ) + # Build tensor descriptors from weight_info dict tensors = [] for name, (addr, numel, element_size) in weight_info.items(): @@ -743,27 +756,33 @@ def _publish_modelexpress_metadata(self): tensors=tensors, ) + # Generate a unique worker_id for this running instance + worker_id = str(uuid.uuid4()) + mx_client = MxClient(server_url=mx_url) try: logger.info( "ModelExpress source: publishing metadata for model=%s, " - "tp_rank=%d, session=%s, %d tensors", + "tp_rank=%d, session=%s, %d tensors, worker_id=%s", model_name, self.tp_rank, session_id, len(tensors), + worker_id, ) - mx_client.publish_metadata(model_name, [worker]) - mx_client.publish_ready( - model_name, - worker_id=self.tp_rank, - session_id=mx_client.session_id, - metadata_hash="", + mx_source_id = mx_client.publish_metadata(identity, worker, worker_id) + mx_client.update_status( + mx_source_id=mx_source_id, + worker_id=worker_id, + worker_rank=self.tp_rank, + status=p2p_pb2.SOURCE_STATUS_READY, ) logger.info( - "ModelExpress source: published ready for model=%s, tp_rank=%d", + "ModelExpress source: published ready for model=%s, " + "tp_rank=%d, mx_source_id=%s", model_name, self.tp_rank, + mx_source_id, ) finally: mx_client.close() @@ -1058,6 +1077,11 @@ def load_model(self): modelexpress_url=self.server_args.modelexpress_url, modelexpress_model_name=self.server_args.modelexpress_model_name or self.server_args.model_path, + modelexpress_tp_size=self.server_args.tp_size, + modelexpress_pp_size=self.server_args.pp_size, + modelexpress_ep_size=self.server_args.ep_size, + modelexpress_dtype=self.server_args.dtype, + modelexpress_quantization=self.server_args.quantization or "", modelopt_config=modelopt_config, rl_quant_profile=self.server_args.rl_quant_profile, draft_model_idx=self.draft_model_idx, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 27d189d65622..f0126e83add1 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2290,6 +2290,9 @@ def load_model_from_modelexpress( ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" try: + import time + + from modelexpress import p2p_pb2 from modelexpress.client import MxClient except ImportError as exc: raise ImportError( @@ -2312,40 +2315,63 @@ def load_model_from_modelexpress( model, transfer_engine ) - # Wait for seed to be ready via ModelExpress + # Build SourceIdentity matching the seed's identity + identity = p2p_pb2.SourceIdentity( + model_name=model_name, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG, + tensor_parallel_size=load_config.modelexpress_tp_size, + pipeline_parallel_size=load_config.modelexpress_pp_size, + expert_parallel_size=load_config.modelexpress_ep_size, + dtype=load_config.modelexpress_dtype or "", + quantization=load_config.modelexpress_quantization or "", + ) + + # Poll list_sources until a READY worker with matching rank is found mx_client = MxClient(server_url=load_config.modelexpress_url) try: logger.info( - "ModelExpress: waiting for seed ready (model=%s)...", + "ModelExpress: waiting for seed ready (model=%s, rank=%d)...", model_name, + tp_rank, ) - ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, - worker_id=tp_rank, - ) - if not ready: - raise RuntimeError( - f"ModelExpress: timed out waiting for seed ready " - f"(model={model_name}, worker={tp_rank})" + max_wait_secs = 300 + poll_interval = 2.0 + elapsed = 0.0 + source_ref = None + while elapsed < max_wait_secs: + resp = mx_client.list_sources( + identity=identity, + status_filter=p2p_pb2.SOURCE_STATUS_READY, ) + for inst in resp.instances: + if inst.worker_rank == tp_rank: + source_ref = inst + break + if source_ref is not None: + break + time.sleep(poll_interval) + elapsed += poll_interval - response = mx_client.get_metadata(model_name) - if not response.found: + if source_ref is None: raise RuntimeError( - f"ModelExpress: no metadata found for model={model_name}" + f"ModelExpress: timed out ({max_wait_secs}s) waiting for " + f"READY source (model={model_name}, rank={tp_rank})" ) - # Find the worker matching our tp_rank - source_worker = None - for w in response.workers: - if w.worker_rank == tp_rank: - source_worker = w - break - if source_worker is None: + # Fetch full metadata for the discovered worker + response = mx_client.get_metadata( + mx_source_id=source_ref.mx_source_id, + worker_id=source_ref.worker_id, + ) + if not response.found: raise RuntimeError( - f"ModelExpress: no worker metadata for rank={tp_rank}" + f"ModelExpress: no metadata found for " + f"source_id={source_ref.mx_source_id}, " + f"worker_id={source_ref.worker_id}" ) + source_worker = response.worker + # Extract session_id from oneof backend_metadata backend_field = source_worker.WhichOneof("backend_metadata") if backend_field == "transfer_engine_session_id": From fae47438d9485aff956f2c43dd4197adc1f08ef7 Mon Sep 17 00:00:00 2001 From: Zhongdongming Dai Date: Mon, 23 Mar 2026 11:50:08 -0700 Subject: [PATCH 2/6] Use Optional defaults for LoadConfig modelexpress fields and add MX server health check --- python/sglang/srt/configs/load_config.py | 6 +++--- python/sglang/srt/model_loader/loader.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 30a3f7455eb3..338add44797f 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -79,9 +79,9 @@ class LoadConfig: modelexpress_url: Optional[str] = None modelexpress_model_name: Optional[str] = None # Fields for building SourceIdentity (needed by both seed and client) - modelexpress_tp_size: int = 1 - modelexpress_pp_size: int = 1 - modelexpress_ep_size: int = 1 + modelexpress_tp_size: Optional[int] = None + modelexpress_pp_size: Optional[int] = None + modelexpress_ep_size: Optional[int] = None modelexpress_dtype: Optional[str] = None modelexpress_quantization: Optional[str] = None diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index f0126e83add1..b6547daa9933 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2292,6 +2292,7 @@ def load_model_from_modelexpress( try: import time + import grpc from modelexpress import p2p_pb2 from modelexpress.client import MxClient except ImportError as exc: @@ -2319,9 +2320,9 @@ def load_model_from_modelexpress( identity = p2p_pb2.SourceIdentity( model_name=model_name, backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG, - tensor_parallel_size=load_config.modelexpress_tp_size, - pipeline_parallel_size=load_config.modelexpress_pp_size, - expert_parallel_size=load_config.modelexpress_ep_size, + tensor_parallel_size=load_config.modelexpress_tp_size or 1, + pipeline_parallel_size=load_config.modelexpress_pp_size or 1, + expert_parallel_size=load_config.modelexpress_ep_size or 1, dtype=load_config.modelexpress_dtype or "", quantization=load_config.modelexpress_quantization or "", ) @@ -2329,6 +2330,15 @@ def load_model_from_modelexpress( # Poll list_sources until a READY worker with matching rank is found mx_client = MxClient(server_url=load_config.modelexpress_url) try: + # Verify MX server is reachable before entering the wait loop + try: + mx_client.list_sources() + except grpc.RpcError as e: + raise RuntimeError( + f"ModelExpress: cannot reach server at " + f"{load_config.modelexpress_url}: {e}" + ) from e + logger.info( "ModelExpress: waiting for seed ready (model=%s, rank=%d)...", model_name, From a00a9d7db747b246ea49238e50f0ebace314dd86 Mon Sep 17 00:00:00 2001 From: Zhongdongming Dai Date: Mon, 23 Mar 2026 11:55:38 -0700 Subject: [PATCH 3/6] Replace polling loop with single list_sources call and fail-fast error --- python/sglang/srt/model_loader/loader.py | 43 +++++++++--------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index b6547daa9933..67b1d2520925 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2290,8 +2290,6 @@ def load_model_from_modelexpress( ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" try: - import time - import grpc from modelexpress import p2p_pb2 from modelexpress.client import MxClient @@ -2327,45 +2325,36 @@ def load_model_from_modelexpress( quantization=load_config.modelexpress_quantization or "", ) - # Poll list_sources until a READY worker with matching rank is found + # Query MX server for a READY source matching our identity and rank mx_client = MxClient(server_url=load_config.modelexpress_url) try: - # Verify MX server is reachable before entering the wait loop - try: - mx_client.list_sources() - except grpc.RpcError as e: - raise RuntimeError( - f"ModelExpress: cannot reach server at " - f"{load_config.modelexpress_url}: {e}" - ) from e - logger.info( - "ModelExpress: waiting for seed ready (model=%s, rank=%d)...", + "ModelExpress: looking for seed (model=%s, rank=%d)...", model_name, tp_rank, ) - max_wait_secs = 300 - poll_interval = 2.0 - elapsed = 0.0 - source_ref = None - while elapsed < max_wait_secs: + try: resp = mx_client.list_sources( identity=identity, status_filter=p2p_pb2.SOURCE_STATUS_READY, ) - for inst in resp.instances: - if inst.worker_rank == tp_rank: - source_ref = inst - break - if source_ref is not None: + except grpc.RpcError as e: + raise RuntimeError( + f"ModelExpress: cannot reach server at " + f"{load_config.modelexpress_url}: {e}" + ) from e + + source_ref = None + for inst in resp.instances: + if inst.worker_rank == tp_rank: + source_ref = inst break - time.sleep(poll_interval) - elapsed += poll_interval if source_ref is None: raise RuntimeError( - f"ModelExpress: timed out ({max_wait_secs}s) waiting for " - f"READY source (model={model_name}, rank={tp_rank})" + f"ModelExpress: no READY source found for " + f"model={model_name}, rank={tp_rank}. " + f"Ensure the seed instance is running and has published metadata." ) # Fetch full metadata for the discovered worker From 4fe3b08421cfe6c35043cf96c90a0f713ba731e6 Mon Sep 17 00:00:00 2001 From: Zhongdongming Dai Date: Tue, 31 Mar 2026 09:46:56 -0700 Subject: [PATCH 4/6] Move uuid import to module level per review feedback --- python/sglang/srt/model_executor/model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 115334b69703..f2691e3330ea 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import socket import threading import time +import uuid from collections import defaultdict from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union @@ -703,8 +704,6 @@ def remote_instance_init_transfer_engine(self): def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" try: - import uuid - from modelexpress import p2p_pb2 from modelexpress.client import MxClient except ImportError as exc: From 2872fadde7f92dda6ab7a2092fccbfc0c281d276 Mon Sep 17 00:00:00 2001 From: Zhongdongming Dai Date: Wed, 1 Apr 2026 13:35:27 -0700 Subject: [PATCH 5/6] fix weight info mismatch bug --- python/sglang/srt/model_loader/loader.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 67b1d2520925..9b4f1ef1685d 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2307,6 +2307,13 @@ def load_model_from_modelexpress( tp_rank = load_config.tp_rank model_name = load_config.modelexpress_model_name + target_device = torch.device(device_config.device) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + logger.info( "ModelExpress: registering memory regions for tp_rank=%d...", tp_rank ) From a22ffb3395d600b08e819ef6aa9c17aa116bd7d2 Mon Sep 17 00:00:00 2001 From: Zhongdongming Dai Date: Mon, 6 Apr 2026 10:38:59 -0700 Subject: [PATCH 6/6] Improve gRPC error message with code and details --- python/sglang/srt/model_loader/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 9b4f1ef1685d..ccf3b4e45c0d 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2348,7 +2348,8 @@ def load_model_from_modelexpress( except grpc.RpcError as e: raise RuntimeError( f"ModelExpress: cannot reach server at " - f"{load_config.modelexpress_url}: {e}" + f"{load_config.modelexpress_url}: " + f"{e.code()}: {e.details()}" ) from e source_ref = None