Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class LoadConfig:
remote_instance_weight_loader_transfer_engine: Optional[Any] = None
modelexpress_url: Optional[str] = None
modelexpress_model_name: Optional[str] = None
# Fields for building SourceIdentity (needed by both seed and client)
modelexpress_tp_size: Optional[int] = None
modelexpress_pp_size: Optional[int] = None
modelexpress_ep_size: Optional[int] = None
modelexpress_dtype: Optional[str] = None
modelexpress_quantization: Optional[str] = None

# ModelOpt-specific loading options
modelopt_checkpoint_restore_path: Optional[str] = None
Expand Down
39 changes: 31 additions & 8 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import socket
import threading
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -781,6 +782,17 @@ def _publish_modelexpress_metadata(self):
)
return

# Build SourceIdentity for this instance
identity = p2p_pb2.SourceIdentity(
model_name=model_name,
backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG,
tensor_parallel_size=self.server_args.tp_size,
pipeline_parallel_size=self.server_args.pp_size,
expert_parallel_size=self.server_args.ep_size,
dtype=self.server_args.dtype or "",
quantization=self.server_args.quantization or "",
)

# Build tensor descriptors from weight_info dict
tensors = []
for name, (addr, numel, element_size) in weight_info.items():
Expand All @@ -799,27 +811,33 @@ def _publish_modelexpress_metadata(self):
tensors=tensors,
)

# Generate a unique worker_id for this running instance
worker_id = str(uuid.uuid4())

mx_client = MxClient(server_url=mx_url)
try:
logger.info(
"ModelExpress source: publishing metadata for model=%s, "
"tp_rank=%d, session=%s, %d tensors",
"tp_rank=%d, session=%s, %d tensors, worker_id=%s",
model_name,
self.tp_rank,
session_id,
len(tensors),
worker_id,
)
mx_client.publish_metadata(model_name, [worker])
mx_client.publish_ready(
model_name,
worker_id=self.tp_rank,
session_id=mx_client.session_id,
metadata_hash="",
mx_source_id = mx_client.publish_metadata(identity, worker, worker_id)
mx_client.update_status(
mx_source_id=mx_source_id,
worker_id=worker_id,
worker_rank=self.tp_rank,
status=p2p_pb2.SOURCE_STATUS_READY,
)
logger.info(
"ModelExpress source: published ready for model=%s, tp_rank=%d",
"ModelExpress source: published ready for model=%s, "
"tp_rank=%d, mx_source_id=%s",
model_name,
self.tp_rank,
mx_source_id,
)
finally:
mx_client.close()
Expand Down Expand Up @@ -1114,6 +1132,11 @@ def load_model(self):
modelexpress_url=self.server_args.modelexpress_url,
modelexpress_model_name=self.server_args.modelexpress_model_name
or self.server_args.model_path,
modelexpress_tp_size=self.server_args.tp_size,
modelexpress_pp_size=self.server_args.pp_size,
modelexpress_ep_size=self.server_args.ep_size,
modelexpress_dtype=self.server_args.dtype,
modelexpress_quantization=self.server_args.quantization or "",
modelopt_config=modelopt_config,
rl_quant_profile=self.server_args.rl_quant_profile,
draft_model_idx=self.draft_model_idx,
Expand Down
75 changes: 54 additions & 21 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,8 @@ def load_model_from_modelexpress(
):
"""Load weights via ModelExpress coordination + TransferEngine RDMA."""
try:
import grpc
from modelexpress import p2p_pb2
from modelexpress.client import MxClient
except ImportError as exc:
raise ImportError(
Expand All @@ -2303,47 +2305,78 @@ def load_model_from_modelexpress(
tp_rank = load_config.tp_rank
model_name = load_config.modelexpress_model_name

target_device = torch.device(device_config.device)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)

logger.info(
"ModelExpress: registering memory regions for tp_rank=%d...", tp_rank
)
self.remote_instance_transfer_engine_weight_info = register_memory_region(
model, transfer_engine
)

# Wait for seed to be ready via ModelExpress
# Build SourceIdentity matching the seed's identity
identity = p2p_pb2.SourceIdentity(
model_name=model_name,
backend_framework=p2p_pb2.BACKEND_FRAMEWORK_SGLANG,
tensor_parallel_size=load_config.modelexpress_tp_size or 1,
pipeline_parallel_size=load_config.modelexpress_pp_size or 1,
expert_parallel_size=load_config.modelexpress_ep_size or 1,
dtype=load_config.modelexpress_dtype or "",
quantization=load_config.modelexpress_quantization or "",
)

# Query MX server for a READY source matching our identity and rank
mx_client = MxClient(server_url=load_config.modelexpress_url)
try:
logger.info(
"ModelExpress: waiting for seed ready (model=%s)...",
model_name,
)
ready, session_id, metadata_hash = mx_client.wait_for_ready(
"ModelExpress: looking for seed (model=%s, rank=%d)...",
model_name,
worker_id=tp_rank,
tp_rank,
)
if not ready:
raise RuntimeError(
f"ModelExpress: timed out waiting for seed ready "
f"(model={model_name}, worker={tp_rank})"
try:
resp = mx_client.list_sources(
identity=identity,
status_filter=p2p_pb2.SOURCE_STATUS_READY,
)
except grpc.RpcError as e:
raise RuntimeError(
f"ModelExpress: cannot reach server at "
f"{load_config.modelexpress_url}: "
f"{e.code()}: {e.details()}"
) from e

source_ref = None
for inst in resp.instances:
if inst.worker_rank == tp_rank:
source_ref = inst
break

response = mx_client.get_metadata(model_name)
if not response.found:
if source_ref is None:
raise RuntimeError(
f"ModelExpress: no metadata found for model={model_name}"
f"ModelExpress: no READY source found for "
f"model={model_name}, rank={tp_rank}. "
f"Ensure the seed instance is running and has published metadata."
)

# Find the worker matching our tp_rank
source_worker = None
for w in response.workers:
if w.worker_rank == tp_rank:
source_worker = w
break
if source_worker is None:
# Fetch full metadata for the discovered worker
response = mx_client.get_metadata(
mx_source_id=source_ref.mx_source_id,
worker_id=source_ref.worker_id,
)
if not response.found:
raise RuntimeError(
f"ModelExpress: no worker metadata for rank={tp_rank}"
f"ModelExpress: no metadata found for "
f"source_id={source_ref.mx_source_id}, "
f"worker_id={source_ref.worker_id}"
)

source_worker = response.worker

# Extract session_id from oneof backend_metadata
backend_field = source_worker.WhichOneof("backend_metadata")
if backend_field == "transfer_engine_session_id":
Expand Down
Loading