From 9ce68bda0e0cd2edb36c078432e96ac5175a38d5 Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Thu, 5 Mar 2026 02:20:46 +0000 Subject: [PATCH 1/9] Add ModelExpress coordination for remote instance weight loading Add MODEL_EXPRESS backend for remote instance weight loading that uses ModelExpress gRPC server for metadata coordination instead of direct HTTP between seed and target instances. Supports FP8 and BF16 models with per-tensor byte-size matching for mixed-dtype transfers. New CLI args: --model-express-url, --model-express-model-name, --model-express-source --- python/sglang/srt/configs/load_config.py | 2 + .../sglang/srt/model_executor/model_runner.py | 96 ++++++++++++ python/sglang/srt/model_loader/loader.py | 148 ++++++++++++++++++ .../remote_instance_weight_loader_utils.py | 1 + python/sglang/srt/server_args.py | 52 +++++- 5 files changed, 293 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index ddf8d2967ce6..6781b912d528 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -76,6 +76,8 @@ class LoadConfig: remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None remote_instance_weight_loader_backend: Optional[str] = None remote_instance_weight_loader_transfer_engine: Optional[Any] = None + model_express_url: Optional[str] = None + model_express_model_name: Optional[str] = None # ModelOpt-specific loading options modelopt_checkpoint_restore_path: Optional[str] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 126d8b161945..944e53fed79c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -677,6 +677,81 @@ def remote_instance_init_transfer_engine(self): f"{local_ip}:{self.remote_instance_transfer_engine.get_rpc_port()}" ) + def _publish_model_express_metadata(self): + """Publish TransferEngine metadata to ModelExpress server (seed mode).""" + from modelexpress.client import MxClient + from modelexpress import p2p_pb2 + + model_name = ( + self.server_args.model_express_model_name + or self.server_args.model_path + ) + mx_url = self.server_args.model_express_url + session_id = self.remote_instance_transfer_engine_session_id + weight_info = self.remote_instance_transfer_engine_weight_info + + if not session_id or weight_info is None: + logger.warning( + "ModelExpress source: skipping publish -- " + "TransferEngine not initialized or no weight info" + ) + return + + # Build tensor descriptors from weight_info dict + # Use actual per-tensor element_size to derive dtype (FP8 models have mixed dtypes) + element_size_to_dtype = {1: "float8_e4m3fn", 2: "bfloat16", 4: "float32", 8: "float64"} + tensors = [] + for name, (addr, numel, element_size) in weight_info.items(): + tensors.append(p2p_pb2.TensorDescriptor( + name=name, + addr=addr, + size=numel * element_size, + device_id=self.gpu_id, + dtype=element_size_to_dtype.get(element_size, "unknown"), + )) + + worker = p2p_pb2.WorkerMetadata( + worker_rank=self.tp_rank, + transfer_engine_session_id=session_id, + tensors=tensors, + ) + + mx_client = MxClient(server_url=mx_url) + logger.info( + "ModelExpress source: publishing metadata for model=%s, " + "tp_rank=%d, session=%s, %d tensors", + model_name, self.tp_rank, session_id, len(tensors), + ) + mx_client.publish_metadata(model_name, [worker]) + mx_client.publish_ready( + model_name, + worker_id=self.tp_rank, + session_id=mx_client.session_id, + metadata_hash="", + ) + logger.info( + "ModelExpress source: published ready for model=%s, tp_rank=%d", + model_name, self.tp_rank, + ) + mx_client.close() + + def _get_model_dtype_str(self) -> str: + """Return the model dtype as a string for tensor descriptors.""" + import torch + dtype_map = { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", + torch.float64: "float64", + torch.int8: "int8", + torch.uint8: "uint8", + } + if hasattr(torch, "float8_e4m3fn"): + dtype_map[torch.float8_e4m3fn] = "float8_e4m3fn" + if hasattr(torch, "float8_e5m2"): + dtype_map[torch.float8_e5m2] = "float8_e5m2" + return dtype_map.get(self.model_config.dtype, str(self.model_config.dtype)) + def model_specific_adjustment(self): server_args = self.server_args @@ -941,6 +1016,8 @@ def load_model(self): remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, remote_instance_weight_loader_backend=self.server_args.remote_instance_weight_loader_backend, remote_instance_weight_loader_transfer_engine=self.remote_instance_transfer_engine, + model_express_url=self.server_args.model_express_url, + model_express_model_name=self.server_args.model_express_model_name or self.server_args.model_path, modelopt_config=modelopt_config, rl_quant_profile=self.server_args.rl_quant_profile, draft_model_idx=self.draft_model_idx, @@ -993,6 +1070,25 @@ def load_model(self): ) monkey_patch_vllm_parallel_state(reverse=True) + # Publish metadata to ModelExpress if running as seed source + if self.server_args.model_express_source: + # Seed loads via DefaultModelLoader (load_format=auto), which doesn't + # call register_memory_region(). Do it here so weight_info is populated. + if ( + self.remote_instance_transfer_engine_weight_info is None + and self.remote_instance_transfer_engine is not None + ): + from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( + register_memory_region, + ) + + self.remote_instance_transfer_engine_weight_info = ( + register_memory_region( + self.model, self.remote_instance_transfer_engine + ) + ) + self._publish_model_express_metadata() + get_offloader().post_init() # Register model for layerwise NVTX profiling if enabled diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index da5c315afb84..c8c7127df794 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2067,6 +2067,29 @@ def load_model( return model +_DTYPE_ELEMENT_SIZES = { + "float16": 2, "half": 2, + "bfloat16": 2, + "float32": 4, "float": 4, + "float64": 8, "double": 8, + "int8": 1, + "int16": 2, + "int32": 4, + "int64": 8, + "uint8": 1, + "float8_e4m3fn": 1, + "float8_e5m2": 1, +} + + +def _dtype_to_element_size(dtype_str: str) -> int: + """Convert dtype string to element size in bytes.""" + size = _DTYPE_ELEMENT_SIZES.get(dtype_str) + if size is None: + raise ValueError(f"Unknown dtype: {dtype_str}") + return size + + class RemoteInstanceModelLoader(BaseModelLoader): """Model loader that can load Tensors from remote sglang instance.""" @@ -2148,6 +2171,13 @@ def load_model( raise RuntimeError( "Failed to load weights from remote instance via transfer engine." ) + elif ( + load_config.remote_instance_weight_loader_backend + == RemoteInstanceWeightLoaderBackend.MODEL_EXPRESS + ): + self.load_model_from_model_express( + model, load_config, device_config, + ) else: raise ValueError("Invalid remote instance weight loader backend.") @@ -2261,6 +2291,124 @@ def load_model_from_remote_instance_by_transfer_engine( return True + def load_model_from_model_express( + self, model, load_config: LoadConfig, device_config: DeviceConfig, + ): + """Load weights via ModelExpress coordination + TransferEngine RDMA.""" + from modelexpress.client import MxClient + + transfer_engine = load_config.remote_instance_weight_loader_transfer_engine + if transfer_engine is None: + raise RuntimeError( + "TransferEngine is not initialized for model_express backend." + ) + tp_rank = load_config.tp_rank + model_name = load_config.model_express_model_name + + logger.info( + "ModelExpress: registering memory regions for tp_rank=%d...", tp_rank + ) + self.remote_instance_transfer_engine_weight_info = register_memory_region( + model, transfer_engine + ) + + # Wait for seed to be ready via ModelExpress + mx_client = MxClient(server_url=load_config.model_express_url) + logger.info( + "ModelExpress: waiting for seed ready (model=%s, worker=%d)...", + model_name, tp_rank, + ) + ready, session_id, metadata_hash = mx_client.wait_for_ready( + model_name, worker_id=tp_rank, + ) + if not ready: + raise RuntimeError( + f"ModelExpress: timed out waiting for seed ready " + f"(model={model_name}, worker={tp_rank})" + ) + + # Get source metadata from ModelExpress + response = mx_client.get_metadata(model_name) + if not response.found: + raise RuntimeError( + f"ModelExpress: no metadata found for model={model_name}" + ) + + # Find the worker matching our tp_rank + source_worker = None + for w in response.workers: + if w.worker_rank == tp_rank: + source_worker = w + break + if source_worker is None: + raise RuntimeError( + f"ModelExpress: no worker metadata for rank={tp_rank}" + ) + + # Extract session_id from oneof backend_metadata + backend_field = source_worker.WhichOneof("backend_metadata") + if backend_field == "transfer_engine_session_id": + seed_session_id = source_worker.transfer_engine_session_id + else: + raise RuntimeError( + f"ModelExpress: expected transfer_engine_session_id, " + f"got backend_metadata={backend_field}" + ) + + # Convert tensor descriptors to {name: (addr, size_bytes)} format + # Use raw byte sizes -- RDMA is a memcpy, dtype matching is not required + seed_weight_info = {} + for td in source_worker.tensors: + seed_weight_info[td.name] = (td.addr, td.size) + + logger.info( + "ModelExpress: got %d tensor descriptors from seed (session=%s)", + len(seed_weight_info), seed_session_id, + ) + + # Transfer weights via TransferEngine RDMA + seed_ptr_list = [] + client_ptr_list = [] + client_len_list = [] + for name, tensor in model.named_parameters(): + weight_info = seed_weight_info.get(name, None) + if weight_info is None: + raise RuntimeError( + f"ModelExpress: cannot find weight info for {name} " + f"in seed metadata" + ) + seed_ptr, seed_size = weight_info + local_size = tensor.numel() * tensor.element_size() + if seed_size != local_size: + raise RuntimeError( + f"ModelExpress: size mismatch for {name}: " + f"seed={seed_size} bytes, local={local_size} bytes" + ) + seed_ptr_list.append(seed_ptr) + client_ptr_list.append(tensor.data_ptr()) + client_len_list.append(local_size) + + logger.info( + "ModelExpress: starting RDMA transfer of %d tensors...", + len(seed_ptr_list), + ) + ret = transfer_engine.batch_transfer_sync_read( + seed_session_id, + client_ptr_list, + seed_ptr_list, + client_len_list, + ) + if ret < 0: + raise RuntimeError( + f"ModelExpress: batch_transfer_sync_read failed, error={ret}" + ) + + if hasattr(model, "post_load_weights"): + model.post_load_weights() + + logger.info("ModelExpress: weight transfer complete for tp_rank=%d", tp_rank) + mx_client.close() + class RemoteModelLoader(BaseModelLoader): """Model loader that can load Tensors from remote database.""" diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index c063ea342d6b..8a8f3e9e205e 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -15,6 +15,7 @@ class RemoteInstanceWeightLoaderBackend(str, enum.Enum): NCCL = "nccl" TRANSFER_ENGINE = "transfer_engine" + MODEL_EXPRESS = "model_express" def trigger_init_weights_send_group_for_remote_instance_request( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9a2208a4cc93..9f4523765fdd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -685,8 +685,11 @@ class ServerArgs: remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None - remote_instance_weight_loader_backend: Literal["transfer_engine", "nccl"] = "nccl" + remote_instance_weight_loader_backend: Literal["transfer_engine", "nccl", "model_express"] = "nccl" remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False + model_express_url: Optional[str] = None + model_express_model_name: Optional[str] = None + model_express_source: bool = False # For PD-Multiplexing enable_pdmux: bool = False @@ -2716,7 +2719,19 @@ def _handle_load_format(self): self.custom_weight_loader = [] if self.load_format == "remote_instance": - if ( + if self.remote_instance_weight_loader_backend == "model_express": + # ModelExpress backend: requires --model-express-url, not seed IP/port + if self.model_express_url is None: + logger.warning( + "Fallback load_format to 'auto' due to missing --model-express-url." + ) + self.load_format = "auto" + elif not self.validate_transfer_engine(): + logger.warning( + "Fallback load_format to 'auto' due to 'transfer_engine' (required by model_express) not being supported." + ) + self.load_format = "auto" + elif ( self.remote_instance_weight_loader_seed_instance_ip is None or self.remote_instance_weight_loader_seed_instance_service_port is None ): @@ -5214,15 +5229,32 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--remote-instance-weight-loader-backend", type=str, - choices=["transfer_engine", "nccl"], + choices=["transfer_engine", "nccl", "model_express"], default=ServerArgs.remote_instance_weight_loader_backend, - help="The backend for loading weights from remote instance. Can be 'transfer_engine' or 'nccl'. Default is 'nccl'.", + help="The backend for loading weights from remote instance. Can be 'transfer_engine', 'nccl', or 'model_express'. Default is 'nccl'.", ) parser.add_argument( "--remote-instance-weight-loader-start-seed-via-transfer-engine", action="store_true", help="Start seed server via transfer engine backend for remote instance weight loader.", ) + parser.add_argument( + "--model-express-url", + type=str, + default=ServerArgs.model_express_url, + help="The URL of the ModelExpress gRPC server (host:port).", + ) + parser.add_argument( + "--model-express-model-name", + type=str, + default=ServerArgs.model_express_model_name, + help="The model name to use for ModelExpress metadata coordination.", + ) + parser.add_argument( + "--model-express-source", + action="store_true", + help="Run as a ModelExpress seed source: publish transfer metadata to the ModelExpress server after loading weights.", + ) # For PD-Multiplexing parser.add_argument( @@ -5770,7 +5802,11 @@ def adjust_mem_fraction_for_vlm(self, model_config): ) def validate_transfer_engine(self): - if importlib.util.find_spec("mooncake.engine") is None: + try: + mooncake_available = importlib.util.find_spec("mooncake.engine") is not None + except (ModuleNotFoundError, ValueError): + mooncake_available = False + if not mooncake_available: logger.warning( "Failed to import mooncake.engine. Does not support using TransferEngine as remote instance weight loader backend." ) @@ -5787,10 +5823,14 @@ def remote_instance_weight_loader_use_transfer_engine(self): # Use TransferEngine as seed backend. if self.remote_instance_weight_loader_start_seed_via_transfer_engine: return True + # ModelExpress source mode also needs TransferEngine init. + if self.model_express_source: + return True # Use TransferEngine as client backend. elif ( self.load_format == "remote_instance" - and self.remote_instance_weight_loader_backend == "transfer_engine" + and self.remote_instance_weight_loader_backend + in ("transfer_engine", "model_express") ): return True else: From 815ce3e2a0b76880cb6e19b9c41948735f121b3a Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Thu, 5 Mar 2026 07:16:22 +0000 Subject: [PATCH 2/9] Collapse model_express CLI args into single JSON config Replace --model-express-url, --model-express-model-name, --model-express-source with single --model-express-config JSON arg. Properties provide backwards-compatible access for all downstream code (model_runner, loader, load_config). Co-Authored-By: Claude Opus 4.6 --- python/sglang/srt/server_args.py | 51 ++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9f4523765fdd..f6d1203f9217 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -687,9 +687,7 @@ class ServerArgs: remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None remote_instance_weight_loader_backend: Literal["transfer_engine", "nccl", "model_express"] = "nccl" remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False - model_express_url: Optional[str] = None - model_express_model_name: Optional[str] = None - model_express_source: bool = False + model_express_config: Optional[str] = None # For PD-Multiplexing enable_pdmux: bool = False @@ -2720,10 +2718,10 @@ def _handle_load_format(self): if self.load_format == "remote_instance": if self.remote_instance_weight_loader_backend == "model_express": - # ModelExpress backend: requires --model-express-url, not seed IP/port + # ModelExpress backend: requires url in --model-express-config if self.model_express_url is None: logger.warning( - "Fallback load_format to 'auto' due to missing --model-express-url." + "Fallback load_format to 'auto' due to missing 'url' in --model-express-config." ) self.load_format = "auto" elif not self.validate_transfer_engine(): @@ -5239,21 +5237,10 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Start seed server via transfer engine backend for remote instance weight loader.", ) parser.add_argument( - "--model-express-url", + "--model-express-config", type=str, - default=ServerArgs.model_express_url, - help="The URL of the ModelExpress gRPC server (host:port).", - ) - parser.add_argument( - "--model-express-model-name", - type=str, - default=ServerArgs.model_express_model_name, - help="The model name to use for ModelExpress metadata coordination.", - ) - parser.add_argument( - "--model-express-source", - action="store_true", - help="Run as a ModelExpress seed source: publish transfer metadata to the ModelExpress server after loading weights.", + default=ServerArgs.model_express_config, + help='JSON config for ModelExpress P2P weight loading. Keys: "url" (required, gRPC host:port), "model_name" (optional, defaults to --model-path), "source" (optional bool, true for seed mode). Example: \'{"url": "localhost:8001", "model_name": "my-model", "source": true}\'', ) # For PD-Multiplexing @@ -5819,6 +5806,32 @@ def validate_transfer_engine(self): else: return True + @property + def _parsed_model_express_config(self) -> dict: + cache = getattr(self, "_mx_config_cache", None) + if cache is not None: + return cache + if self.model_express_config is None: + result = {} + elif isinstance(self.model_express_config, str): + result = json.loads(self.model_express_config) + else: + result = self.model_express_config + object.__setattr__(self, "_mx_config_cache", result) + return result + + @property + def model_express_url(self) -> Optional[str]: + return self._parsed_model_express_config.get("url") + + @property + def model_express_model_name(self) -> Optional[str]: + return self._parsed_model_express_config.get("model_name") + + @property + def model_express_source(self) -> bool: + return self._parsed_model_express_config.get("source", False) + def remote_instance_weight_loader_use_transfer_engine(self): # Use TransferEngine as seed backend. if self.remote_instance_weight_loader_start_seed_via_transfer_engine: From bca0e9d498d5acda2b422fc7510bec4159b4d652 Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Fri, 6 Mar 2026 04:16:41 +0000 Subject: [PATCH 3/9] Remove unused _DTYPE_ELEMENT_SIZES dict and _dtype_to_element_size() Dead code from initial MX integration. We switched to raw byte size comparison instead of dtype string conversion. --- python/sglang/srt/model_loader/loader.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index c8c7127df794..fb3b5a94037f 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2067,29 +2067,6 @@ def load_model( return model -_DTYPE_ELEMENT_SIZES = { - "float16": 2, "half": 2, - "bfloat16": 2, - "float32": 4, "float": 4, - "float64": 8, "double": 8, - "int8": 1, - "int16": 2, - "int32": 4, - "int64": 8, - "uint8": 1, - "float8_e4m3fn": 1, - "float8_e5m2": 1, -} - - -def _dtype_to_element_size(dtype_str: str) -> int: - """Convert dtype string to element size in bytes.""" - size = _DTYPE_ELEMENT_SIZES.get(dtype_str) - if size is None: - raise ValueError(f"Unknown dtype: {dtype_str}") - return size - - class RemoteInstanceModelLoader(BaseModelLoader): """Model loader that can load Tensors from remote sglang instance.""" From c7c085b1252adceb83a73592330265234f0abc20 Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Fri, 6 Mar 2026 04:21:06 +0000 Subject: [PATCH 4/9] Clean up ModelExpress integration code - Remove unused _get_model_dtype_str() method - Drop lossy element_size_to_dtype reverse mapping from seed publish (dtype field was never read on target side) - Wrap MxClient usage in try/finally to prevent gRPC channel leaks - Close MxClient before starting RDMA transfers (connection not needed during transfer phase) --- .../sglang/srt/model_executor/model_runner.py | 56 ++++------- python/sglang/srt/model_loader/loader.py | 92 +++++++++---------- 2 files changed, 65 insertions(+), 83 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 944e53fed79c..d33d6d905cc5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -698,8 +698,6 @@ def _publish_model_express_metadata(self): return # Build tensor descriptors from weight_info dict - # Use actual per-tensor element_size to derive dtype (FP8 models have mixed dtypes) - element_size_to_dtype = {1: "float8_e4m3fn", 2: "bfloat16", 4: "float32", 8: "float64"} tensors = [] for name, (addr, numel, element_size) in weight_info.items(): tensors.append(p2p_pb2.TensorDescriptor( @@ -707,7 +705,6 @@ def _publish_model_express_metadata(self): addr=addr, size=numel * element_size, device_id=self.gpu_id, - dtype=element_size_to_dtype.get(element_size, "unknown"), )) worker = p2p_pb2.WorkerMetadata( @@ -717,40 +714,25 @@ def _publish_model_express_metadata(self): ) mx_client = MxClient(server_url=mx_url) - logger.info( - "ModelExpress source: publishing metadata for model=%s, " - "tp_rank=%d, session=%s, %d tensors", - model_name, self.tp_rank, session_id, len(tensors), - ) - mx_client.publish_metadata(model_name, [worker]) - mx_client.publish_ready( - model_name, - worker_id=self.tp_rank, - session_id=mx_client.session_id, - metadata_hash="", - ) - logger.info( - "ModelExpress source: published ready for model=%s, tp_rank=%d", - model_name, self.tp_rank, - ) - mx_client.close() - - def _get_model_dtype_str(self) -> str: - """Return the model dtype as a string for tensor descriptors.""" - import torch - dtype_map = { - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.float32: "float32", - torch.float64: "float64", - torch.int8: "int8", - torch.uint8: "uint8", - } - if hasattr(torch, "float8_e4m3fn"): - dtype_map[torch.float8_e4m3fn] = "float8_e4m3fn" - if hasattr(torch, "float8_e5m2"): - dtype_map[torch.float8_e5m2] = "float8_e5m2" - return dtype_map.get(self.model_config.dtype, str(self.model_config.dtype)) + try: + logger.info( + "ModelExpress source: publishing metadata for model=%s, " + "tp_rank=%d, session=%s, %d tensors", + model_name, self.tp_rank, session_id, len(tensors), + ) + mx_client.publish_metadata(model_name, [worker]) + mx_client.publish_ready( + model_name, + worker_id=self.tp_rank, + session_id=mx_client.session_id, + metadata_hash="", + ) + logger.info( + "ModelExpress source: published ready for model=%s, tp_rank=%d", + model_name, self.tp_rank, + ) + finally: + mx_client.close() def model_specific_adjustment(self): server_args = self.server_args diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index fb3b5a94037f..3b21d125301a 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2291,57 +2291,58 @@ def load_model_from_model_express( # Wait for seed to be ready via ModelExpress mx_client = MxClient(server_url=load_config.model_express_url) - logger.info( - "ModelExpress: waiting for seed ready (model=%s, worker=%d)...", - model_name, tp_rank, - ) - ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, worker_id=tp_rank, - ) - if not ready: - raise RuntimeError( - f"ModelExpress: timed out waiting for seed ready " - f"(model={model_name}, worker={tp_rank})" + try: + logger.info( + "ModelExpress: waiting for seed ready (model=%s)...", + model_name, ) - - # Get source metadata from ModelExpress - response = mx_client.get_metadata(model_name) - if not response.found: - raise RuntimeError( - f"ModelExpress: no metadata found for model={model_name}" + ready, session_id, metadata_hash = mx_client.wait_for_ready( + model_name, worker_id=tp_rank, ) + if not ready: + raise RuntimeError( + f"ModelExpress: timed out waiting for seed ready " + f"(model={model_name}, worker={tp_rank})" + ) - # Find the worker matching our tp_rank - source_worker = None - for w in response.workers: - if w.worker_rank == tp_rank: - source_worker = w - break - if source_worker is None: - raise RuntimeError( - f"ModelExpress: no worker metadata for rank={tp_rank}" - ) + response = mx_client.get_metadata(model_name) + if not response.found: + raise RuntimeError( + f"ModelExpress: no metadata found for model={model_name}" + ) - # Extract session_id from oneof backend_metadata - backend_field = source_worker.WhichOneof("backend_metadata") - if backend_field == "transfer_engine_session_id": - seed_session_id = source_worker.transfer_engine_session_id - else: - raise RuntimeError( - f"ModelExpress: expected transfer_engine_session_id, " - f"got backend_metadata={backend_field}" - ) + # Find the worker matching our tp_rank + source_worker = None + for w in response.workers: + if w.worker_rank == tp_rank: + source_worker = w + break + if source_worker is None: + raise RuntimeError( + f"ModelExpress: no worker metadata for rank={tp_rank}" + ) + + # Extract session_id from oneof backend_metadata + backend_field = source_worker.WhichOneof("backend_metadata") + if backend_field == "transfer_engine_session_id": + seed_session_id = source_worker.transfer_engine_session_id + else: + raise RuntimeError( + f"ModelExpress: expected transfer_engine_session_id, " + f"got backend_metadata={backend_field}" + ) - # Convert tensor descriptors to {name: (addr, size_bytes)} format - # Use raw byte sizes -- RDMA is a memcpy, dtype matching is not required - seed_weight_info = {} - for td in source_worker.tensors: - seed_weight_info[td.name] = (td.addr, td.size) + # Build {name: (addr, size_bytes)} from seed tensor descriptors + seed_weight_info = {} + for td in source_worker.tensors: + seed_weight_info[td.name] = (td.addr, td.size) - logger.info( - "ModelExpress: got %d tensor descriptors from seed (session=%s)", - len(seed_weight_info), seed_session_id, - ) + logger.info( + "ModelExpress: got %d tensor descriptors from seed (session=%s)", + len(seed_weight_info), seed_session_id, + ) + finally: + mx_client.close() # Transfer weights via TransferEngine RDMA seed_ptr_list = [] @@ -2384,7 +2385,6 @@ def load_model_from_model_express( model.post_load_weights() logger.info("ModelExpress: weight transfer complete for tp_rank=%d", tp_rank) - mx_client.close() class RemoteModelLoader(BaseModelLoader): From de918d1418dd51c21364a717a1ae6a5935454ffd Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Fri, 13 Mar 2026 21:38:45 +0000 Subject: [PATCH 5/9] Rename model_express to modelexpress in args and identifiers Address review nit: remove separator from model_express/model-express naming to use modelexpress consistently across CLI args, field names, enum values, and method names. --- python/sglang/srt/configs/load_config.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 14 +++--- python/sglang/srt/model_loader/loader.py | 12 ++--- .../remote_instance_weight_loader_utils.py | 2 +- python/sglang/srt/server_args.py | 48 +++++++++---------- 5 files changed, 40 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 6781b912d528..443ba643d083 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -76,8 +76,8 @@ class LoadConfig: remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None remote_instance_weight_loader_backend: Optional[str] = None remote_instance_weight_loader_transfer_engine: Optional[Any] = None - model_express_url: Optional[str] = None - model_express_model_name: Optional[str] = None + modelexpress_url: Optional[str] = None + modelexpress_model_name: Optional[str] = None # ModelOpt-specific loading options modelopt_checkpoint_restore_path: Optional[str] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7bcaf5c47c68..fcfa82225dd0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -676,16 +676,16 @@ def remote_instance_init_transfer_engine(self): f"{local_ip}:{self.remote_instance_transfer_engine.get_rpc_port()}" ) - def _publish_model_express_metadata(self): + def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" from modelexpress.client import MxClient from modelexpress import p2p_pb2 model_name = ( - self.server_args.model_express_model_name + self.server_args.modelexpress_model_name or self.server_args.model_path ) - mx_url = self.server_args.model_express_url + mx_url = self.server_args.modelexpress_url session_id = self.remote_instance_transfer_engine_session_id weight_info = self.remote_instance_transfer_engine_weight_info @@ -998,8 +998,8 @@ def load_model(self): remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, remote_instance_weight_loader_backend=self.server_args.remote_instance_weight_loader_backend, remote_instance_weight_loader_transfer_engine=self.remote_instance_transfer_engine, - model_express_url=self.server_args.model_express_url, - model_express_model_name=self.server_args.model_express_model_name or self.server_args.model_path, + modelexpress_url=self.server_args.modelexpress_url, + modelexpress_model_name=self.server_args.modelexpress_model_name or self.server_args.model_path, modelopt_config=modelopt_config, rl_quant_profile=self.server_args.rl_quant_profile, draft_model_idx=self.draft_model_idx, @@ -1053,7 +1053,7 @@ def load_model(self): monkey_patch_vllm_parallel_state(reverse=True) # Publish metadata to ModelExpress if running as seed source - if self.server_args.model_express_source: + if self.server_args.modelexpress_source: # Seed loads via DefaultModelLoader (load_format=auto), which doesn't # call register_memory_region(). Do it here so weight_info is populated. if ( @@ -1069,7 +1069,7 @@ def load_model(self): self.model, self.remote_instance_transfer_engine ) ) - self._publish_model_express_metadata() + self._publish_modelexpress_metadata() get_offloader().post_init() diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 384fa2ca38af..a3698e5f4c84 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2153,9 +2153,9 @@ def load_model( ) elif ( load_config.remote_instance_weight_loader_backend - == RemoteInstanceWeightLoaderBackend.MODEL_EXPRESS + == RemoteInstanceWeightLoaderBackend.MODELEXPRESS ): - self.load_model_from_model_express( + self.load_model_from_modelexpress( model, load_config, device_config, ) else: @@ -2271,7 +2271,7 @@ def load_model_from_remote_instance_by_transfer_engine( return True - def load_model_from_model_express( + def load_model_from_modelexpress( self, model, load_config: LoadConfig, device_config: DeviceConfig, ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" @@ -2280,10 +2280,10 @@ def load_model_from_model_express( transfer_engine = load_config.remote_instance_weight_loader_transfer_engine if transfer_engine is None: raise RuntimeError( - "TransferEngine is not initialized for model_express backend." + "TransferEngine is not initialized for modelexpress backend." ) tp_rank = load_config.tp_rank - model_name = load_config.model_express_model_name + model_name = load_config.modelexpress_model_name logger.info( "ModelExpress: registering memory regions for tp_rank=%d...", tp_rank @@ -2293,7 +2293,7 @@ def load_model_from_model_express( ) # Wait for seed to be ready via ModelExpress - mx_client = MxClient(server_url=load_config.model_express_url) + mx_client = MxClient(server_url=load_config.modelexpress_url) try: logger.info( "ModelExpress: waiting for seed ready (model=%s)...", diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index 8a8f3e9e205e..8a945bb4c2e3 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -15,7 +15,7 @@ class RemoteInstanceWeightLoaderBackend(str, enum.Enum): NCCL = "nccl" TRANSFER_ENGINE = "transfer_engine" - MODEL_EXPRESS = "model_express" + MODELEXPRESS = "modelexpress" def trigger_init_weights_send_group_for_remote_instance_request( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 373f94be7183..6ef4484e945b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -697,9 +697,9 @@ class ServerArgs: remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None - remote_instance_weight_loader_backend: Literal["transfer_engine", "nccl", "model_express"] = "nccl" + remote_instance_weight_loader_backend: Literal["transfer_engine", "nccl", "modelexpress"] = "nccl" remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False - model_express_config: Optional[str] = None + modelexpress_config: Optional[str] = None # For PD-Multiplexing enable_pdmux: bool = False @@ -2896,16 +2896,16 @@ def _handle_load_format(self): self.custom_weight_loader = [] if self.load_format == "remote_instance": - if self.remote_instance_weight_loader_backend == "model_express": - # ModelExpress backend: requires url in --model-express-config - if self.model_express_url is None: + if self.remote_instance_weight_loader_backend == "modelexpress": + # ModelExpress backend: requires url in --modelexpress-config + if self.modelexpress_url is None: logger.warning( - "Fallback load_format to 'auto' due to missing 'url' in --model-express-config." + "Fallback load_format to 'auto' due to missing 'url' in --modelexpress-config." ) self.load_format = "auto" elif not self.validate_transfer_engine(): logger.warning( - "Fallback load_format to 'auto' due to 'transfer_engine' (required by model_express) not being supported." + "Fallback load_format to 'auto' due to 'transfer_engine' (required by modelexpress) not being supported." ) self.load_format = "auto" elif ( @@ -5472,9 +5472,9 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--remote-instance-weight-loader-backend", type=str, - choices=["transfer_engine", "nccl", "model_express"], + choices=["transfer_engine", "nccl", "modelexpress"], default=ServerArgs.remote_instance_weight_loader_backend, - help="The backend for loading weights from remote instance. Can be 'transfer_engine', 'nccl', or 'model_express'. Default is 'nccl'.", + help="The backend for loading weights from remote instance. Can be 'transfer_engine', 'nccl', or 'modelexpress'. Default is 'nccl'.", ) parser.add_argument( "--remote-instance-weight-loader-start-seed-via-transfer-engine", @@ -5482,9 +5482,9 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Start seed server via transfer engine backend for remote instance weight loader.", ) parser.add_argument( - "--model-express-config", + "--modelexpress-config", type=str, - default=ServerArgs.model_express_config, + default=ServerArgs.modelexpress_config, help='JSON config for ModelExpress P2P weight loading. Keys: "url" (required, gRPC host:port), "model_name" (optional, defaults to --model-path), "source" (optional bool, true for seed mode). Example: \'{"url": "localhost:8001", "model_name": "my-model", "source": true}\'', ) @@ -6097,43 +6097,43 @@ def validate_transfer_engine(self): return True @property - def _parsed_model_express_config(self) -> dict: + def _parsed_modelexpress_config(self) -> dict: cache = getattr(self, "_mx_config_cache", None) if cache is not None: return cache - if self.model_express_config is None: + if self.modelexpress_config is None: result = {} - elif isinstance(self.model_express_config, str): - result = json.loads(self.model_express_config) + elif isinstance(self.modelexpress_config, str): + result = json.loads(self.modelexpress_config) else: - result = self.model_express_config + result = self.modelexpress_config object.__setattr__(self, "_mx_config_cache", result) return result @property - def model_express_url(self) -> Optional[str]: - return self._parsed_model_express_config.get("url") + def modelexpress_url(self) -> Optional[str]: + return self._parsed_modelexpress_config.get("url") @property - def model_express_model_name(self) -> Optional[str]: - return self._parsed_model_express_config.get("model_name") + def modelexpress_model_name(self) -> Optional[str]: + return self._parsed_modelexpress_config.get("model_name") @property - def model_express_source(self) -> bool: - return self._parsed_model_express_config.get("source", False) + def modelexpress_source(self) -> bool: + return self._parsed_modelexpress_config.get("source", False) def remote_instance_weight_loader_use_transfer_engine(self): # Use TransferEngine as seed backend. if self.remote_instance_weight_loader_start_seed_via_transfer_engine: return True # ModelExpress source mode also needs TransferEngine init. - if self.model_express_source: + if self.modelexpress_source: return True # Use TransferEngine as client backend. elif ( self.load_format == "remote_instance" and self.remote_instance_weight_loader_backend - in ("transfer_engine", "model_express") + in ("transfer_engine", "modelexpress") ): return True else: From 22cebb5bb039d3153a8776d7453380686918ce54 Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Fri, 13 Mar 2026 21:43:59 +0000 Subject: [PATCH 6/9] Add ModelExpress backend documentation to R-Fork docs Document modelexpress as a third R-Fork backend option alongside NCCL and TransferEngine, including seed/client usage examples and the --modelexpress-config argument. --- docs/advanced_features/rfork.md | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/docs/advanced_features/rfork.md b/docs/advanced_features/rfork.md index 5e01aa111216..e4b513328ecf 100644 --- a/docs/advanced_features/rfork.md +++ b/docs/advanced_features/rfork.md @@ -9,11 +9,12 @@ To learn more details about R-Fork, please check ** Date: Tue, 17 Mar 2026 13:09:26 -0700 Subject: [PATCH 7/9] lint --- .../sglang/srt/model_executor/model_runner.py | 30 +++++++++++-------- python/sglang/srt/model_loader/loader.py | 15 +++++++--- python/sglang/srt/server_args.py | 4 ++- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fcfa82225dd0..4c3ed347a0ff 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -678,12 +678,11 @@ def remote_instance_init_transfer_engine(self): def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" - from modelexpress.client import MxClient from modelexpress import p2p_pb2 + from modelexpress.client import MxClient model_name = ( - self.server_args.modelexpress_model_name - or self.server_args.model_path + self.server_args.modelexpress_model_name or self.server_args.model_path ) mx_url = self.server_args.modelexpress_url session_id = self.remote_instance_transfer_engine_session_id @@ -699,12 +698,14 @@ def _publish_modelexpress_metadata(self): # Build tensor descriptors from weight_info dict tensors = [] for name, (addr, numel, element_size) in weight_info.items(): - tensors.append(p2p_pb2.TensorDescriptor( - name=name, - addr=addr, - size=numel * element_size, - device_id=self.gpu_id, - )) + tensors.append( + p2p_pb2.TensorDescriptor( + name=name, + addr=addr, + size=numel * element_size, + device_id=self.gpu_id, + ) + ) worker = p2p_pb2.WorkerMetadata( worker_rank=self.tp_rank, @@ -717,7 +718,10 @@ def _publish_modelexpress_metadata(self): logger.info( "ModelExpress source: publishing metadata for model=%s, " "tp_rank=%d, session=%s, %d tensors", - model_name, self.tp_rank, session_id, len(tensors), + model_name, + self.tp_rank, + session_id, + len(tensors), ) mx_client.publish_metadata(model_name, [worker]) mx_client.publish_ready( @@ -728,7 +732,8 @@ def _publish_modelexpress_metadata(self): ) logger.info( "ModelExpress source: published ready for model=%s, tp_rank=%d", - model_name, self.tp_rank, + model_name, + self.tp_rank, ) finally: mx_client.close() @@ -999,7 +1004,8 @@ def load_model(self): remote_instance_weight_loader_backend=self.server_args.remote_instance_weight_loader_backend, remote_instance_weight_loader_transfer_engine=self.remote_instance_transfer_engine, modelexpress_url=self.server_args.modelexpress_url, - modelexpress_model_name=self.server_args.modelexpress_model_name or self.server_args.model_path, + modelexpress_model_name=self.server_args.modelexpress_model_name + or self.server_args.model_path, modelopt_config=modelopt_config, rl_quant_profile=self.server_args.rl_quant_profile, draft_model_idx=self.draft_model_idx, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 93ef22d5d6b2..44281d3ce504 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2156,7 +2156,9 @@ def load_model( == RemoteInstanceWeightLoaderBackend.MODELEXPRESS ): self.load_model_from_modelexpress( - model, load_config, device_config, + model, + load_config, + device_config, ) else: raise ValueError("Invalid remote instance weight loader backend.") @@ -2272,7 +2274,10 @@ def load_model_from_remote_instance_by_transfer_engine( return True def load_model_from_modelexpress( - self, model, load_config: LoadConfig, device_config: DeviceConfig, + self, + model, + load_config: LoadConfig, + device_config: DeviceConfig, ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" from modelexpress.client import MxClient @@ -2300,7 +2305,8 @@ def load_model_from_modelexpress( model_name, ) ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, worker_id=tp_rank, + model_name, + worker_id=tp_rank, ) if not ready: raise RuntimeError( @@ -2342,7 +2348,8 @@ def load_model_from_modelexpress( logger.info( "ModelExpress: got %d tensor descriptors from seed (session=%s)", - len(seed_weight_info), seed_session_id, + len(seed_weight_info), + seed_session_id, ) finally: mx_client.close() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index afda867ba52f..40ee22080f46 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -698,7 +698,9 @@ class ServerArgs: remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None - remote_instance_weight_loader_backend: Literal["transfer_engine", "nccl", "modelexpress"] = "nccl" + remote_instance_weight_loader_backend: Literal[ + "transfer_engine", "nccl", "modelexpress" + ] = "nccl" remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False modelexpress_config: Optional[str] = None From 42216f76b774a6240898c4aa69b6dd9ccc4f701e Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Tue, 17 Mar 2026 13:16:23 -0700 Subject: [PATCH 8/9] Gate optional modelexpress imports --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- python/sglang/srt/model_loader/loader.py | 3 ++- .../remote_instance_weight_loader_utils.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4c3ed347a0ff..8e81e09706fb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -143,6 +143,7 @@ from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( RemoteInstanceWeightLoaderBackend, + get_modelexpress_modules, register_memory_region, trigger_init_weights_send_group_for_remote_instance_request, ) @@ -678,8 +679,7 @@ def remote_instance_init_transfer_engine(self): def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" - from modelexpress import p2p_pb2 - from modelexpress.client import MxClient + p2p_pb2, MxClient = get_modelexpress_modules() model_name = ( self.server_args.modelexpress_model_name or self.server_args.model_path diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 44281d3ce504..4a9433daf5ae 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -38,6 +38,7 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( RemoteInstanceWeightLoaderBackend, get_remote_instance_transfer_engine_info_per_rank, + get_modelexpress_modules, register_memory_region, ) from sglang.srt.server_args import get_global_server_args @@ -2280,7 +2281,7 @@ def load_model_from_modelexpress( device_config: DeviceConfig, ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" - from modelexpress.client import MxClient + _, MxClient = get_modelexpress_modules() transfer_engine = load_config.remote_instance_weight_loader_transfer_engine if transfer_engine is None: diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index 8a945bb4c2e3..19597ab67f2a 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -18,6 +18,12 @@ class RemoteInstanceWeightLoaderBackend(str, enum.Enum): MODELEXPRESS = "modelexpress" +MODELEXPRESS_INSTALL_MESSAGE = ( + "ModelExpress support requires the 'modelexpress' package. " + "Install it with: pip install modelexpress" +) + + def trigger_init_weights_send_group_for_remote_instance_request( remote_instance_weight_loader_seed_instance_ip: str, remote_instance_weight_loader_seed_instance_service_port: int, @@ -121,6 +127,15 @@ def parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_in return remote_instance_transfer_engine_info +def get_modelexpress_modules(): + try: + p2p_pb2 = importlib.import_module("modelexpress.p2p_pb2") + mx_client_module = importlib.import_module("modelexpress.client") + except ImportError as exc: + raise ImportError(MODELEXPRESS_INSTALL_MESSAGE) from exc + return p2p_pb2, mx_client_module.MxClient + + def register_memory_region(model, transfer_engine): if importlib.util.find_spec("torch") is None: return register_memory_region_v1(model, transfer_engine) From 961354ef72d56d45f641c18a2c2462ebdebacaa6 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Tue, 17 Mar 2026 13:21:28 -0700 Subject: [PATCH 9/9] Simplify modelexpress import gating --- python/sglang/srt/model_executor/model_runner.py | 10 ++++++++-- python/sglang/srt/model_loader/loader.py | 9 +++++++-- .../remote_instance_weight_loader_utils.py | 15 --------------- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8e81e09706fb..289d882fb0af 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -143,7 +143,6 @@ from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( RemoteInstanceWeightLoaderBackend, - get_modelexpress_modules, register_memory_region, trigger_init_weights_send_group_for_remote_instance_request, ) @@ -679,7 +678,14 @@ def remote_instance_init_transfer_engine(self): def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" - p2p_pb2, MxClient = get_modelexpress_modules() + try: + from modelexpress import p2p_pb2 + from modelexpress.client import MxClient + except ImportError as exc: + raise ImportError( + "ModelExpress support requires the 'modelexpress' package. " + "Install it with: pip install modelexpress" + ) from exc model_name = ( self.server_args.modelexpress_model_name or self.server_args.model_path diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 4a9433daf5ae..3087f0e85f6e 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -38,7 +38,6 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( RemoteInstanceWeightLoaderBackend, get_remote_instance_transfer_engine_info_per_rank, - get_modelexpress_modules, register_memory_region, ) from sglang.srt.server_args import get_global_server_args @@ -2281,7 +2280,13 @@ def load_model_from_modelexpress( device_config: DeviceConfig, ): """Load weights via ModelExpress coordination + TransferEngine RDMA.""" - _, MxClient = get_modelexpress_modules() + try: + from modelexpress.client import MxClient + except ImportError as exc: + raise ImportError( + "ModelExpress support requires the 'modelexpress' package. " + "Install it with: pip install modelexpress" + ) from exc transfer_engine = load_config.remote_instance_weight_loader_transfer_engine if transfer_engine is None: diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index 19597ab67f2a..8a945bb4c2e3 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -18,12 +18,6 @@ class RemoteInstanceWeightLoaderBackend(str, enum.Enum): MODELEXPRESS = "modelexpress" -MODELEXPRESS_INSTALL_MESSAGE = ( - "ModelExpress support requires the 'modelexpress' package. " - "Install it with: pip install modelexpress" -) - - def trigger_init_weights_send_group_for_remote_instance_request( remote_instance_weight_loader_seed_instance_ip: str, remote_instance_weight_loader_seed_instance_service_port: int, @@ -127,15 +121,6 @@ def parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_in return remote_instance_transfer_engine_info -def get_modelexpress_modules(): - try: - p2p_pb2 = importlib.import_module("modelexpress.p2p_pb2") - mx_client_module = importlib.import_module("modelexpress.client") - except ImportError as exc: - raise ImportError(MODELEXPRESS_INSTALL_MESSAGE) from exc - return p2p_pb2, mx_client_module.MxClient - - def register_memory_region(model, transfer_engine): if importlib.util.find_spec("torch") is None: return register_memory_region_v1(model, transfer_engine)