Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions docs/advanced_features/rfork.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ To learn more details about R-Fork, please check **<a href=https://lmsys.org/blo
| Argument | Usage |
|--------------|--------------------------------------------|
| load-format | set to `remote_instance` to enable R-Fork. |
| remote-instance-weight-loader-backend | `nccl` or `transfer_engine`, default value is `nccl` |
| remote-instance-weight-loader-seed-instance-ip | IP address of the seed instance who will provide the model weight |
| remote-instance-weight-loader-seed-instance-service-port | the port that the seed instance's HTTP server is listening on |
| remote-instance-weight-loader-send-weights-group-ports | the list of available ports on the seed instance that will be used to build NCCL communication groups between seed and client instance. This argument is only needed by `nccl` backend. |
| remote-instance-weight-loader-start-seed-via-transfer-engine | set to start seed service that supports TransferEngine as backend. It is needed for seed instances when using `transfer_engine` as backend. |
| remote-instance-weight-loader-backend | `nccl`, `transfer_engine`, or `modelexpress`. Default is `nccl`. |
| remote-instance-weight-loader-seed-instance-ip | IP address of the seed instance who will provide the model weight. Used by `nccl` and `transfer_engine` backends. |
| remote-instance-weight-loader-seed-instance-service-port | the port that the seed instance's HTTP server is listening on. Used by `nccl` and `transfer_engine` backends. |
| remote-instance-weight-loader-send-weights-group-ports | the list of available ports on the seed instance that will be used to build NCCL communication groups between seed and client instance. Only needed by `nccl` backend. |
| remote-instance-weight-loader-start-seed-via-transfer-engine | set to start seed service that supports TransferEngine as backend. Needed for seed instances when using `transfer_engine` as backend. |
| modelexpress-config | JSON config for `modelexpress` backend. Keys: `"url"` (required, gRPC host:port of ModelExpress server), `"model_name"` (optional, defaults to `--model-path`), `"source"` (optional bool, `true` for seed mode). |

### NCCL as backend

Expand Down Expand Up @@ -47,3 +48,25 @@ python -m sglang.launch_server [args] \
--remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \
--remote-instance-weight-loader-backend transfer_engine
```

### ModelExpress as backend

[ModelExpress](https://github.com/ai-dynamo/modelexpress) is a coordination service that manages P2P weight transfer metadata. It removes the need for direct seed IP/port configuration by providing a centralized registry that seeds publish to and clients discover from. Under the hood it uses TransferEngine (Mooncake) for the actual RDMA data transfer.

A running ModelExpress server is required. See the [ModelExpress documentation](https://github.com/ai-dynamo/modelexpress) for setup instructions.

seed instance:
```shell
python -m sglang.launch_server [args] \
--modelexpress-config '{"url": "[modelexpress_grpc_host:port]", "model_name": "[model_name]", "source": true}'
```

client instance:
```shell
python -m sglang.launch_server [args] \
--load-format remote_instance \
--remote-instance-weight-loader-backend modelexpress \
--modelexpress-config '{"url": "[modelexpress_grpc_host:port]", "model_name": "[model_name]"}'
```

The seed publishes its TransferEngine session ID and tensor layout to ModelExpress. The client queries ModelExpress to discover the seed, then pulls weights directly via RDMA. This enables dynamic seed discovery without hardcoding IPs, and supports multiple models through a single ModelExpress instance.
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class LoadConfig:
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
remote_instance_weight_loader_backend: Optional[str] = None
remote_instance_weight_loader_transfer_engine: Optional[Any] = None
modelexpress_url: Optional[str] = None
modelexpress_model_name: Optional[str] = None

# ModelOpt-specific loading options
modelopt_checkpoint_restore_path: Optional[str] = None
Expand Down
90 changes: 90 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,74 @@ def remote_instance_init_transfer_engine(self):
f"{local_ip}:{self.remote_instance_transfer_engine.get_rpc_port()}"
)

def _publish_modelexpress_metadata(self):
"""Publish TransferEngine metadata to ModelExpress server (seed mode)."""
try:
from modelexpress import p2p_pb2
from modelexpress.client import MxClient
except ImportError as exc:
raise ImportError(
"ModelExpress support requires the 'modelexpress' package. "
"Install it with: pip install modelexpress"
) from exc

model_name = (
self.server_args.modelexpress_model_name or self.server_args.model_path
)
mx_url = self.server_args.modelexpress_url
session_id = self.remote_instance_transfer_engine_session_id
weight_info = self.remote_instance_transfer_engine_weight_info

if not session_id or weight_info is None:
logger.warning(
"ModelExpress source: skipping publish -- "
"TransferEngine not initialized or no weight info"
)
return

# Build tensor descriptors from weight_info dict
tensors = []
for name, (addr, numel, element_size) in weight_info.items():
tensors.append(
p2p_pb2.TensorDescriptor(
name=name,
addr=addr,
size=numel * element_size,
device_id=self.gpu_id,
)
)

worker = p2p_pb2.WorkerMetadata(
worker_rank=self.tp_rank,
transfer_engine_session_id=session_id,
tensors=tensors,
)

mx_client = MxClient(server_url=mx_url)
try:
logger.info(
"ModelExpress source: publishing metadata for model=%s, "
"tp_rank=%d, session=%s, %d tensors",
model_name,
self.tp_rank,
session_id,
len(tensors),
)
mx_client.publish_metadata(model_name, [worker])
mx_client.publish_ready(
model_name,
worker_id=self.tp_rank,
session_id=mx_client.session_id,
metadata_hash="",
)
logger.info(
"ModelExpress source: published ready for model=%s, tp_rank=%d",
model_name,
self.tp_rank,
)
finally:
mx_client.close()

def model_specific_adjustment(self):
server_args = self.server_args

Expand Down Expand Up @@ -941,6 +1009,9 @@ def load_model(self):
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
remote_instance_weight_loader_backend=self.server_args.remote_instance_weight_loader_backend,
remote_instance_weight_loader_transfer_engine=self.remote_instance_transfer_engine,
modelexpress_url=self.server_args.modelexpress_url,
modelexpress_model_name=self.server_args.modelexpress_model_name
or self.server_args.model_path,
modelopt_config=modelopt_config,
rl_quant_profile=self.server_args.rl_quant_profile,
draft_model_idx=self.draft_model_idx,
Expand Down Expand Up @@ -993,6 +1064,25 @@ def load_model(self):
)
monkey_patch_vllm_parallel_state(reverse=True)

# Publish metadata to ModelExpress if running as seed source
if self.server_args.modelexpress_source:
# Seed loads via DefaultModelLoader (load_format=auto), which doesn't
# call register_memory_region(). Do it here so weight_info is populated.
if (
self.remote_instance_transfer_engine_weight_info is None
and self.remote_instance_transfer_engine is not None
):
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
register_memory_region,
)

self.remote_instance_transfer_engine_weight_info = (
register_memory_region(
self.model, self.remote_instance_transfer_engine
)
)
self._publish_modelexpress_metadata()

get_offloader().post_init()

# Register model for layerwise NVTX profiling if enabled
Expand Down
138 changes: 138 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,15 @@ def load_model(
raise RuntimeError(
"Failed to load weights from remote instance via transfer engine."
)
elif (
load_config.remote_instance_weight_loader_backend
== RemoteInstanceWeightLoaderBackend.MODELEXPRESS
):
self.load_model_from_modelexpress(
model,
load_config,
device_config,
)
else:
raise ValueError("Invalid remote instance weight loader backend.")

Expand Down Expand Up @@ -2264,6 +2273,135 @@ def load_model_from_remote_instance_by_transfer_engine(

return True

def load_model_from_modelexpress(
self,
model,
load_config: LoadConfig,
device_config: DeviceConfig,
):
"""Load weights via ModelExpress coordination + TransferEngine RDMA."""
try:
from modelexpress.client import MxClient
except ImportError as exc:
raise ImportError(
"ModelExpress support requires the 'modelexpress' package. "
"Install it with: pip install modelexpress"
) from exc

transfer_engine = load_config.remote_instance_weight_loader_transfer_engine
if transfer_engine is None:
raise RuntimeError(
"TransferEngine is not initialized for modelexpress backend."
)
tp_rank = load_config.tp_rank
model_name = load_config.modelexpress_model_name

logger.info(
"ModelExpress: registering memory regions for tp_rank=%d...", tp_rank
)
self.remote_instance_transfer_engine_weight_info = register_memory_region(
model, transfer_engine
)

# Wait for seed to be ready via ModelExpress
mx_client = MxClient(server_url=load_config.modelexpress_url)
try:
logger.info(
"ModelExpress: waiting for seed ready (model=%s)...",
model_name,
)
ready, session_id, metadata_hash = mx_client.wait_for_ready(
model_name,
worker_id=tp_rank,
)
if not ready:
raise RuntimeError(
f"ModelExpress: timed out waiting for seed ready "
f"(model={model_name}, worker={tp_rank})"
)

response = mx_client.get_metadata(model_name)
if not response.found:
raise RuntimeError(
f"ModelExpress: no metadata found for model={model_name}"
)

# Find the worker matching our tp_rank
source_worker = None
for w in response.workers:
if w.worker_rank == tp_rank:
source_worker = w
break
if source_worker is None:
raise RuntimeError(
f"ModelExpress: no worker metadata for rank={tp_rank}"
)

# Extract session_id from oneof backend_metadata
backend_field = source_worker.WhichOneof("backend_metadata")
if backend_field == "transfer_engine_session_id":
seed_session_id = source_worker.transfer_engine_session_id
else:
raise RuntimeError(
f"ModelExpress: expected transfer_engine_session_id, "
f"got backend_metadata={backend_field}"
)

# Build {name: (addr, size_bytes)} from seed tensor descriptors
seed_weight_info = {}
for td in source_worker.tensors:
seed_weight_info[td.name] = (td.addr, td.size)

logger.info(
"ModelExpress: got %d tensor descriptors from seed (session=%s)",
len(seed_weight_info),
seed_session_id,
)
finally:
mx_client.close()

# Transfer weights via TransferEngine RDMA
seed_ptr_list = []
client_ptr_list = []
client_len_list = []
for name, tensor in model.named_parameters():
weight_info = seed_weight_info.get(name, None)
if weight_info is None:
raise RuntimeError(
f"ModelExpress: cannot find weight info for {name} "
f"in seed metadata"
)
seed_ptr, seed_size = weight_info
local_size = tensor.numel() * tensor.element_size()
if seed_size != local_size:
raise RuntimeError(
f"ModelExpress: size mismatch for {name}: "
f"seed={seed_size} bytes, local={local_size} bytes"
)
seed_ptr_list.append(seed_ptr)
client_ptr_list.append(tensor.data_ptr())
client_len_list.append(local_size)

logger.info(
"ModelExpress: starting RDMA transfer of %d tensors...",
len(seed_ptr_list),
)
ret = transfer_engine.batch_transfer_sync_read(
seed_session_id,
client_ptr_list,
seed_ptr_list,
client_len_list,
)
if ret < 0:
raise RuntimeError(
f"ModelExpress: batch_transfer_sync_read failed, error={ret}"
)

if hasattr(model, "post_load_weights"):
model.post_load_weights()

logger.info("ModelExpress: weight transfer complete for tp_rank=%d", tp_rank)


class RemoteModelLoader(BaseModelLoader):
"""Model loader that can load Tensors from remote database."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class RemoteInstanceWeightLoaderBackend(str, enum.Enum):
NCCL = "nccl"
TRANSFER_ENGINE = "transfer_engine"
MODELEXPRESS = "modelexpress"


def trigger_init_weights_send_group_for_remote_instance_request(
Expand Down
Loading
Loading