From 789f31142f8e1bf9e516600e17ab036bdc8a8739 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Mon, 15 Dec 2025 01:22:12 +0000 Subject: [PATCH 01/11] rdma impl rebase with profile --- .../megatron_utils/update_weight/common.py | 11 + .../update_weight/remote_transfer_plan.py | 234 ++++++++++++++++++ .../update_weight_from_distributed.py | 163 +----------- .../update_weight/update_weight_from_rdma.py | 159 ++++++++++++ .../update_weight_from_remote.py | 193 +++++++++++++++ slime/backends/sglang_utils/sglang_engine.py | 2 + slime/utils/arguments.py | 21 ++ slime/utils/timer.py | 7 +- tests/test_qwen2.5_0.5B_gsm8k_async.py | 9 +- tests/test_weight_transfer.py | 114 +++++++++ tests/timing_comparison_guide.md | 65 +++++ 11 files changed, 820 insertions(+), 158 deletions(-) create mode 100644 slime/backends/megatron_utils/update_weight/remote_transfer_plan.py create mode 100644 slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py create mode 100644 slime/backends/megatron_utils/update_weight/update_weight_from_remote.py create mode 100644 tests/test_weight_transfer.py create mode 100644 tests/timing_comparison_guide.md diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index a2e4e129bc..d515b2f963 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -130,6 +130,17 @@ def named_params_and_buffers( return ans +def split_expert_and_non_expert_param_names(params: Iterator[str]): + expert_param_names = [] + non_expert_param_names = [] + for param in params: + if ".experts." in param: + expert_param_names.append(param) + else: + non_expert_param_names.append(param) + return expert_param_names, non_expert_param_names + + def _maybe_get_cpu_backup(x: torch.Tensor): from torch_memory_saver import torch_memory_saver diff --git a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py new file mode 100644 index 0000000000..a6941f1385 --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py @@ -0,0 +1,234 @@ +""" +Remote Transfer Plan - Abstract transfer planning for NCCL and RDMA weight updates. + +This module provides a unified interface for determining transfer sources and planning +weight transfer tasks across different communication backends (NCCL, RDMA). +""" + +import logging +from argparse import Namespace +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal + +import torch +from megatron.core import mpu + +from .common import named_params_and_buffers, split_expert_and_non_expert_param_names + +logger = logging.getLogger(__name__) + + +@dataclass +class TransferTask: + """ + Attributes: + session: Session identifier (e.g., NCCL group name or Transfer Engine Session Id) + tensor_names: Full list of tensor names to be transferred in this task as they appear in the engine API interface. + """ + + tensor_names: list[str] + session: str # NCCL group name or target entity id. + + +@dataclass +class TransferTaskP2PMeta: + """ + Specifies a engine rollout rank to connect to. + """ + + engine_ind: int # The index of the target rollout engine. + engine_rank: int # The shard within the target rollout engine. + group: Literal["expert", "non-expert"] + + +class RemoteTransferPlan: + """ + Plans and manages remote weight transfers for both NCCL and RDMA backends, assuming static training and rollout placements. + + At the moment, the plan assumes an all-gather in the tp/ep dimension on a bucketed basis. + + NCCL Plan: Use a single broadcast from DP=TP=0 PP rank to all rollout engines in a new process group. + RDMA P2P Plan: + The current execution plan prioritizes simplicity and general applicability for all supported models. It reuses existing + componenets of slime distributed update as well as sglang remote instance load mechanisms. The plan follows: + 1. Calculate total number of source full replica (up to pp dimension) after all-gather in tp/ep dimension, for both + expert and non-expert parameters. + 2. For each rollout engine, assign source ranks in a round-robin manner for both expert and non-expert parameters. + 3. During initialization, query each target rollout engine ranks for remote parameter names and session identifiers. + 4. Generate transfer tasks for each source rank based on remote parameter and local parameter availability. + + """ + + def __init__( + self, args: Namespace, model: Sequence[torch.nn.Module], mode: Literal["nccl", "rdma"] = "nccl" + ) -> None: + """ + Initialize the transfer plan. + + Args: + args: Configuration namespace containing parallelism settings + mode: Transfer backend mode - either "nccl" or "rdma" + """ + self.mode = mode + self._get_parallel_info(args) + # Cache the parameter names for transfer planning + self.expert_param_names, self.param_names = split_expert_and_non_expert_param_names( + name for name, _ in named_params_and_buffers(args, model) + ) + self.targets: list[TransferTaskP2PMeta] = self._plan_p2p() if mode == "rdma" else [] + self.transfer_tasks: list[TransferTask] = [] + + # Whether duplicated transfer tasks have been merged + self._merge_checked = False + + def _get_parallel_info(self, args: Namespace) -> None: + # Gather the source (current trainer) information. + self._pp_rank, self._pp_size = ( + mpu.get_pipeline_model_parallel_rank(), + mpu.get_pipeline_model_parallel_world_size(), + ) + self._ep_rank, self._ep_size = mpu.get_expert_model_parallel_rank(), mpu.get_expert_model_parallel_world_size() + self._tp_rank, self._tp_size = mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_world_size() + self._etp_rank, self._etp_size = ( + mpu.get_expert_tensor_parallel_rank(), + mpu.get_expert_tensor_parallel_world_size(), + ) + self._dp_rank, self._dp_size = mpu.get_data_parallel_rank( + with_context_parallel=True + ), mpu.get_data_parallel_world_size(with_context_parallel=True) + + # Gather the target (rollout engine count and parallelism) information. + self._rollout_tp_size = args.sglang_tp_size + self._rollout_dp_size = args.sglang_dp_size + self._rollout_ep_size = args.sglang_ep_size + # PP sizes are not supported currently. + self._rollout_pp_size = args.sglang_pp_size + if self._rollout_ep_size != 1: + raise NotImplementedError("Rollout expert parallelism is not supported yet.") + num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) + self._rollout_engine_count = args.rollout_num_gpus // num_gpu_per_engine + logger.info( + f"RemoteTransferPlan initialized: mode={self.mode}, pp_rank={self._pp_rank}/{self._pp_size}, tp_rank={self._tp_rank}/{self._tp_size}, " + f"ep_rank={self._ep_rank}/{self._ep_size}, etp_rank={self._etp_rank}/{self._etp_size}, dp_rank={self._dp_rank}/{self._dp_size}" + ) + logger.info( + f"Rollout engine count: {self._rollout_engine_count}, tp_size={self._rollout_tp_size}, ep_size={self._rollout_ep_size}, dp_size={self._rollout_dp_size}" + ) + + # Expert and non expert parameters can have different parallel groups after all-gather. + self._gathered_dp_size = self._dp_size * self._tp_size + self._gathered_dp_rank = self._dp_rank * self._tp_size + self._tp_rank + expert_tp_size = self._ep_size * self._etp_size + self._gathered_expert_dp_size = self._dp_size * expert_tp_size + self._gathered_expert_dp_rank = ( + self._dp_rank * expert_tp_size + self._ep_rank * self._etp_size + self._etp_rank + ) + logger.info( + f"Gathered dp_size={self._gathered_dp_size}, gathered expert dp_size={self._gathered_expert_dp_size}" + ) + logger.info( + f"Gathered dp_rank={self._gathered_dp_rank}, gathered expert dp_rank={self._gathered_expert_dp_rank}" + ) + + def _plan_p2p(self) -> list[TransferTaskP2PMeta]: + def plan( + source_size: int, + source_rank: int, + num_rank_in_target: int, + num_targets: int, + params: str, + cur_active_rank: int = 0, + ) -> list[TransferTaskP2PMeta]: + transfer_tasks = [] + for target_ind in range(num_targets): + for target_rank in range(num_rank_in_target): + if cur_active_rank % source_size == source_rank: + transfer_tasks.append( + TransferTaskP2PMeta(engine_ind=target_ind, engine_rank=target_rank, group=params) + ) + logger.info( + f"Planned P2P transfer task: source_rank={source_rank} -> target_engine_ind={target_ind}, target_engine_rank={target_rank}, group={params}" + ) + cur_active_rank += 1 + return transfer_tasks + + non_expert_plan = plan( + source_size=self._gathered_dp_size, + source_rank=self._gathered_dp_rank, + num_rank_in_target=self._rollout_dp_size * self._rollout_tp_size, + num_targets=self._rollout_engine_count, + params="non-expert", + ) + offset = len(non_expert_plan) + # Offset the current active rank by the number of non-expert transfer tasks to avoid overloading first few ranks. + return non_expert_plan + plan( + source_size=self._gathered_expert_dp_size, + source_rank=self._gathered_expert_dp_rank, + num_rank_in_target=self._rollout_dp_size * self._rollout_ep_size, + num_targets=self._rollout_engine_count, + params="expert", + cur_active_rank=offset, + ) + + def is_source(self) -> bool: + """ + Determine if the current rank needs to initiate weight transfer. + + Returns: + bool - True if the current rank is a source for weight transfer, False otherwise. + """ + if self.mode == "nccl": + # NCCL only load from DP=TP=0 PP ranks to all rollout engines. + return ( + mpu.get_data_parallel_rank(with_context_parallel=True) == 0 + and mpu.get_tensor_model_parallel_rank() == 0 + ) + return len(self.targets) > 0 + + def add_transfer_task( + self, session: str, remote_tensor_names: list[str], param_group: Literal["expert", "non-expert"] + ) -> None: + """ + Add a transfer task to the plan using remote instance session and tensor names; Only transfer parameters that are + available to this rank using the pre-cached parameter names. + """ + if param_group == "expert": + expected_names = set(self.expert_param_names) + self.transfer_tasks.append( + TransferTask(session=session, tensor_names=set(remote_tensor_names).intersection(expected_names)) + ) + logger.info( + f"Added expert parameter transfer task: session={session}, num_tensors={len(remote_tensor_names)}" + ) + else: + expected_names = set(self.param_names) + self.transfer_tasks.append( + TransferTask(session=session, tensor_names=set(remote_tensor_names).intersection(expected_names)) + ) + logger.info(f"Added parameter transfer task: session={session}, num_tensors={len(remote_tensor_names)}") + + def clear_transfer_tasks(self) -> None: + self.transfer_tasks = [] + + def _merge_transfer_tasks(self) -> None: + # In case transfer tasks share the same sesssion, merge them. + if not self._merge_checked: + tasks_by_session: dict[str, TransferTask] = {} + for task in self.transfer_tasks: + if task.session not in tasks_by_session: + tasks_by_session[task.session] = TransferTask(session=task.session, tensor_names=[]) + tasks_by_session[task.session].tensor_names.extend(task.tensor_names) + self.transfer_tasks = list(tasks_by_session.values()) + self._merge_checked = True + + def get_transfer_tasks(self) -> list[TransferTask]: + # Generate session identifier based on mode + if self.mode == "nccl": + session = f"slime-pp_{self._pp_rank}" + # In NCCL mode, the transfer is simply a broadcast from DP=TP=0 to all rollout engines. + return [TransferTask(session=session, tensor_names=self.param_names + self.expert_param_names)] + if self.targets and not self.transfer_tasks: + raise RuntimeError("RDMA need to query target engine information for transfer task generations.") + self._merge_transfer_tasks() + return self.transfer_tasks diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index d6317daa66..34082d94e4 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -6,18 +6,16 @@ import ray import torch import torch.distributed as dist -from megatron.core import mpu from ray import ObjectRef from ray.actor import ActorHandle from tqdm import tqdm -from slime.utils.distributed_utils import get_gloo_group, init_process_group +from slime.utils.distributed_utils import init_process_group -from ..megatron_to_hf import convert_to_hf -from .common import all_gather_param, named_params_and_buffers +from .update_weight_from_remote import UpdateWeightFromRemote -class UpdateWeightFromDistributed: +class UpdateWeightFromDistributed(UpdateWeightFromRemote): """ Update distributed engines via NCCL. Each PP rank: group "slime-pp_{pp_rank}", only DP=TP=0 broadcasts. Non-expert (TP) and expert (EP) params separate. @@ -35,11 +33,7 @@ def __init__( """ Initialize. Groups created in connect_rollout_engines. """ - self.args = args - self.model = model - self.model_name = model_name - self.quantization_config = quantization_config - self.weight_version = 0 + super().__init__(args, model, weights_getter, model_name=model_name, quantization_config=quantization_config) self._model_update_groups = None def connect_rollout_engines( @@ -54,14 +48,12 @@ def connect_rollout_engines( # For TP: # 1. AllGather paramters to rank 0 # 2. Broadcast parameters from rank 0 to all sglang engines - self._is_pp_src_rank = ( - mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 - ) - pp_rank = mpu.get_pipeline_model_parallel_rank() - if self._is_pp_src_rank: - self._group_name = f"slime-pp_{pp_rank}" - - if self._is_pp_src_rank: + self._is_source = self.transfer_plan.is_source() + if self._is_source: + transfer_tasks = self.transfer_plan.get_transfer_tasks() + assert self.transfer_plan.mode == "nccl", "Only NCCL supported currently." + assert len(transfer_tasks) == 1, "Only single transfer task supported currently." + self._group_name, self._tensor_names = transfer_tasks[0].session, transfer_tasks[0].tensor_names if self._model_update_groups is not None: disconnect_rollout_engines_from_distributed( self.args, self._group_name, self._model_update_groups, self.rollout_engines @@ -70,139 +62,7 @@ def connect_rollout_engines( self.args, self._group_name, rollout_engines ) - @torch.no_grad() - def update_weights(self) -> None: - """ - Pause → flush → non-expert (TP) → expert (EP) → continue. Progress on PP source. - """ - self.weight_version += 1 - - if dist.get_rank() == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - - buffer_size = 0 - converted_named_tensors = [] - # non expert params - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." in name: - continue - buffer_size = self._update_weight_from_distributed( - name, param, converted_named_tensors, buffer_size, pbar=pbar - ) - - if converted_named_tensors: - self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) - - dist.barrier(group=get_gloo_group()) - - buffer_size = 0 - named_tensors = [] - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." not in name: - continue - buffer_size = self._update_expert_weight_from_distributed( - name, param, named_tensors, buffer_size, pbar=pbar - ) - - if named_tensors: - self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) - - dist.barrier(group=get_gloo_group()) - if dist.get_rank() == 0: - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - - def _update_weight_from_distributed( - self, - name: str, - param: torch.nn.Parameter, - converted_named_tensors: list[tuple[str, torch.Tensor]], - buffer_size: int, - pbar: tqdm | None = None, - ) -> int | None: - """ - Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers. - Returns updated bytes on source, None on non-source. - """ - param = all_gather_param(name, param) - if not self._is_pp_src_rank: - return - - param_size = param.numel() * param.element_size() - if buffer_size + param_size > self.args.update_weight_buffer_size: - self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) - buffer_size = 0 - converted_named_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - buffer_size += param_size - return buffer_size - - def _update_expert_weight_from_distributed( - self, - name: str, - param: torch.nn.Parameter, - named_tensors: list[tuple[str, torch.Tensor]], - buffer_size: int, - pbar: tqdm | None = None, - ) -> int: - """ - Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size. - """ - param = all_gather_param(name, param) - - param_size = param.numel() * param.element_size() - if ( - buffer_size + param_size - ) * mpu.get_expert_model_parallel_world_size() > self.args.update_weight_buffer_size: - self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) - buffer_size = 0 - - named_tensors.append((name, param)) - buffer_size += param_size - return buffer_size - - def _update_expert_bucket_weights_from_distributed( - self, named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None - ) -> None: - """ - Gather EP → HF → broadcast. Clears buffer. - """ - names = [name for name, _ in named_tensors] - all_names = [None] * mpu.get_expert_model_parallel_world_size() - dist.all_gather_object(all_names, names, group=mpu.get_expert_model_parallel_group()) - - for names in all_names: - assert len(named_tensors) == len(names), f"mismatch names length: {len(named_tensors)} != {len(names)}" - - all_gathered_params = [[] for _ in range(mpu.get_expert_model_parallel_world_size())] - handles = [] - for i, (_name, param) in enumerate(named_tensors): - params = [ - torch.empty_like(param.data, device=torch.cuda.current_device()) - for _ in range(mpu.get_expert_model_parallel_world_size()) - ] - handle = dist.all_gather(params, param.data, group=mpu.get_expert_model_parallel_group(), async_op=True) - handles.append(handle) - for ep_rank, names in enumerate(all_names): - all_gathered_params[ep_rank].append((names[i], params[ep_rank])) - for handle in handles: - handle.wait() - - named_tensors.clear() - if not self._is_pp_src_rank: - return - - all_gathered_params = sum(all_gathered_params, []) - converted_hf_tensors = [] - for name, param in all_gathered_params: - converted_hf_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - - self._update_bucket_weights_from_distributed(converted_hf_tensors, pbar) - - def _update_bucket_weights_from_distributed( + def _update_bucket_weights_from_remote( self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: """ @@ -211,7 +71,6 @@ def _update_bucket_weights_from_distributed( # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) - refs = update_weights_from_distributed( self._group_name, self._model_update_groups, diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py new file mode 100644 index 0000000000..659e5cc3da --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -0,0 +1,159 @@ +import logging +import socket +import time +from argparse import Namespace +from collections.abc import Callable, Mapping, Sequence + +import ray +import torch +from ray.actor import ActorHandle +from tqdm import tqdm + +from .common import split_expert_and_non_expert_param_names +from .update_weight_from_remote import UpdateWeightFromRemote + +logger = logging.getLogger(__name__) + + +class UpdateWeightFromRDMA(UpdateWeightFromRemote): + """ + Update weights from RDMA using Transfer Engine. + + Similar to UpdateWeightFromNCCL but uses P2P RDMA transfer engine for the underlying weight transfer. Workflow + consists of following steps: + 1. Based off the transfer plan, query the target rollout engines for remote session and weight info during connect_rollout_engines. + 2. Do TP-EP all-gather for bucketed weights on parameters needing transfer from local just as in NCCL case. + 3. Convert the gathered HF tensor into target shape and register them with Engine. + 4. Call engine to batch transfer weights for each transfer task. + """ + + def __init__( + self, + args: Namespace, + model: Sequence[torch.nn.Module], + weights_getter: Callable[[], Mapping[str, torch.Tensor]], + *, + model_name: str, + quantization_config: dict[str, int | str | list[str]] | None, + vocab_size: int, + ) -> None: + """ + Initialize. P2PTrainingTransferEngine created in connect_rollout_engines. + Calls parent constructor and adds P2P RDMA specific attributes. + """ + # Call parent constructor to initialize all base attributes + super().__init__( + args, + model, + weights_getter, + model_name=model_name, + quantization_config=quantization_config, + vocab_size=vocab_size, + ) + + # P2P RDMA specific initialization + self.training_p2p_transfer_engine = None + self.session_id = None + + def connect_rollout_engines( + self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle + ) -> None: + """ + Initialize P2PTrainingTransferEngine if serves as a source. + """ + # Store rollout engines and lock + self.rollout_engines = rollout_engines + self.rollout_engine_lock = rollout_engine_lock + self._is_source = self.transfer_plan.is_source() + + # Initialize P2PTrainingTransferEngine on source rank + if self._is_source: + if self.training_p2p_transfer_engine is not None: + self.training_p2p_transfer_engine.stop() + self.session_id = None + self.transfer_plan.clear_transfer_tasks() + + # Get master address and port for P2P communication + local_ip = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + port = sock.getsockname()[1] + + # Initialize P2PTrainingTransferEngine + # self.training_p2p_transfer_engine = P2PTrainingTransferEngine( + # master_ip=local_ip, + # master_port=port, + # gpu_id=None, + # ib_device=None, + # ) + self.training_p2p_transfer_engine.start() + self.session_id = f"{local_ip}:{port}" + logger.info(f"P2PTrainingTransferEngine started on {local_ip}:{port}") + + # Query Engine session and weight info from rollout instances according to the transfer plan + self.remote_weight_infos_by_engine_and_rank = {} + for target in self.transfer_plan.targets: + if (target.engine_ind, target.engine_rank) not in self.remote_weight_infos_by_engine_and_rank: + self.remote_weight_infos_by_engine_and_rank[(target.engine_ind, target.engine_rank)] = ray.get( + self.rollout_engines[target.engine_ind].get_remote_instance_transfer_engine_info.remote( + rank=target.engine_rank + )["remote_instance_transfer_engine_info"] + ) + logger.info( + f"Obtained remote session info from rollout engine {target.engine_ind} rank {target.engine_rank}" + ) + remote_session_id, remote_weight_info = self.remote_weight_infos_by_engine_and_rank[ + (target.engine_ind, target.engine_rank) + ] + expert_params, non_expert_params = split_expert_and_non_expert_param_names(remote_weight_info.keys()) + self.transfer_plan.add_transfer_task( + session=remote_session_id, + remote_tensor_names=expert_params if target.group == "expert" else non_expert_params, + param_group=target.group, + ) + + def _update_bucket_weights_from_remote( + self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + ) -> None: + """ + Register weights with P2PTrainingTransferEngine and wait for transfers to complete. + Based on lines 518-545 in SGLang test: register_weights pattern. + Overrides parent method to use P2P RDMA instead of NCCL broadcast. + """ + + # TODO(jd): pin the memory for GPUs after resharding to avoid expensive registration. + if not self._is_source or not converted_named_tensors: + return + + # Lock the rollout engines to prevent concurrent operations (same as parent) + while not ray.get(self.rollout_engine_lock.acquire.remote()): + time.sleep(0.1) + + try: + # Register all weights with the P2P training transfer engine + # This follows the pattern from SGLang test lines 518-537 + for name, tensor in converted_named_tensors: + self.training_p2p_transfer_engine.register_buffer(name, tensor) + + # Initiate weight transfer to all rollout engines as per the transfer plan + refs = [ + engine.update_weights_from_distributed.remote( + names=[name for name, _ in converted_named_tensors], + dtypes=[param.dtype for _, param in converted_named_tensors], + shapes=[param.shape for _, param in converted_named_tensors], + group_name=self._group_name, + weight_version=str(self.weight_version), + session_id=f"{self.master_addr}:{self.master_port}", # Pass P2P session info + ) + for engine in self.rollout_engines + ] + + # Wait for all P2P transfers to complete + ray.get(refs) + converted_named_tensors.clear() + + finally: + # Release the lock (same as parent) + ray.get(self.rollout_engine_lock.release.remote()) + if pbar: + pbar.update(1) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py new file mode 100644 index 0000000000..472f5cf887 --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py @@ -0,0 +1,193 @@ +from abc import abstractmethod +from argparse import Namespace +from collections.abc import Callable, Mapping, Sequence + +import ray +import torch +import torch.distributed as dist +from megatron.core import mpu +from ray.actor import ActorHandle +from tqdm import tqdm + +from slime.utils.distributed_utils import get_gloo_group +from slime.utils.timer import timer + +from ..megatron_to_hf import convert_to_hf +from .common import all_gather_param, named_params_and_buffers +from .remote_transfer_plan import RemoteTransferPlan + + +class UpdateWeightFromRemote: + """ + Abstract base class for remote bucketed tensor weight update. Weights are all-gathered in TP EP dimension + for each bucket of predefined size and processed into HF format. + """ + + def __init__( + self, + args: Namespace, + model: Sequence[torch.nn.Module], + weights_getter: Callable[[], Mapping[str, torch.Tensor]], + *, + model_name: str, + quantization_config: dict[str, int | str | list[str]] | None, + ) -> None: + """ + Initialize. Groups created in connect_rollout_engines. + """ + self.args = args + self.model = model + self.model_name = model_name + self.quantization_config = quantization_config + self.weight_version = 0 + self.transfer_plan = RemoteTransferPlan(args, model, args.update_weight_transfer_mode) + + @abstractmethod + def connect_rollout_engines( + self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle + ) -> None: + """ + Establish connection to remote rollout engines. + """ + + @abstractmethod + def _update_bucket_weights_from_remote( + self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + ) -> None: + """ + Implementation of the bucketed parameter update from remote. + """ + + @torch.no_grad() + def update_weights(self) -> None: + """ + Pause → flush → non-expert (TP) → expert (EP) → continue. Progress on PP source. + """ + self.weight_version += 1 + + if dist.get_rank() == 0: + ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + with timer("update_weights_implementation"): + buffer_size = 0 + converted_named_tensors = [] + # non expert params + pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_source else None + + for name, param in named_params_and_buffers(self.args, self.model): + # transfer tp tensors + if name not in self._tensor_names or ".experts." in name: + continue + buffer_size = self._update_weight_from_remote( + name, param, converted_named_tensors, buffer_size, pbar=pbar + ) + + if converted_named_tensors: + self._update_bucket_weights_from_remote(converted_named_tensors, pbar=pbar) + + dist.barrier(group=get_gloo_group()) + + buffer_size = 0 + named_tensors = [] + for name, param in named_params_and_buffers(self.args, self.model): + # transfer expert tensors + if name not in self._tensor_names or ".experts." not in name: + continue + buffer_size = self._update_expert_weight_from_remote( + name, param, named_tensors, buffer_size, pbar=pbar + ) + + if named_tensors: + self._update_expert_bucket_weights_from_remote(named_tensors, pbar=pbar) + + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + def _update_weight_from_remote( + self, + name: str, + param: torch.nn.Parameter, + converted_named_tensors: list[tuple[str, torch.Tensor]], + buffer_size: int, + pbar: tqdm | None = None, + ) -> int | None: + """ + Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers. + Returns updated bytes on source, None on non-source. + """ + param = all_gather_param(name, param) + if not self._is_source: + return + + param_size = param.numel() * param.element_size() + if buffer_size + param_size > self.args.update_weight_buffer_size: + self._update_bucket_weights_from_remote(converted_named_tensors, pbar=pbar) + buffer_size = 0 + converted_named_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) + buffer_size += param_size + return buffer_size + + def _update_expert_weight_from_remote( + self, + name: str, + param: torch.nn.Parameter, + named_tensors: list[tuple[str, torch.Tensor]], + buffer_size: int, + pbar: tqdm | None = None, + ) -> int: + """ + Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size. + """ + param = all_gather_param(name, param) + + param_size = param.numel() * param.element_size() + if ( + buffer_size + param_size + ) * mpu.get_expert_model_parallel_world_size() > self.args.update_weight_buffer_size and named_tensors: + self._update_expert_bucket_weights_from_remote(named_tensors, pbar=pbar) + buffer_size = 0 + + named_tensors.append((name, param)) + buffer_size += param_size + return buffer_size + + def _update_expert_bucket_weights_from_remote( + self, named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + ) -> None: + """ + Gather EP → HF → broadcast. Clears buffer. + """ + names = [name for name, _ in named_tensors] + all_names = [None] * mpu.get_expert_model_parallel_world_size() + dist.all_gather_object(all_names, names, group=mpu.get_expert_model_parallel_group()) + + for names in all_names: + assert len(named_tensors) == len(names), f"mismatch names length: {len(named_tensors)} != {len(names)}" + + all_gathered_params = [[] for _ in range(mpu.get_expert_model_parallel_world_size())] + handles = [] + for i, (_name, param) in enumerate(named_tensors): + params = [ + torch.empty_like(param.data, device=torch.cuda.current_device()) + for _ in range(mpu.get_expert_model_parallel_world_size()) + ] + handle = dist.all_gather(params, param.data, group=mpu.get_expert_model_parallel_group(), async_op=True) + handles.append(handle) + for ep_rank, names in enumerate(all_names): + all_gathered_params[ep_rank].append((names[i], params[ep_rank])) + for handle in handles: + handle.wait() + + named_tensors.clear() + if not self._is_source: + return + + all_gathered_params = sum(all_gathered_params, []) + converted_hf_tensors = [] + for name, param in all_gathered_params: + converted_hf_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) + + self._update_bucket_weights_from_remote(converted_hf_tensors, pbar) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index ab9ad9e9fd..c2c12d06fa 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -182,7 +182,9 @@ def _make_request(self, endpoint: str, payload: dict | None = None): return url = f"http://{self.server_host}:{self.server_port}/{endpoint}" + response = requests.post(url, json=payload or {}) + try: response.raise_for_status() except requests.exceptions.HTTPError as e: diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index ade651cedf..d8ae7670d5 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -148,6 +148,12 @@ def add_train_arguments(parser): default="raw", help="The method to convert megatron weights to hugging face weights for SGLang.", ) + parser.add_argument( + "--update-weight-transfer-mode", + choices=["nccl", "rdma"], + default="nccl", + help="The method to transfer weights to remote rollout engines during update weight.", + ) return parser @@ -381,6 +387,21 @@ def add_rollout_arguments(parser): nargs="+", help="Address and ports of the external engines.", ) + # from https://github.com/Risc-lt/sglang/blob/cc8883ff7bf63dda8627cc696d49055e0c573d5b/python/sglang/srt/model_executor/model_runner.py + parser.add_argument( + "--update-weights-p2p-transfer", + action="store_true", + default=False, + help="Enable P2P weight transfer between GPUs when updating model weights in RL training.", + ) + parser.add_argument( + "--p2p-transfer-ib-device", + type=str, + default=None, + help="The InfiniBand devices for P2P transfer, accepts single device (e.g., --p2p-transfer-ib-device mlx5_0) " + "or multiple comma-separated devices (e.g., --p2p-transfer-ib-device mlx5_0,mlx5_1). " + "Default is None, which triggers automatic device detection when mooncake backend is enabled.", + ) return parser def add_fault_tolerance_arguments(parser): diff --git a/slime/utils/timer.py b/slime/utils/timer.py index f5610842f6..242c524f87 100644 --- a/slime/utils/timer.py +++ b/slime/utils/timer.py @@ -11,6 +11,8 @@ logger = logging.getLogger(__name__) +LOGFILE = "slime_timer" + class Timer(metaclass=SingletonMeta): def __init__(self): @@ -28,8 +30,11 @@ def end(self, name): elapsed_time = time() - self.start_time[name] self.add(name, elapsed_time) del self.start_time[name] - if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if rank == 0: logger.info(f"Timer {name} end (elapsed: {elapsed_time:.1f}s)") + with open(f"{LOGFILE}_{rank}.log", "a") as f: + f.write(f"Timer {name} end (elapsed: {elapsed_time*1000:.1f}ms)\n") def reset(self, name=None): if name is None: diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index a42bb42a53..03a66fc905 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -1,4 +1,3 @@ -import os import slime.utils.external_utils.command_utils as U FEW_GPU = U.get_bool_env_var("SLIME_TEST_FEW_GPU", "1") @@ -125,8 +124,8 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + # os.environ.pop("http_proxy") + # os.environ.pop("https_proxy") + # os.environ.pop("HTTP_PROXY") + # os.environ.pop("HTTPS_PROXY") execute() diff --git a/tests/test_weight_transfer.py b/tests/test_weight_transfer.py new file mode 100644 index 0000000000..443d285a3d --- /dev/null +++ b/tests/test_weight_transfer.py @@ -0,0 +1,114 @@ +import os + +import slime.utils.external_utils.command_utils as U + + +TIGHT_HOST_MEMORY = bool(int(os.environ.get("SLIME_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 3 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 100 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + # Training parallellism settings + perf_args = ( + "--tensor-model-parallel-size 1 " + # "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 2048 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 1 " "--rollout-num-gpus 2 " "--sglang-mem-fraction-static 0.8 " + + # ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 1 " + # 1GB buffer for weight update + f"--update-weight-buffer-size {1 * 1024 ** 3} " + # "--check-weight-update-equal " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + # f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + train_script="train_async.py", + # extra_env_vars={"RAY_DEBUG": "1"}, + ) + + +if __name__ == "__main__": + prepare() + execute() diff --git a/tests/timing_comparison_guide.md b/tests/timing_comparison_guide.md new file mode 100644 index 0000000000..278fc7cfbc --- /dev/null +++ b/tests/timing_comparison_guide.md @@ -0,0 +1,65 @@ +# Qwen32B P2P vs Distributed Transfer Performance Comparison + +## Overview +This document describes the timing measurements for comparing P2P Transfer Engine performance against torch.distributed baseline. + +## Test Methods + +### 1. P2P Transfer Test: `test_qwen32b_model_transfer` +**Command**: +```bash +python -m pytest tests/test_p2p_engine_model_qwen2.py::TestQwen32BP2PTransfer::test_qwen32b_model_transfer -v -s --tb=short --capture=no +``` + +**Results**: Saved to `/root/slime/tests/p2p_timing_results.json` + +### 2. Distributed Baseline Test: `test_qwen32b_model_transfer_baseline` +**Command**: +```bash +python -m pytest tests/test_p2p_engine_model_qwen2.py::TestQwen32BP2PTransfer::test_qwen32b_model_transfer_baseline -v -s --tb=short --capture=no +``` + +**Results**: Saved to `/root/slime/tests/distributed_timing_results.json` + +## Timing Metrics + +### P2P Transfer Engine Metrics + +#### Training Side: +- `register_and_start_time`: Time to register weights and start the P2P training engine +- `update_weights_time`: Time to update/re-register weights +- `stop_time`: Time to stop the training engine + + +#### Rollout Side: +- `submit_tasks_time`: Time to submit all transfer tasks +- `wait_transfers_time`: Time waiting for all transfers to complete +- `sync_time`: CUDA synchronization time +- `total_transfer_time`: Pure transfer time (submit to sync) + +### Distributed Baseline Metrics + +#### Training Side: +- `init_process_group_time`: Time to initialize NCCL process group +- `broadcast_time`: Time to broadcast all weights to other processes +- `destroy_group_time`: Time to destroy the process group + + +#### Rollout Side: +- `init_process_group_time`: Time to initialize NCCL process group +- `broadcast_time`: Time to receive weights via broadcast +- `sync_time`: CUDA synchronization time +- `destroy_group_time`: Time to destroy the process group +- `total_transfer_time`: Pure transfer time (broadcast + sync) + + +## Key Performance Comparisons + +1. **Core Transfer Time**: + - P2P: `wait_transfers_time` vs Distributed: `broadcast_time` + +2. **Setup/Teardown Overhead**: + - P2P: `register_and_start_time` + `stop_time` vs Distributed: `init_process_group_time` + `destroy_group_time` + +3. **Total Transfer Pipeline**: + - P2P: `total_transfer_time` vs Distributed: `total_transfer_time` From 91d3ab9435d580381b9d67e1f41a1e3380ffffe2 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Mon, 15 Dec 2025 01:23:14 +0000 Subject: [PATCH 02/11] remove old benchmark --- tests/timing_comparison_guide.md | 65 -------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 tests/timing_comparison_guide.md diff --git a/tests/timing_comparison_guide.md b/tests/timing_comparison_guide.md deleted file mode 100644 index 278fc7cfbc..0000000000 --- a/tests/timing_comparison_guide.md +++ /dev/null @@ -1,65 +0,0 @@ -# Qwen32B P2P vs Distributed Transfer Performance Comparison - -## Overview -This document describes the timing measurements for comparing P2P Transfer Engine performance against torch.distributed baseline. - -## Test Methods - -### 1. P2P Transfer Test: `test_qwen32b_model_transfer` -**Command**: -```bash -python -m pytest tests/test_p2p_engine_model_qwen2.py::TestQwen32BP2PTransfer::test_qwen32b_model_transfer -v -s --tb=short --capture=no -``` - -**Results**: Saved to `/root/slime/tests/p2p_timing_results.json` - -### 2. Distributed Baseline Test: `test_qwen32b_model_transfer_baseline` -**Command**: -```bash -python -m pytest tests/test_p2p_engine_model_qwen2.py::TestQwen32BP2PTransfer::test_qwen32b_model_transfer_baseline -v -s --tb=short --capture=no -``` - -**Results**: Saved to `/root/slime/tests/distributed_timing_results.json` - -## Timing Metrics - -### P2P Transfer Engine Metrics - -#### Training Side: -- `register_and_start_time`: Time to register weights and start the P2P training engine -- `update_weights_time`: Time to update/re-register weights -- `stop_time`: Time to stop the training engine - - -#### Rollout Side: -- `submit_tasks_time`: Time to submit all transfer tasks -- `wait_transfers_time`: Time waiting for all transfers to complete -- `sync_time`: CUDA synchronization time -- `total_transfer_time`: Pure transfer time (submit to sync) - -### Distributed Baseline Metrics - -#### Training Side: -- `init_process_group_time`: Time to initialize NCCL process group -- `broadcast_time`: Time to broadcast all weights to other processes -- `destroy_group_time`: Time to destroy the process group - - -#### Rollout Side: -- `init_process_group_time`: Time to initialize NCCL process group -- `broadcast_time`: Time to receive weights via broadcast -- `sync_time`: CUDA synchronization time -- `destroy_group_time`: Time to destroy the process group -- `total_transfer_time`: Pure transfer time (broadcast + sync) - - -## Key Performance Comparisons - -1. **Core Transfer Time**: - - P2P: `wait_transfers_time` vs Distributed: `broadcast_time` - -2. **Setup/Teardown Overhead**: - - P2P: `register_and_start_time` + `stop_time` vs Distributed: `init_process_group_time` + `destroy_group_time` - -3. **Total Transfer Pipeline**: - - P2P: `total_transfer_time` vs Distributed: `total_transfer_time` From 3ed9867ac0cc613866fae5db2c9581815dce6943 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Mon, 15 Dec 2025 06:32:54 +0000 Subject: [PATCH 03/11] initial impl --- slime/backends/megatron_utils/actor.py | 10 +- .../megatron_utils/update_weight/common.py | 102 +++++++++++- .../update_weight/remote_transfer_plan.py | 64 +++----- .../update_weight_from_distributed.py | 25 ++- .../update_weight/update_weight_from_rdma.py | 145 ++++++++++-------- .../update_weight_from_remote.py | 101 +++++++----- slime/backends/sglang_utils/sglang_engine.py | 8 + tests/test_weight_transfer.py | 30 +++- 8 files changed, 312 insertions(+), 173 deletions(-) diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 75f6f8c5b8..d21342a97a 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -35,6 +35,7 @@ from .model import forward_only, initialize_model_and_optimizer, save, train from .update_weight.common import named_params_and_buffers from .update_weight.update_weight_from_distributed import UpdateWeightFromDistributed +from .update_weight.update_weight_from_rdma import UpdateWeightFromRDMA from .update_weight.update_weight_from_tensor import UpdateWeightFromTensor logging.getLogger("megatron").setLevel(logging.WARNING) @@ -112,8 +113,13 @@ def init( if self.args.vocab_size is None: self.args.vocab_size = self.tokenizer.vocab_size - - update_weight_cls = UpdateWeightFromTensor if self.args.colocate else UpdateWeightFromDistributed + if self.args.colocate: + update_weight_cls = UpdateWeightFromTensor + else: + if self.args.update_weight_transfer_mode == "nccl": + update_weight_cls = UpdateWeightFromDistributed + else: + update_weight_cls = UpdateWeightFromRDMA self.weight_updater = update_weight_cls( self.args, self.model, diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index d515b2f963..51fa0759ca 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -130,15 +130,47 @@ def named_params_and_buffers( return ans -def split_expert_and_non_expert_param_names(params: Iterator[str]): - expert_param_names = [] - non_expert_param_names = [] - for param in params: - if ".experts." in param: - expert_param_names.append(param) +def split_expert_and_non_expert_param_names(param_names: Sequence[str]) -> tuple[list[str], list[str]]: + expert_params = [] + non_expert_params = [] + for name in param_names: + if ".experts." in name: + expert_params.append(name) else: - non_expert_param_names.append(param) - return expert_param_names, non_expert_param_names + non_expert_params.append(name) + return expert_params, non_expert_params + + +def non_expert_named_params_and_buffers( + args: Namespace, + model: Sequence[torch.nn.Module], + convert_to_global_name: bool = True, + translate_gpu_to_cpu: bool = False, +) -> Iterator[tuple[str, torch.Tensor]]: + for name, tensor in named_params_and_buffers( + args, + model, + convert_to_global_name, + translate_gpu_to_cpu, + ): + if ".experts." not in name: + yield name, tensor + + +def expert_named_params_and_buffers( + args: Namespace, + model: Sequence[torch.nn.Module], + convert_to_global_name: bool = True, + translate_gpu_to_cpu: bool = False, +) -> Iterator[tuple[str, torch.Tensor]]: + for name, tensor in named_params_and_buffers( + args, + model, + convert_to_global_name, + translate_gpu_to_cpu, + ): + if ".experts." in name: + yield name, tensor def _maybe_get_cpu_backup(x: torch.Tensor): @@ -244,3 +276,57 @@ def _named_params_and_buffers_global( layer_idx, rest = match.groups() layer_idx = int(layer_idx) + layer_offset yield f"module.module.decoder.layers.{layer_idx}.{rest}", buffer + + +def register_memory_transfer_engine(named_param_with_buffers: Sequence[tuple[str, torch.Tensor]], engine) -> None: + """ + Efficient memory registration for transfer engine that reduce total registration count by batching continuous memory regions. + """ + + weight_mr_dict = {} + weight_addr_set = set() + for name, weight in named_param_with_buffers: + weight_mr_dict[name] = ( + weight.data_ptr(), + weight.numel(), + weight.element_size(), + ) + weight_addr_set.add(weight.data_ptr()) + memory_snapshot = torch.cuda.memory.memory_snapshot() + weight_blocks_for_reg_mr = [] + # Blocks in each segment have continuous physical addresses, + # so they can be merged for memory registration. + for segment in memory_snapshot: + current_weight_block = None + blocks = segment.get("blocks", []) + for block in blocks: + address = block.get("address", -1) + size = block.get("size", -1) + state = block.get("state", "") + if address < 0 or size < 0 or state == "": + continue + # Only register active allocated memory blocks that hold weights. + if state == "active_allocated": + if address in weight_addr_set: + if current_weight_block is None: + current_weight_block = (address, size) + elif current_weight_block[0] + current_weight_block[1] == address: + current_weight_block = ( + current_weight_block[0], + current_weight_block[1] + size, + ) + else: + weight_blocks_for_reg_mr.append(current_weight_block) + current_weight_block = (address, size) + if current_weight_block is not None: + weight_blocks_for_reg_mr.append(current_weight_block) + + # Register merged memory blocks that hold weights. + for weight_block in weight_blocks_for_reg_mr: + address, size = weight_block + ret = engine.register_memory(address, size) + if ret != 0: + raise RuntimeError( + f"register memory failed for weight block at address {address} with size {size}, error: {ret}" + ) + return weight_mr_dict diff --git a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py index a6941f1385..9806b6f9ef 100644 --- a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py +++ b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py @@ -14,7 +14,7 @@ import torch from megatron.core import mpu -from .common import named_params_and_buffers, split_expert_and_non_expert_param_names +from .common import expert_named_params_and_buffers, non_expert_named_params_and_buffers logger = logging.getLogger(__name__) @@ -24,11 +24,13 @@ class TransferTask: """ Attributes: session: Session identifier (e.g., NCCL group name or Transfer Engine Session Id) - tensor_names: Full list of tensor names to be transferred in this task as they appear in the engine API interface. + named_params_and_buffers: tensors to be transferred from this rank. + tensor_type: "expert" or "non-expert" are two diverse types of tasks. """ - tensor_names: list[str] + named_params_and_buffers: list[tuple[str, torch.Tensor]] session: str # NCCL group name or target entity id. + tensor_type: Literal["expert", "non-expert"] @dataclass @@ -72,15 +74,10 @@ def __init__( """ self.mode = mode self._get_parallel_info(args) - # Cache the parameter names for transfer planning - self.expert_param_names, self.param_names = split_expert_and_non_expert_param_names( - name for name, _ in named_params_and_buffers(args, model) - ) self.targets: list[TransferTaskP2PMeta] = self._plan_p2p() if mode == "rdma" else [] self.transfer_tasks: list[TransferTask] = [] - - # Whether duplicated transfer tasks have been merged - self._merge_checked = False + self.non_expert_params_buffers = list(non_expert_named_params_and_buffers(args, model)) + self.expert_params_buffers = list(expert_named_params_and_buffers(args, model)) def _get_parallel_info(self, args: Namespace) -> None: # Gather the source (current trainer) information. @@ -186,49 +183,32 @@ def is_source(self) -> bool: ) return len(self.targets) > 0 - def add_transfer_task( - self, session: str, remote_tensor_names: list[str], param_group: Literal["expert", "non-expert"] - ) -> None: + def add_transfer_task(self, session: str, param_group: Literal["expert", "non-expert"]) -> None: """ - Add a transfer task to the plan using remote instance session and tensor names; Only transfer parameters that are - available to this rank using the pre-cached parameter names. + Add a transfer task to the plan using remote instance session and tensor names. """ - if param_group == "expert": - expected_names = set(self.expert_param_names) - self.transfer_tasks.append( - TransferTask(session=session, tensor_names=set(remote_tensor_names).intersection(expected_names)) - ) - logger.info( - f"Added expert parameter transfer task: session={session}, num_tensors={len(remote_tensor_names)}" - ) - else: - expected_names = set(self.param_names) - self.transfer_tasks.append( - TransferTask(session=session, tensor_names=set(remote_tensor_names).intersection(expected_names)) - ) - logger.info(f"Added parameter transfer task: session={session}, num_tensors={len(remote_tensor_names)}") + params = self.non_expert_params_buffers if param_group == "non-expert" else self.expert_params_buffers + self.transfer_tasks.append( + TransferTask(session=session, named_params_and_buffers=params, tensor_type=param_group) + ) + logger.info(f"Added {param_group} parameter transfer task: session={session}, num_tensors={len(params)}") def clear_transfer_tasks(self) -> None: self.transfer_tasks = [] - def _merge_transfer_tasks(self) -> None: - # In case transfer tasks share the same sesssion, merge them. - if not self._merge_checked: - tasks_by_session: dict[str, TransferTask] = {} - for task in self.transfer_tasks: - if task.session not in tasks_by_session: - tasks_by_session[task.session] = TransferTask(session=task.session, tensor_names=[]) - tasks_by_session[task.session].tensor_names.extend(task.tensor_names) - self.transfer_tasks = list(tasks_by_session.values()) - self._merge_checked = True - def get_transfer_tasks(self) -> list[TransferTask]: # Generate session identifier based on mode if self.mode == "nccl": session = f"slime-pp_{self._pp_rank}" # In NCCL mode, the transfer is simply a broadcast from DP=TP=0 to all rollout engines. - return [TransferTask(session=session, tensor_names=self.param_names + self.expert_param_names)] + return [ + TransferTask( + session=session, named_params_and_buffers=self.non_expert_params_buffers, tensor_type="non-expert" + ), + TransferTask( + session=session, named_params_and_buffers=self.expert_params_buffers, tensor_type="expert" + ), + ] if self.targets and not self.transfer_tasks: raise RuntimeError("RDMA need to query target engine information for transfer task generations.") - self._merge_transfer_tasks() return self.transfer_tasks diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 34082d94e4..efae91eb49 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -33,7 +33,20 @@ def __init__( """ Initialize. Groups created in connect_rollout_engines. """ - super().__init__(args, model, weights_getter, model_name=model_name, quantization_config=quantization_config) + super().__init__( + args, + model, + weights_getter, + model_name=model_name, + quantization_config=quantization_config, + weight_update_mode="nccl", + ) + + if self._is_source: + transfer_tasks = self.transfer_plan.get_transfer_tasks() + assert self.transfer_plan.mode == "nccl", "Only NCCL supported currently." + assert len(transfer_tasks) == 2, "Only two transfer tasks supported currently." + # Indicates if the nccl group has been established. self._model_update_groups = None def connect_rollout_engines( @@ -48,13 +61,11 @@ def connect_rollout_engines( # For TP: # 1. AllGather paramters to rank 0 # 2. Broadcast parameters from rank 0 to all sglang engines - self._is_source = self.transfer_plan.is_source() if self._is_source: transfer_tasks = self.transfer_plan.get_transfer_tasks() - assert self.transfer_plan.mode == "nccl", "Only NCCL supported currently." - assert len(transfer_tasks) == 1, "Only single transfer task supported currently." - self._group_name, self._tensor_names = transfer_tasks[0].session, transfer_tasks[0].tensor_names + self._group_name = transfer_tasks[0].session if self._model_update_groups is not None: + # Reestablish group if already connected, e.g. new instance has joined. disconnect_rollout_engines_from_distributed( self.args, self._group_name, self._model_update_groups, self.rollout_engines ) @@ -63,7 +74,7 @@ def connect_rollout_engines( ) def _update_bucket_weights_from_remote( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + self, converted_named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None ) -> None: """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. @@ -72,7 +83,7 @@ def _update_bucket_weights_from_remote( while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) refs = update_weights_from_distributed( - self._group_name, + session_id, self._model_update_groups, self.weight_version, self.rollout_engines, diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 659e5cc3da..9cdc9f8f63 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -1,5 +1,4 @@ import logging -import socket import time from argparse import Namespace from collections.abc import Callable, Mapping, Sequence @@ -7,9 +6,10 @@ import ray import torch from ray.actor import ActorHandle +from srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from tqdm import tqdm -from .common import split_expert_and_non_expert_param_names +from .common import register_memory_transfer_engine, split_expert_and_non_expert_param_names from .update_weight_from_remote import UpdateWeightFromRemote logger = logging.getLogger(__name__) @@ -38,8 +38,7 @@ def __init__( vocab_size: int, ) -> None: """ - Initialize. P2PTrainingTransferEngine created in connect_rollout_engines. - Calls parent constructor and adds P2P RDMA specific attributes. + Initialize transfer engine. """ # Call parent constructor to initialize all base attributes super().__init__( @@ -49,11 +48,10 @@ def __init__( model_name=model_name, quantization_config=quantization_config, vocab_size=vocab_size, + weight_update_mode="rdma", ) - # P2P RDMA specific initialization - self.training_p2p_transfer_engine = None - self.session_id = None + self.transfer_engine = None def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle @@ -68,60 +66,57 @@ def connect_rollout_engines( # Initialize P2PTrainingTransferEngine on source rank if self._is_source: - if self.training_p2p_transfer_engine is not None: - self.training_p2p_transfer_engine.stop() - self.session_id = None - self.transfer_plan.clear_transfer_tasks() - # Get master address and port for P2P communication local_ip = ray._private.services.get_node_ip_address() - with socket.socket() as sock: - sock.bind(("", 0)) - port = sock.getsockname()[1] - - # Initialize P2PTrainingTransferEngine - # self.training_p2p_transfer_engine = P2PTrainingTransferEngine( - # master_ip=local_ip, - # master_port=port, - # gpu_id=None, - # ib_device=None, - # ) - self.training_p2p_transfer_engine.start() - self.session_id = f"{local_ip}:{port}" - logger.info(f"P2PTrainingTransferEngine started on {local_ip}:{port}") + self.transfer_engine = MooncakeTransferEngine(hostname=local_ip) + logger.info(f"Transfer Engine initialized at {self.transfer_engine.session_id}") # Query Engine session and weight info from rollout instances according to the transfer plan - self.remote_weight_infos_by_engine_and_rank = {} + self.remote_weight_infos_by_session_id = {} + targets_to_query = set((target.engine_ind, target.engine_rank) for target in self.transfer_plan.targets) + targets_to_session_id = {} + for engine_ind, engine_rank in targets_to_query: + session_id, weights_info = ray.get( + self.rollout_engines[engine_ind].get_remote_instance_transfer_engine_info.remote(rank=engine_rank)[ + "remote_instance_transfer_engine_info" + ] + ) + logger.info(f"Obtained remote session info from rollout engine {engine_ind} rank {engine_rank}") + self.remote_weight_infos_by_session_id[session_id] = weights_info + targets_to_session_id[(engine_ind, engine_rank)] = session_id + + # Associate transfer tasks based on obtained session and weight info for target in self.transfer_plan.targets: - if (target.engine_ind, target.engine_rank) not in self.remote_weight_infos_by_engine_and_rank: - self.remote_weight_infos_by_engine_and_rank[(target.engine_ind, target.engine_rank)] = ray.get( - self.rollout_engines[target.engine_ind].get_remote_instance_transfer_engine_info.remote( - rank=target.engine_rank - )["remote_instance_transfer_engine_info"] - ) - logger.info( - f"Obtained remote session info from rollout engine {target.engine_ind} rank {target.engine_rank}" - ) - remote_session_id, remote_weight_info = self.remote_weight_infos_by_engine_and_rank[ - (target.engine_ind, target.engine_rank) - ] - expert_params, non_expert_params = split_expert_and_non_expert_param_names(remote_weight_info.keys()) + session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] + expert_params, non_expert_params = split_expert_and_non_expert_param_names( + self.remote_weight_infos_by_session_id[session_id].keys() + ) + params = expert_params if target.group == "expert" else non_expert_params self.transfer_plan.add_transfer_task( - session=remote_session_id, - remote_tensor_names=expert_params if target.group == "expert" else non_expert_params, + session=session_id, param_group=target.group, ) + logger.info( + f"Added transfer task for session {session_id} with {len(params)} tensors in group {target.group}." + ) + + def leader_post_update(self) -> None: + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + ray.get( + [ + engine.update_weight_version.remote(weight_version=self.weight_version) + for engine in self.rollout_engines + ] + ) + return def _update_bucket_weights_from_remote( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + self, converted_named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None ) -> None: """ - Register weights with P2PTrainingTransferEngine and wait for transfers to complete. - Based on lines 518-545 in SGLang test: register_weights pattern. - Overrides parent method to use P2P RDMA instead of NCCL broadcast. + The RDMA P2P weight update is implemented as a single side write, meaning the trainer writes its weights directly to the rollout engines' memory. """ - # TODO(jd): pin the memory for GPUs after resharding to avoid expensive registration. if not self._is_source or not converted_named_tensors: return @@ -130,26 +125,44 @@ def _update_bucket_weights_from_remote( time.sleep(0.1) try: - # Register all weights with the P2P training transfer engine - # This follows the pattern from SGLang test lines 518-537 + # Features still missing for MVP: + # TODO(jd): Implement resharding logic, right now it's not handled. + # TODO: Some model may need target size handling like post_load_weights, currently not handled. + # TODO(jd): Need a correctness test of the model weights similar to the test:https://github.com/sgl-project/sglang/pull/14997/changes#diff-6efab5fd819ef0efa7a1f43989320bb28231702f8840897eb7acacf174f6e71f + + # Potential optimization not implemented: + # TODO: currently implementation does not guarantee single traversal of model dict and submits to mutliple targets in order. + # If there are more targets than source, we are registering/all-gather/resharding/deregistering multiple times for same weights. + # TODO: maybe pin the memory for GPUs instead of register + deregester each time after resharding. + # TODO: increase concurrency with non-blocking transfers somehow. Note the reshaped tensors are temporary. + # TODO: skip the forced all-gather for same shard tensors and instead convert directly. + # TODO: finer granularity weight transfer where a multiple source instance can update a singular target instance. + + _ = register_memory_transfer_engine(converted_named_tensors, self.engine) + + # Verify the 1-to-1 mapping between registered weights and remote weights expected. + source_ptrs, target_ptrs, source_lens = [], [], [] for name, tensor in converted_named_tensors: - self.training_p2p_transfer_engine.register_buffer(name, tensor) - - # Initiate weight transfer to all rollout engines as per the transfer plan - refs = [ - engine.update_weights_from_distributed.remote( - names=[name for name, _ in converted_named_tensors], - dtypes=[param.dtype for _, param in converted_named_tensors], - shapes=[param.shape for _, param in converted_named_tensors], - group_name=self._group_name, - weight_version=str(self.weight_version), - session_id=f"{self.master_addr}:{self.master_port}", # Pass P2P session info - ) - for engine in self.rollout_engines - ] - - # Wait for all P2P transfers to complete - ray.get(refs) + if name not in self.remote_weight_infos_by_session_id[session_id]: + raise RuntimeError( + f"Registered weight {name} not found in remote weight info for session {session_id}." + ) + remote_ptr, remote_numel, remote_element_size = self.remote_weight_infos_by_session_id[session_id][ + name + ] + if tensor.numel() != remote_numel or tensor.element_size() != remote_element_size: + raise RuntimeError( + f"Registered weight {name} numel {tensor.numel()} size {tensor.element_size()} does not match remote numel {remote_numel} size {remote_element_size}." + ) + source_ptrs.append(tensor.data_ptr()) + target_ptrs.append(remote_ptr) + source_lens.append(tensor.numel() * tensor.element_size()) + + # Batch transfer weights through RDMA + ret = self.transfer_engine.batch_transfer_sync_write(session_id, source_ptrs, target_ptrs, source_lens) + if ret < 0: + raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") + self.transfer_engine.deregister_memory_batch(source_ptrs) converted_named_tensors.clear() finally: diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py index 472f5cf887..af8f096c77 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py @@ -1,7 +1,7 @@ from abc import abstractmethod from argparse import Namespace from collections.abc import Callable, Mapping, Sequence - +from typing import Literal import ray import torch import torch.distributed as dist @@ -13,7 +13,7 @@ from slime.utils.timer import timer from ..megatron_to_hf import convert_to_hf -from .common import all_gather_param, named_params_and_buffers +from .common import all_gather_param from .remote_transfer_plan import RemoteTransferPlan @@ -31,6 +31,7 @@ def __init__( *, model_name: str, quantization_config: dict[str, int | str | list[str]] | None, + weight_update_mode: Literal["nccl", "rdma"] = "nccl", ) -> None: """ Initialize. Groups created in connect_rollout_engines. @@ -40,7 +41,8 @@ def __init__( self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 - self.transfer_plan = RemoteTransferPlan(args, model, args.update_weight_transfer_mode) + self.transfer_plan = RemoteTransferPlan(args, model, weight_update_mode) + self._is_source = self.transfer_plan.is_source() @abstractmethod def connect_rollout_engines( @@ -52,10 +54,12 @@ def connect_rollout_engines( @abstractmethod def _update_bucket_weights_from_remote( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + self, converted_named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None ) -> None: """ - Implementation of the bucketed parameter update from remote. + Implementation of the bucketed parameter update from remote. session_id is used as the identifier + for the operation, either NCCL group name or Transfer Engine session id. + TODO(jd): to avoid traversing the model dict multiple times we need session_id to be a list. """ @torch.no_grad() @@ -69,49 +73,65 @@ def update_weights(self) -> None: ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) - with timer("update_weights_implementation"): - buffer_size = 0 - converted_named_tensors = [] - # non expert params - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_source else None - - for name, param in named_params_and_buffers(self.args, self.model): - # transfer tp tensors - if name not in self._tensor_names or ".experts." in name: - continue - buffer_size = self._update_weight_from_remote( - name, param, converted_named_tensors, buffer_size, pbar=pbar - ) - - if converted_named_tensors: - self._update_bucket_weights_from_remote(converted_named_tensors, pbar=pbar) - - dist.barrier(group=get_gloo_group()) - buffer_size = 0 - named_tensors = [] - for name, param in named_params_and_buffers(self.args, self.model): - # transfer expert tensors - if name not in self._tensor_names or ".experts." not in name: - continue - buffer_size = self._update_expert_weight_from_remote( - name, param, named_tensors, buffer_size, pbar=pbar - ) - - if named_tensors: - self._update_expert_bucket_weights_from_remote(named_tensors, pbar=pbar) + with timer("update_weights_implementation"): + for transfer_task in self.transfer_plan.get_transfer_tasks(): + # Update non-expert or expert weights + if transfer_task.tensor_type == "non-expert": + self._update_weights(transfer_task.named_params_and_buffers, transfer_task.session) + elif transfer_task.tensor_type == "expert": + self._update_expert_weights(transfer_task.named_params_and_buffers, transfer_task.session) + else: + raise ValueError(f"Unknown tensor type {transfer_task.tensor_type} in transfer task.") + dist.barrier(group=get_gloo_group()) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + self.leader_post_update() dist.barrier(group=get_gloo_group()) + def leader_post_update(self) -> None: + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + return + + def _update_expert_weights( + self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]], session_id: str + ) -> None: + pbar = tqdm(desc=f"[{session_id}] Update Expert Weights", total=0) if self._is_source else None + buffer_size = 0 + named_tensors = [] + for name, param in named_params_and_buffers: + # transfer expert tensors + assert ".experts." in name, "Function intended for expert params only." + buffer_size = self._update_expert_weight_from_remote( + name, param, named_tensors, buffer_size, session_id, pbar=pbar + ) + + if named_tensors: + self._update_expert_bucket_weights_from_remote(named_tensors, session_id, pbar=pbar) + + def _update_weights(self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]], session_id: str) -> None: + pbar = tqdm(desc=f"[{session_id}] Update Weights", total=0) if self._is_source else None + buffer_size = 0 + converted_named_tensors = [] + # non expert params + for name, param in named_params_and_buffers: + # transfer tp tensors + assert ".experts." not in name, "Function intended for non-expert params only." + buffer_size = self._update_weight_from_remote( + name, param, converted_named_tensors, buffer_size, session_id, pbar=pbar + ) + + if converted_named_tensors: + self._update_bucket_weights_from_remote(converted_named_tensors, session_id, pbar=pbar) + def _update_weight_from_remote( self, name: str, param: torch.nn.Parameter, converted_named_tensors: list[tuple[str, torch.Tensor]], buffer_size: int, + session_id: str, pbar: tqdm | None = None, ) -> int | None: """ @@ -124,7 +144,7 @@ def _update_weight_from_remote( param_size = param.numel() * param.element_size() if buffer_size + param_size > self.args.update_weight_buffer_size: - self._update_bucket_weights_from_remote(converted_named_tensors, pbar=pbar) + self._update_bucket_weights_from_remote(converted_named_tensors, session_id, pbar=pbar) buffer_size = 0 converted_named_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) buffer_size += param_size @@ -136,6 +156,7 @@ def _update_expert_weight_from_remote( param: torch.nn.Parameter, named_tensors: list[tuple[str, torch.Tensor]], buffer_size: int, + session_id: str, pbar: tqdm | None = None, ) -> int: """ @@ -147,7 +168,7 @@ def _update_expert_weight_from_remote( if ( buffer_size + param_size ) * mpu.get_expert_model_parallel_world_size() > self.args.update_weight_buffer_size and named_tensors: - self._update_expert_bucket_weights_from_remote(named_tensors, pbar=pbar) + self._update_expert_bucket_weights_from_remote(named_tensors, session_id, pbar=pbar) buffer_size = 0 named_tensors.append((name, param)) @@ -155,7 +176,7 @@ def _update_expert_weight_from_remote( return buffer_size def _update_expert_bucket_weights_from_remote( - self, named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + self, named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None ) -> None: """ Gather EP → HF → broadcast. Clears buffer. @@ -190,4 +211,4 @@ def _update_expert_bucket_weights_from_remote( for name, param in all_gathered_params: converted_hf_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - self._update_bucket_weights_from_remote(converted_hf_tensors, pbar) + self._update_bucket_weights_from_remote(converted_hf_tensors, session_id, pbar) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index c2c12d06fa..5b7adae347 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -350,6 +350,14 @@ def continue_generation(self): response.raise_for_status() return response + def update_weight_version(self, weight_version: str): + response = requests.post( + f"http://{self.server_host}:{self.server_port}/update_weight_version", + json={"new_version": weight_version}, + ) + response.raise_for_status() + return response + def start_profile( self, # The output directory diff --git a/tests/test_weight_transfer.py b/tests/test_weight_transfer.py index 443d285a3d..2c92f2971e 100644 --- a/tests/test_weight_transfer.py +++ b/tests/test_weight_transfer.py @@ -1,24 +1,29 @@ -import os - -import slime.utils.external_utils.command_utils as U +from dataclasses import dataclass +from typing import Literal +import typer -TIGHT_HOST_MEMORY = bool(int(os.environ.get("SLIME_TEST_TIGHT_HOST_MEMORY", "1"))) +import slime.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-4B" MODEL_TYPE = "qwen3-4B" NUM_GPUS = 3 +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["nccl", "rdma"] = "nccl" + # TODO: Add diverse parallelism settings; imbalance training/inference instances, etc for benchmark. + + def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") U.hf_download_dataset("zhuzilin/dapo-math-17k") - U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) -def execute(): +def execute(args: ScriptArgs): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " rollout_args = ( @@ -69,6 +74,8 @@ def execute(): ) sglang_args = "--rollout-num-gpus-per-engine 1 " "--rollout-num-gpus 2 " "--sglang-mem-fraction-static 0.8 " + if args.mode == "rdma": + sglang_args += "--remote-instance-weight-loader-support-transfer-engine " # ci_args = "--ci-test " @@ -87,6 +94,8 @@ def execute(): f"--update-weight-buffer-size {1 * 1024 ** 3} " # "--check-weight-update-equal " ) + if args.mode == "rdma": + misc_args += "--update-weight-transfer-mode rdma " train_args = ( f"{ckpt_args} " @@ -109,6 +118,11 @@ def execute(): ) +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + if __name__ == "__main__": - prepare() - execute() + typer.run(main) From f333c519b63f8b886f212abe124162df8b38dfb2 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Tue, 16 Dec 2025 06:44:06 +0000 Subject: [PATCH 04/11] fixing name mismatch issue --- .../megatron_utils/update_weight/common.py | 2 +- .../update_weight/update_weight_from_rdma.py | 39 ++++++++++++------- slime/backends/sglang_utils/sglang_engine.py | 9 +++++ tests/test_weight_transfer.py | 22 +++++++---- 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index 51fa0759ca..63dda9a8b7 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -324,7 +324,7 @@ def register_memory_transfer_engine(named_param_with_buffers: Sequence[tuple[str # Register merged memory blocks that hold weights. for weight_block in weight_blocks_for_reg_mr: address, size = weight_block - ret = engine.register_memory(address, size) + ret = engine.register(address, size) if ret != 0: raise RuntimeError( f"register memory failed for weight block at address {address} with size {size}, error: {ret}" diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 9cdc9f8f63..8b7e845f2a 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -6,7 +6,9 @@ import ray import torch from ray.actor import ActorHandle -from srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine + +# from mooncake.engine import TransferEngine +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from tqdm import tqdm from .common import register_memory_transfer_engine, split_expert_and_non_expert_param_names @@ -35,7 +37,6 @@ def __init__( *, model_name: str, quantization_config: dict[str, int | str | list[str]] | None, - vocab_size: int, ) -> None: """ Initialize transfer engine. @@ -47,7 +48,6 @@ def __init__( weights_getter, model_name=model_name, quantization_config=quantization_config, - vocab_size=vocab_size, weight_update_mode="rdma", ) @@ -68,8 +68,11 @@ def connect_rollout_engines( if self._is_source: # Get master address and port for P2P communication local_ip = ray._private.services.get_node_ip_address() - self.transfer_engine = MooncakeTransferEngine(hostname=local_ip) - logger.info(f"Transfer Engine initialized at {self.transfer_engine.session_id}") + self.transfer_engine = MooncakeTransferEngine(local_ip, None, None) + logger.info(f"[RDMA] Transfer Engine initialized at port {self.transfer_engine.session_id}") + # breakpoint() + # self.transfer_engine = TransferEngine() + # logger.info(f"[RDMA] Transfer Engine initialized at port {self.transfer_engine.get_rpc_port()}") # Query Engine session and weight info from rollout instances according to the transfer plan self.remote_weight_infos_by_session_id = {} @@ -77,11 +80,16 @@ def connect_rollout_engines( targets_to_session_id = {} for engine_ind, engine_rank in targets_to_query: session_id, weights_info = ray.get( - self.rollout_engines[engine_ind].get_remote_instance_transfer_engine_info.remote(rank=engine_rank)[ - "remote_instance_transfer_engine_info" - ] + self.rollout_engines[engine_ind].get_remote_instance_transfer_engine_info.remote(rank=engine_rank) + ) + assert ( + session_id is not None + ), f"Failed to get session id from rollout engine {engine_ind} rank {engine_rank}" + logger.info( + f"[RDMA] Obtained remote {session_id} info from rollout engine {engine_ind} rank {engine_rank}" ) - logger.info(f"Obtained remote session info from rollout engine {engine_ind} rank {engine_rank}") + logger.info(f"[RDMA] Remote weight info has {len(weights_info)} tensors.") + logger.info(list(weights_info.keys())) self.remote_weight_infos_by_session_id[session_id] = weights_info targets_to_session_id[(engine_ind, engine_rank)] = session_id @@ -138,8 +146,11 @@ def _update_bucket_weights_from_remote( # TODO: skip the forced all-gather for same shard tensors and instead convert directly. # TODO: finer granularity weight transfer where a multiple source instance can update a singular target instance. - _ = register_memory_transfer_engine(converted_named_tensors, self.engine) - + _ = register_memory_transfer_engine(converted_named_tensors, self.transfer_engine) + logger.info( + f"[RDMA] Registered {len(converted_named_tensors)} tensors with transfer engine for session {session_id}." + ) + logger.info(f"[RDMA] Transfering {list(name for name, _ in converted_named_tensors)}") # Verify the 1-to-1 mapping between registered weights and remote weights expected. source_ptrs, target_ptrs, source_lens = [], [], [] for name, tensor in converted_named_tensors: @@ -159,10 +170,12 @@ def _update_bucket_weights_from_remote( source_lens.append(tensor.numel() * tensor.element_size()) # Batch transfer weights through RDMA - ret = self.transfer_engine.batch_transfer_sync_write(session_id, source_ptrs, target_ptrs, source_lens) + ret = self.transfer_engine.batch_transfer_sync(session_id, source_ptrs, target_ptrs, source_lens) + logger.info(f"[RDMA] Batch transferred {len(converted_named_tensors)} tensors to session {session_id}.") if ret < 0: raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") - self.transfer_engine.deregister_memory_batch(source_ptrs) + self.transfer_engine.batch_deregister(source_ptrs) + logger.info(f"[RDMA] Batch deregistered {len(converted_named_tensors)} tensors.") converted_named_tensors.clear() finally: diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 5b7adae347..594262829b 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -274,6 +274,15 @@ def shutdown(self): response.raise_for_status() kill_process_tree(self.process.pid) + def get_remote_instance_transfer_engine_info(self, rank: int): + response = requests.get( + f"http://{self.server_host}:{self.server_port}/get_remote_instance_transfer_engine_info", + params={"rank": rank}, + timeout=5.0, + ) + response.raise_for_status() + return response.json()["remote_instance_transfer_engine_info"] + def get_weight_version(self): if self.node_rank != 0: return diff --git a/tests/test_weight_transfer.py b/tests/test_weight_transfer.py index 2c92f2971e..92d039004c 100644 --- a/tests/test_weight_transfer.py +++ b/tests/test_weight_transfer.py @@ -7,23 +7,27 @@ MODEL_NAME = "Qwen3-4B" MODEL_TYPE = "qwen3-4B" -NUM_GPUS = 3 @dataclass class ScriptArgs(U.ExecuteTrainConfig): mode: Literal["nccl", "rdma"] = "nccl" + # Right now tp=ep=pp=1 + num_train_gpus: int = 1 + num_rollout_gpus: int = 1 # TODO: Add diverse parallelism settings; imbalance training/inference instances, etc for benchmark. -def prepare(): +def prepare(args: ScriptArgs): U.exec_command("mkdir -p /root/models /root/datasets") U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") U.hf_download_dataset("zhuzilin/dapo-math-17k") - U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + num_gpus = args.num_train_gpus + args.num_rollout_gpus + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=num_gpus) def execute(args: ScriptArgs): + num_gpus = args.num_train_gpus + args.num_rollout_gpus ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " rollout_args = ( @@ -73,9 +77,13 @@ def execute(args: ScriptArgs): "--adam-beta2 0.98 " ) - sglang_args = "--rollout-num-gpus-per-engine 1 " "--rollout-num-gpus 2 " "--sglang-mem-fraction-static 0.8 " + sglang_args = ( + f"--rollout-num-gpus-per-engine 1 " + f"--rollout-num-gpus {args.num_rollout_gpus} " + "--sglang-mem-fraction-static 0.8 " + ) if args.mode == "rdma": - sglang_args += "--remote-instance-weight-loader-support-transfer-engine " + sglang_args += "--sglang-remote-instance-weight-loader-support-transfer-engine " # ci_args = "--ci-test " @@ -111,10 +119,10 @@ def execute(args: ScriptArgs): U.execute_train( train_args=train_args, - num_gpus_per_node=NUM_GPUS, + num_gpus_per_node=num_gpus, megatron_model_type=MODEL_TYPE, train_script="train_async.py", - # extra_env_vars={"RAY_DEBUG": "1"}, + extra_env_vars={"RAY_DEBUG": "1"}, ) From 9b47b111689b0bee9709e1839118628625e414e1 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Wed, 17 Dec 2025 06:35:35 +0000 Subject: [PATCH 05/11] rework with local copy --- .../megatron_utils/update_weight/common.py | 2 +- .../update_weight/remote_transfer_plan.py | 11 +- .../update_weight/update_weight_from_rdma.py | 170 +++++++++++------- .../update_weight_from_remote.py | 6 +- 4 files changed, 119 insertions(+), 70 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index 63dda9a8b7..01c606e780 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -278,7 +278,7 @@ def _named_params_and_buffers_global( yield f"module.module.decoder.layers.{layer_idx}.{rest}", buffer -def register_memory_transfer_engine(named_param_with_buffers: Sequence[tuple[str, torch.Tensor]], engine) -> None: +def register_memory_transfer_engine(named_param_with_buffers: Sequence[tuple[str, torch.Tensor]], engine) -> dict: """ Efficient memory registration for transfer engine that reduce total registration count by batching continuous memory regions. """ diff --git a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py index 9806b6f9ef..d93e5f4d8f 100644 --- a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py +++ b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py @@ -101,8 +101,8 @@ def _get_parallel_info(self, args: Namespace) -> None: self._rollout_ep_size = args.sglang_ep_size # PP sizes are not supported currently. self._rollout_pp_size = args.sglang_pp_size - if self._rollout_ep_size != 1: - raise NotImplementedError("Rollout expert parallelism is not supported yet.") + if self._rollout_ep_size != 1 or self._rollout_pp_size != 1: + raise NotImplementedError("Rollout expert and pipeline parallelisms are not supported yet.") num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) self._rollout_engine_count = args.rollout_num_gpus // num_gpu_per_engine logger.info( @@ -141,6 +141,8 @@ def plan( for target_ind in range(num_targets): for target_rank in range(num_rank_in_target): if cur_active_rank % source_size == source_rank: + # TODO(jd): instead of doing round robin, we should prioritize reusing source ranks with same target_rank to + # avoid duplicating local copies. transfer_tasks.append( TransferTaskP2PMeta(engine_ind=target_ind, engine_rank=target_rank, group=params) ) @@ -150,6 +152,8 @@ def plan( cur_active_rank += 1 return transfer_tasks + # TODO(JD): Due to the local replica design, the plan should proritize reusing existing copies and merge sources. + non_expert_plan = plan( source_size=self._gathered_dp_size, source_rank=self._gathered_dp_rank, @@ -157,15 +161,12 @@ def plan( num_targets=self._rollout_engine_count, params="non-expert", ) - offset = len(non_expert_plan) - # Offset the current active rank by the number of non-expert transfer tasks to avoid overloading first few ranks. return non_expert_plan + plan( source_size=self._gathered_expert_dp_size, source_rank=self._gathered_expert_dp_rank, num_rank_in_target=self._rollout_dp_size * self._rollout_ep_size, num_targets=self._rollout_engine_count, params="expert", - cur_active_rank=offset, ) def is_source(self) -> bool: diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 8b7e845f2a..973492b243 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -1,16 +1,22 @@ import logging -import time from argparse import Namespace from collections.abc import Callable, Mapping, Sequence import ray import torch from ray.actor import ActorHandle +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import ModelConfig # from mooncake.engine import TransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.model_loader.loader import get_model_loader from tqdm import tqdm +from slime.backends.megatron_utils.update_weight.remote_transfer_plan import TransferTask +from slime.utils.memory_utils import print_memory + from .common import register_memory_transfer_engine, split_expert_and_non_expert_param_names from .update_weight_from_remote import UpdateWeightFromRemote @@ -62,26 +68,21 @@ def connect_rollout_engines( # Store rollout engines and lock self.rollout_engines = rollout_engines self.rollout_engine_lock = rollout_engine_lock - self._is_source = self.transfer_plan.is_source() - # Initialize P2PTrainingTransferEngine on source rank if self._is_source: - # Get master address and port for P2P communication local_ip = ray._private.services.get_node_ip_address() self.transfer_engine = MooncakeTransferEngine(local_ip, None, None) logger.info(f"[RDMA] Transfer Engine initialized at port {self.transfer_engine.session_id}") - # breakpoint() - # self.transfer_engine = TransferEngine() - # logger.info(f"[RDMA] Transfer Engine initialized at port {self.transfer_engine.get_rpc_port()}") # Query Engine session and weight info from rollout instances according to the transfer plan self.remote_weight_infos_by_session_id = {} targets_to_query = set((target.engine_ind, target.engine_rank) for target in self.transfer_plan.targets) - targets_to_session_id = {} + targets_to_session_id, self.session_id_to_engine_rank = {}, {} for engine_ind, engine_rank in targets_to_query: session_id, weights_info = ray.get( self.rollout_engines[engine_ind].get_remote_instance_transfer_engine_info.remote(rank=engine_rank) ) + self.session_id_to_engine_rank[session_id] = engine_rank assert ( session_id is not None ), f"Failed to get session id from rollout engine {engine_ind} rank {engine_rank}" @@ -93,6 +94,11 @@ def connect_rollout_engines( self.remote_weight_infos_by_session_id[session_id] = weights_info targets_to_session_id[(engine_ind, engine_rank)] = session_id + print_memory("[RDMA] After obtaining remote weight info") + + # Local model with identical shape to remote. Create at most one copy per target rank, and link + # them by session id. + self.engines, self.session_id_to_local_replicas = {}, {} # Associate transfer tasks based on obtained session and weight info for target in self.transfer_plan.targets: session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] @@ -107,6 +113,76 @@ def connect_rollout_engines( logger.info( f"Added transfer task for session {session_id} with {len(params)} tensors in group {target.group}." ) + # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. + if target.engine_rank not in self.engines: + model_replica = self._create_inference_replica( + self.args.hf_checkpoint, + target_tp=self.args.rollout_num_gpus_per_engine, + target_rank=target.engine_rank, + ) + transfer_engine = self._create_transfer_engine() + weight_memory_registry = self._register_replica_memory( + model_replica, self.remote_weight_infos_by_session_id[session_id], transfer_engine + ) + self.engines[target.engine_rank] = (model_replica, transfer_engine, weight_memory_registry) + + print_memory("[RDMA] After Local Engine Replicas and engine Creation") + + def _register_replica_memory(self, model_replica, remote_weight_info, transfer_engine) -> dict: + to_register_named_tensors = [] + named_tensors = dict(model_replica.named_parameters()) + # Verify the 1-to-1 mapping between registered weights and remote weights expected. + for name, (_, remote_numel, remote_ele_size) in remote_weight_info: + if name not in named_tensors: + raise RuntimeError(f"Remote replica parameter {name} not found in local replica.") + tensor = named_tensors[name] + if tensor.numel() != remote_numel or tensor.element_size() != remote_ele_size: + raise RuntimeError( + f"Local replica parameter {name} numel {tensor.numel()} size {tensor.element_size()} does not match remote numel {remote_numel} size {remote_ele_size}." + ) + to_register_named_tensors.append((name, tensor)) + weight_memory_registry = register_memory_transfer_engine(to_register_named_tensors, transfer_engine) + logger.info( + f"[RDMA] Registered {len(to_register_named_tensors)} tensors of total {len(named_tensors)} from replica with transfer engine." + ) + return weight_memory_registry + + def _create_transfer_engine(self) -> MooncakeTransferEngine: + local_ip = ray._private.services.get_node_ip_address() + transfer_engine = MooncakeTransferEngine(local_ip, None, None) + logger.info(f"[RDMA] Local replica Transfer Engine initialized at port {transfer_engine.session_id}") + return transfer_engine + + def _create_inference_replica(self, model_path, target_tp, target_rank): + # Create model replica for target rank with correct tp settings. + # FIXME: how to validate this works? + model_config = ModelConfig(model_path) + loader = get_model_loader( + load_config=LoadConfig(load_format="direct", tp_rank=target_rank), + model_config=model_config, + ) + return loader.load_model( + model_config=model_config, + device_config=DeviceConfig(self.device, self.gpu_id), + ) + + def _execute_transfer(self, session_id: str) -> None: + """ + Execute weight transfer for a single transfer task using RDMA P2P transfer engine. + """ + _, engine, weight_memory_registry = self.engines[self.session_id_to_engine_rank[session_id]] + remote_weight_info = self.remote_weight_infos_by_session_id[session_id] + source_ptrs, target_ptrs, source_lens = [], [], [] + for name, tensor in weight_memory_registry.items(): + source_ptrs.append(tensor.data_ptr()) + target_ptrs.append(remote_weight_info[name][0]) # remote address + source_lens.append(tensor.numel() * tensor.element_size()) + + # Batch transfer weights through RDMA + ret = engine.batch_transfer_sync(session_id, source_ptrs, target_ptrs, source_lens) + logger.info(f"[RDMA] Batch transferred {len(weight_memory_registry)} tensors to session {session_id}.") + if ret < 0: + raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") def leader_post_update(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) @@ -128,58 +204,26 @@ def _update_bucket_weights_from_remote( if not self._is_source or not converted_named_tensors: return - # Lock the rollout engines to prevent concurrent operations (same as parent) - while not ray.get(self.rollout_engine_lock.acquire.remote()): - time.sleep(0.1) - - try: - # Features still missing for MVP: - # TODO(jd): Implement resharding logic, right now it's not handled. - # TODO: Some model may need target size handling like post_load_weights, currently not handled. - # TODO(jd): Need a correctness test of the model weights similar to the test:https://github.com/sgl-project/sglang/pull/14997/changes#diff-6efab5fd819ef0efa7a1f43989320bb28231702f8840897eb7acacf174f6e71f - - # Potential optimization not implemented: - # TODO: currently implementation does not guarantee single traversal of model dict and submits to mutliple targets in order. - # If there are more targets than source, we are registering/all-gather/resharding/deregistering multiple times for same weights. - # TODO: maybe pin the memory for GPUs instead of register + deregester each time after resharding. - # TODO: increase concurrency with non-blocking transfers somehow. Note the reshaped tensors are temporary. - # TODO: skip the forced all-gather for same shard tensors and instead convert directly. - # TODO: finer granularity weight transfer where a multiple source instance can update a singular target instance. - - _ = register_memory_transfer_engine(converted_named_tensors, self.transfer_engine) - logger.info( - f"[RDMA] Registered {len(converted_named_tensors)} tensors with transfer engine for session {session_id}." - ) - logger.info(f"[RDMA] Transfering {list(name for name, _ in converted_named_tensors)}") - # Verify the 1-to-1 mapping between registered weights and remote weights expected. - source_ptrs, target_ptrs, source_lens = [], [], [] - for name, tensor in converted_named_tensors: - if name not in self.remote_weight_infos_by_session_id[session_id]: - raise RuntimeError( - f"Registered weight {name} not found in remote weight info for session {session_id}." - ) - remote_ptr, remote_numel, remote_element_size = self.remote_weight_infos_by_session_id[session_id][ - name - ] - if tensor.numel() != remote_numel or tensor.element_size() != remote_element_size: - raise RuntimeError( - f"Registered weight {name} numel {tensor.numel()} size {tensor.element_size()} does not match remote numel {remote_numel} size {remote_element_size}." - ) - source_ptrs.append(tensor.data_ptr()) - target_ptrs.append(remote_ptr) - source_lens.append(tensor.numel() * tensor.element_size()) - - # Batch transfer weights through RDMA - ret = self.transfer_engine.batch_transfer_sync(session_id, source_ptrs, target_ptrs, source_lens) - logger.info(f"[RDMA] Batch transferred {len(converted_named_tensors)} tensors to session {session_id}.") - if ret < 0: - raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") - self.transfer_engine.batch_deregister(source_ptrs) - logger.info(f"[RDMA] Batch deregistered {len(converted_named_tensors)} tensors.") - converted_named_tensors.clear() - - finally: - # Release the lock (same as parent) - ray.get(self.rollout_engine_lock.release.remote()) - if pbar: - pbar.update(1) + # Refactoring needed: + # TODO: refactor update_weight to still do a single traversal of the model dict; session_id should be per weight instead. + # TODO: There is probably enough difference to the UpdateFromNCCL that we should just rebuild from scratch maybe? + # Functionality missing: + # TODO: Fix learner PP, right now we still send all weights from any source. + # TODO: Support engine expert parallel + # TODO: Extensive tests on different pp/ep/tp settings --> tp/dp/ep settings. + # TODO: Need a correctness test of the model weights similar to the test:https://github.com/sgl-project/sglang/pull/14997/changes#diff-6efab5fd819ef0efa7a1f43989320bb28231702f8840897eb7acacf174f6e71f + # TODO: Memory profiling. + # TODO: Design of experiments --- what other configurations do we need to enable. + # Optimizations: + # TODO: remote transfer plan optimizes for reduce local copy memory usage + # TODO: pipeline the all-gather/reshard with transfer engine transfer calls for performance. + # TODO: increase concurrency with non-blocking transfers to multiple targets. + # TODO: memory offloading if the replica becomes a bottleneck. + + # Load weights into local replica matching the target session, this handles sharding and reshaping. + self.session_id_to_local_replicas[session_id].load_weights(converted_named_tensors) + converted_named_tensors.clear() + + def finish_transfer_task(self, task: TransferTask) -> None: + self._execute_transfer(task.session) + return diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py index af8f096c77..7b9bc14fb0 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py @@ -14,7 +14,7 @@ from ..megatron_to_hf import convert_to_hf from .common import all_gather_param -from .remote_transfer_plan import RemoteTransferPlan +from .remote_transfer_plan import RemoteTransferPlan, TransferTask class UpdateWeightFromRemote: @@ -84,6 +84,7 @@ def update_weights(self) -> None: else: raise ValueError(f"Unknown tensor type {transfer_task.tensor_type} in transfer task.") dist.barrier(group=get_gloo_group()) + self.finish_transfer_task(transfer_task.session) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: @@ -94,6 +95,9 @@ def leader_post_update(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) return + def finish_transfer_task(self, task: TransferTask) -> None: + return + def _update_expert_weights( self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]], session_id: str ) -> None: From dd5a2cec368c49d3702155288fe9b9487ecc1110 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Thu, 18 Dec 2025 06:44:01 +0000 Subject: [PATCH 06/11] fail on engine registration --- .../megatron_utils/update_weight/common.py | 14 ++ .../update_weight/update_weight_from_rdma.py | 137 +++++++++++++++--- slime/backends/sglang_utils/sglang_engine.py | 8 + 3 files changed, 137 insertions(+), 22 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index 01c606e780..2f9b0bde5e 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -330,3 +330,17 @@ def register_memory_transfer_engine(named_param_with_buffers: Sequence[tuple[str f"register memory failed for weight block at address {address} with size {size}, error: {ret}" ) return weight_mr_dict + + +def register_memory_region_v1(named_param_with_buffers: Sequence[tuple[str, torch.Tensor]], transfer_engine): + weight_mr_dict = {} + for name, weight in named_param_with_buffers: + ret = transfer_engine.register(weight.data_ptr(), weight.numel() * weight.element_size()) + if ret != 0: + raise RuntimeError(f"register memory failed for weight {name}, error: {ret}") + weight_mr_dict[name] = ( + weight.data_ptr(), + weight.numel(), + weight.element_size(), + ) + return weight_mr_dict diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 973492b243..5ada388479 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -1,17 +1,21 @@ +import dataclasses import logging from argparse import Namespace from collections.abc import Callable, Mapping, Sequence import ray +import sglang.srt.layers.dp_attention as sglang_dp_attention +import sglang.srt.server_args as sglang_server_args import torch from ray.actor import ActorHandle from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.model_loader import get_model # from mooncake.engine import TransferEngine -from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine -from sglang.srt.model_loader.loader import get_model_loader +from sglang.srt.server_args import ServerArgs from tqdm import tqdm from slime.backends.megatron_utils.update_weight.remote_transfer_plan import TransferTask @@ -23,6 +27,12 @@ logger = logging.getLogger(__name__) +def create_server_args_from_dict(data_dict: dict) -> ServerArgs: + valid_fields = {f.name for f in dataclasses.fields(ServerArgs)} + filtered_data = {k: v for k, v in data_dict.items() if k in valid_fields} + return ServerArgs(**filtered_data) + + class UpdateWeightFromRDMA(UpdateWeightFromRemote): """ Update weights from RDMA using Transfer Engine. @@ -70,19 +80,19 @@ def connect_rollout_engines( self.rollout_engine_lock = rollout_engine_lock if self._is_source: - local_ip = ray._private.services.get_node_ip_address() - self.transfer_engine = MooncakeTransferEngine(local_ip, None, None) - logger.info(f"[RDMA] Transfer Engine initialized at port {self.transfer_engine.session_id}") - # Query Engine session and weight info from rollout instances according to the transfer plan self.remote_weight_infos_by_session_id = {} targets_to_query = set((target.engine_ind, target.engine_rank) for target in self.transfer_plan.targets) targets_to_session_id, self.session_id_to_engine_rank = {}, {} + self.session_id_to_server_args = {} for engine_ind, engine_rank in targets_to_query: session_id, weights_info = ray.get( self.rollout_engines[engine_ind].get_remote_instance_transfer_engine_info.remote(rank=engine_rank) ) self.session_id_to_engine_rank[session_id] = engine_rank + self.session_id_to_server_args[session_id] = create_server_args_from_dict( + ray.get(self.rollout_engines[engine_ind].get_server_info.remote()) + ) assert ( session_id is not None ), f"Failed to get session id from rollout engine {engine_ind} rank {engine_rank}" @@ -90,7 +100,7 @@ def connect_rollout_engines( f"[RDMA] Obtained remote {session_id} info from rollout engine {engine_ind} rank {engine_rank}" ) logger.info(f"[RDMA] Remote weight info has {len(weights_info)} tensors.") - logger.info(list(weights_info.keys())) + # logger.info(list(weights_info.keys())) self.remote_weight_infos_by_session_id[session_id] = weights_info targets_to_session_id[(engine_ind, engine_rank)] = session_id @@ -115,12 +125,14 @@ def connect_rollout_engines( ) # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. if target.engine_rank not in self.engines: + transfer_engine = self._create_transfer_engine() model_replica = self._create_inference_replica( self.args.hf_checkpoint, - target_tp=self.args.rollout_num_gpus_per_engine, target_rank=target.engine_rank, + target_tp=self.args.rollout_num_gpus_per_engine, + server_args=self.session_id_to_server_args[session_id], ) - transfer_engine = self._create_transfer_engine() + print_memory("[RDMA] After model replica") weight_memory_registry = self._register_replica_memory( model_replica, self.remote_weight_infos_by_session_id[session_id], transfer_engine ) @@ -132,7 +144,8 @@ def _register_replica_memory(self, model_replica, remote_weight_info, transfer_e to_register_named_tensors = [] named_tensors = dict(model_replica.named_parameters()) # Verify the 1-to-1 mapping between registered weights and remote weights expected. - for name, (_, remote_numel, remote_ele_size) in remote_weight_info: + for name, info in remote_weight_info.items(): + (_, remote_numel, remote_ele_size) = info if name not in named_tensors: raise RuntimeError(f"Remote replica parameter {name} not found in local replica.") tensor = named_tensors[name] @@ -140,6 +153,8 @@ def _register_replica_memory(self, model_replica, remote_weight_info, transfer_e raise RuntimeError( f"Local replica parameter {name} numel {tensor.numel()} size {tensor.element_size()} does not match remote numel {remote_numel} size {remote_ele_size}." ) + if tensor.device.type != "cuda": + raise RuntimeError(f"Local replica parameter {name} is not on CUDA device.") to_register_named_tensors.append((name, tensor)) weight_memory_registry = register_memory_transfer_engine(to_register_named_tensors, transfer_engine) logger.info( @@ -153,18 +168,27 @@ def _create_transfer_engine(self) -> MooncakeTransferEngine: logger.info(f"[RDMA] Local replica Transfer Engine initialized at port {transfer_engine.session_id}") return transfer_engine - def _create_inference_replica(self, model_path, target_tp, target_rank): - # Create model replica for target rank with correct tp settings. - # FIXME: how to validate this works? + def _create_inference_replica(self, model_path: str, target_rank: int, target_tp: int, server_args: ServerArgs): + """ + Create model replica for target rank with correct tp settings. + + Uses MockSglangDistributedContext to avoid initializing actual distributed environment + while ensuring the model weights have the correct shape for the target rank. + """ model_config = ModelConfig(model_path) - loader = get_model_loader( - load_config=LoadConfig(load_format="direct", tp_rank=target_rank), - model_config=model_config, - ) - return loader.load_model( - model_config=model_config, - device_config=DeviceConfig(self.device, self.gpu_id), - ) + load_config = LoadConfig(load_format="auto") + device_config = DeviceConfig() + + # Mock the distributed environment to get correct weight shapes + with MockSglangDistributedContext(tp_size=target_tp, tp_rank=target_rank, server_args=server_args): + model = get_model( + model_config=model_config, + load_config=load_config, + device_config=device_config, + ) + device = next(model.parameters()).device + logger.info(f" Model {device}, params: {sum(p.numel() for p in model.parameters())} ") + return model def _execute_transfer(self, session_id: str) -> None: """ @@ -209,7 +233,7 @@ def _update_bucket_weights_from_remote( # TODO: There is probably enough difference to the UpdateFromNCCL that we should just rebuild from scratch maybe? # Functionality missing: # TODO: Fix learner PP, right now we still send all weights from any source. - # TODO: Support engine expert parallel + # TODO: Support engine expert parallel, which has a bunch of details like dp_attnetion_tp etc. # TODO: Extensive tests on different pp/ep/tp settings --> tp/dp/ep settings. # TODO: Need a correctness test of the model weights similar to the test:https://github.com/sgl-project/sglang/pull/14997/changes#diff-6efab5fd819ef0efa7a1f43989320bb28231702f8840897eb7acacf174f6e71f # TODO: Memory profiling. @@ -219,6 +243,8 @@ def _update_bucket_weights_from_remote( # TODO: pipeline the all-gather/reshard with transfer engine transfer calls for performance. # TODO: increase concurrency with non-blocking transfers to multiple targets. # TODO: memory offloading if the replica becomes a bottleneck. + # Question: + # 1. Do we really want to support sglang pipeline paralell? # Load weights into local replica matching the target session, this handles sharding and reshaping. self.session_id_to_local_replicas[session_id].load_weights(converted_named_tensors) @@ -227,3 +253,70 @@ def _update_bucket_weights_from_remote( def finish_transfer_task(self, task: TransferTask) -> None: self._execute_transfer(task.session) return + + +class MockSglangDistributedContext: + def __init__(self, tp_size: int, tp_rank: int, server_args: ServerArgs): + """ + TODO: Extend this to support ep, and dp attention? + """ + self.tp_size = tp_size + self.tp_rank = tp_rank + self.pp_size = 1 + self.pp_rank = 0 + self.attn_tp_size = tp_size + self.attn_tp_rank = tp_rank + self.server_args = server_args + # Store active patches for cleanup + self._patches = [] + + def __enter__(self): + """Apply function-level mocks using unittest.mock.patch.""" + from unittest.mock import MagicMock, patch + + # Mock TP group + mock_group = MagicMock() + mock_group.world_size = self.tp_size + mock_group.rank_in_group = self.tp_rank + + # Mock PP group with proper attributes + mock_pp_group = MagicMock() + mock_pp_group.rank_in_group = self.pp_rank + mock_pp_group.world_size = self.pp_size + # Mock underlying global variables + sglang_server_args._global_server_args = self.server_args + sglang_dp_attention._ATTN_TP_RANK = self.attn_tp_rank + sglang_dp_attention._ATTN_TP_SIZE = self.attn_tp_size + sglang_dp_attention._ATTN_DP_RANK = None + sglang_dp_attention._ATTN_DP_SIZE = 1 + # Mock parallelism getter + self._patches = [ + patch("sglang.srt.distributed.parallel_state.get_tp_group", return_value=mock_group), + patch("sglang.srt.distributed.get_pp_group", return_value=mock_pp_group), + patch( + "sglang.srt.distributed.parallel_state.get_tensor_model_parallel_world_size", return_value=self.tp_size + ), + patch("sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank", return_value=self.tp_rank), + patch("sglang.srt.layers.dp_attention.get_attention_tp_rank", return_value=self.attn_tp_rank), + patch("sglang.srt.layers.dp_attention.get_attention_tp_size", return_value=self.attn_tp_size), + ] + + for p in self._patches: + p.start() + + logger.info( + f"[MockDist] Activated: TP={self.tp_rank}/{self.tp_size}, " + f"PP={self.pp_rank}/{self.pp_size}, AttnTP={self.attn_tp_rank}/{self.attn_tp_size}" + ) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all patches and restore original functions.""" + for p in self._patches: + p.stop() + sglang_server_args._global_server_args = None + self._patches.clear() + logger.info("[MockDist] Deactivated") + + return False # Don't suppress exceptions diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 594262829b..ee2db84d75 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -283,6 +283,14 @@ def get_remote_instance_transfer_engine_info(self, rank: int): response.raise_for_status() return response.json()["remote_instance_transfer_engine_info"] + def get_server_info(self): + response = requests.get( + f"http://{self.server_host}:{self.server_port}/server_info", + timeout=5.0, + ) + response.raise_for_status() + return response.json() + def get_weight_version(self): if self.node_rank != 0: return From bddc95dc93361e564664eb40cb14abe75c72d3a4 Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Thu, 18 Dec 2025 19:05:04 +0000 Subject: [PATCH 07/11] runs --- .../megatron_utils/update_weight/common.py | 2 +- .../update_weight/remote_transfer_plan.py | 9 +-- .../update_weight/update_weight_from_rdma.py | 63 +++++++++++-------- .../update_weight_from_remote.py | 4 +- slime/backends/sglang_utils/sglang_engine.py | 8 +-- 5 files changed, 47 insertions(+), 39 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index 2f9b0bde5e..9eaf017824 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -324,7 +324,7 @@ def register_memory_transfer_engine(named_param_with_buffers: Sequence[tuple[str # Register merged memory blocks that hold weights. for weight_block in weight_blocks_for_reg_mr: address, size = weight_block - ret = engine.register(address, size) + ret = engine.register_memory(address, size) if ret != 0: raise RuntimeError( f"register memory failed for weight block at address {address} with size {size}, error: {ret}" diff --git a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py index d93e5f4d8f..23598429c4 100644 --- a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py +++ b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py @@ -189,10 +189,11 @@ def add_transfer_task(self, session: str, param_group: Literal["expert", "non-ex Add a transfer task to the plan using remote instance session and tensor names. """ params = self.non_expert_params_buffers if param_group == "non-expert" else self.expert_params_buffers - self.transfer_tasks.append( - TransferTask(session=session, named_params_and_buffers=params, tensor_type=param_group) - ) - logger.info(f"Added {param_group} parameter transfer task: session={session}, num_tensors={len(params)}") + if params: + self.transfer_tasks.append( + TransferTask(session=session, named_params_and_buffers=params, tensor_type=param_group) + ) + logger.info(f"Added {param_group} parameter transfer task: session={session}, num_tensors={len(params)}") def clear_transfer_tasks(self) -> None: self.transfer_tasks = [] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 5ada388479..7acd9ed30d 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -1,5 +1,6 @@ import dataclasses import logging +import os from argparse import Namespace from collections.abc import Callable, Mapping, Sequence @@ -7,21 +8,20 @@ import sglang.srt.layers.dp_attention as sglang_dp_attention import sglang.srt.server_args as sglang_server_args import torch +from mooncake.engine import TransferEngine from ray.actor import ActorHandle from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine -from sglang.srt.model_loader import get_model -# from mooncake.engine import TransferEngine +# from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs from tqdm import tqdm -from slime.backends.megatron_utils.update_weight.remote_transfer_plan import TransferTask from slime.utils.memory_utils import print_memory -from .common import register_memory_transfer_engine, split_expert_and_non_expert_param_names +from .common import register_memory_transfer_engine from .update_weight_from_remote import UpdateWeightFromRemote logger = logging.getLogger(__name__) @@ -108,21 +108,14 @@ def connect_rollout_engines( # Local model with identical shape to remote. Create at most one copy per target rank, and link # them by session id. - self.engines, self.session_id_to_local_replicas = {}, {} + self.engines = {} # Associate transfer tasks based on obtained session and weight info for target in self.transfer_plan.targets: session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] - expert_params, non_expert_params = split_expert_and_non_expert_param_names( - self.remote_weight_infos_by_session_id[session_id].keys() - ) - params = expert_params if target.group == "expert" else non_expert_params self.transfer_plan.add_transfer_task( session=session_id, param_group=target.group, ) - logger.info( - f"Added transfer task for session {session_id} with {len(params)} tensors in group {target.group}." - ) # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. if target.engine_rank not in self.engines: transfer_engine = self._create_transfer_engine() @@ -141,6 +134,9 @@ def connect_rollout_engines( print_memory("[RDMA] After Local Engine Replicas and engine Creation") def _register_replica_memory(self, model_replica, remote_weight_info, transfer_engine) -> dict: + + old_cuda_alloc_value = os.environ.get("PYTORCH_ALLOC_CONF", "") + os.environ["PYTORCH_ALLOC_CONF"] = "" to_register_named_tensors = [] named_tensors = dict(model_replica.named_parameters()) # Verify the 1-to-1 mapping between registered weights and remote weights expected. @@ -157,15 +153,21 @@ def _register_replica_memory(self, model_replica, remote_weight_info, transfer_e raise RuntimeError(f"Local replica parameter {name} is not on CUDA device.") to_register_named_tensors.append((name, tensor)) weight_memory_registry = register_memory_transfer_engine(to_register_named_tensors, transfer_engine) + + os.environ["PYTORCH_ALLOC_CONF"] = old_cuda_alloc_value logger.info( f"[RDMA] Registered {len(to_register_named_tensors)} tensors of total {len(named_tensors)} from replica with transfer engine." ) return weight_memory_registry - def _create_transfer_engine(self) -> MooncakeTransferEngine: + def _create_transfer_engine(self) -> TransferEngine: + # local_ip = ray._private.services.get_node_ip_address() + # transfer_engine = MooncakeTransferEngine(local_ip, None, None) + transfer_engine = TransferEngine() local_ip = ray._private.services.get_node_ip_address() - transfer_engine = MooncakeTransferEngine(local_ip, None, None) - logger.info(f"[RDMA] Local replica Transfer Engine initialized at port {transfer_engine.session_id}") + transfer_engine.initialize(local_ip, "P2PHANDSHAKE", "rdma", "") + + logger.info(f"[RDMA] Local replica Transfer Engine initialized at port {transfer_engine.get_rpc_port()}") return transfer_engine def _create_inference_replica(self, model_path: str, target_rank: int, target_tp: int, server_args: ServerArgs): @@ -176,10 +178,17 @@ def _create_inference_replica(self, model_path: str, target_rank: int, target_tp while ensuring the model weights have the correct shape for the target rank. """ model_config = ModelConfig(model_path) - load_config = LoadConfig(load_format="auto") + load_config = LoadConfig( + load_format="auto", + tp_rank=target_rank, + model_loader_extra_config=server_args.model_loader_extra_config, + rl_quant_profile=server_args.rl_quant_profile, + ) device_config = DeviceConfig() # Mock the distributed environment to get correct weight shapes + # TODO: Reuse https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py#L845 + # For memory pinning and CPU offloading? with MockSglangDistributedContext(tp_size=target_tp, tp_rank=target_rank, server_args=server_args): model = get_model( model_config=model_config, @@ -197,14 +206,15 @@ def _execute_transfer(self, session_id: str) -> None: _, engine, weight_memory_registry = self.engines[self.session_id_to_engine_rank[session_id]] remote_weight_info = self.remote_weight_infos_by_session_id[session_id] source_ptrs, target_ptrs, source_lens = [], [], [] - for name, tensor in weight_memory_registry.items(): - source_ptrs.append(tensor.data_ptr()) + for name, tensor_register in weight_memory_registry.items(): + data_ptr, numel, ele_size = tensor_register + source_ptrs.append(data_ptr) target_ptrs.append(remote_weight_info[name][0]) # remote address - source_lens.append(tensor.numel() * tensor.element_size()) + source_lens.append(numel * ele_size) # Batch transfer weights through RDMA - ret = engine.batch_transfer_sync(session_id, source_ptrs, target_ptrs, source_lens) - logger.info(f"[RDMA] Batch transferred {len(weight_memory_registry)} tensors to session {session_id}.") + ret = engine.batch_transfer_sync_write(session_id, source_ptrs, target_ptrs, source_lens) + # logger.info(f"[RDMA] Batch transferred {len(weight_memory_registry)} tensors to session {session_id}.") if ret < 0: raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") @@ -212,7 +222,7 @@ def leader_post_update(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) ray.get( [ - engine.update_weight_version.remote(weight_version=self.weight_version) + engine.update_weight_version.remote(weight_version=str(self.weight_version)) for engine in self.rollout_engines ] ) @@ -247,11 +257,11 @@ def _update_bucket_weights_from_remote( # 1. Do we really want to support sglang pipeline paralell? # Load weights into local replica matching the target session, this handles sharding and reshaping. - self.session_id_to_local_replicas[session_id].load_weights(converted_named_tensors) + self.engines[self.session_id_to_engine_rank[session_id]][0].load_weights(converted_named_tensors) converted_named_tensors.clear() - def finish_transfer_task(self, task: TransferTask) -> None: - self._execute_transfer(task.session) + def finish_transfer_task(self, session: str) -> None: + self._execute_transfer(session) return @@ -318,5 +328,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): sglang_server_args._global_server_args = None self._patches.clear() logger.info("[MockDist] Deactivated") - return False # Don't suppress exceptions diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py index 7b9bc14fb0..9773cdf62d 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py @@ -14,7 +14,7 @@ from ..megatron_to_hf import convert_to_hf from .common import all_gather_param -from .remote_transfer_plan import RemoteTransferPlan, TransferTask +from .remote_transfer_plan import RemoteTransferPlan class UpdateWeightFromRemote: @@ -95,7 +95,7 @@ def leader_post_update(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) return - def finish_transfer_task(self, task: TransferTask) -> None: + def finish_transfer_task(self, session: str) -> None: return def _update_expert_weights( diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index ee2db84d75..3cfff3aefd 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -368,12 +368,10 @@ def continue_generation(self): return response def update_weight_version(self, weight_version: str): - response = requests.post( - f"http://{self.server_host}:{self.server_port}/update_weight_version", - json={"new_version": weight_version}, + return self._make_request( + "update_weight_version", + {"new_version": weight_version}, ) - response.raise_for_status() - return response def start_profile( self, From c270f61bb8b552f2b39a5ffa6a4bcd1afda476ae Mon Sep 17 00:00:00 2001 From: JD-ETH Date: Sat, 20 Dec 2025 18:03:11 +0000 Subject: [PATCH 08/11] refactored --- .../update_weight/remote_transfer_plan.py | 190 +++++++++--------- .../update_weight_from_distributed.py | 9 +- .../update_weight/update_weight_from_rdma.py | 131 ++++++------ .../update_weight_from_remote.py | 65 +++--- 4 files changed, 191 insertions(+), 204 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py index 23598429c4..211653b657 100644 --- a/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py +++ b/slime/backends/megatron_utils/update_weight/remote_transfer_plan.py @@ -7,6 +7,7 @@ import logging from argparse import Namespace +from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass from typing import Literal @@ -14,25 +15,9 @@ import torch from megatron.core import mpu -from .common import expert_named_params_and_buffers, non_expert_named_params_and_buffers - logger = logging.getLogger(__name__) -@dataclass -class TransferTask: - """ - Attributes: - session: Session identifier (e.g., NCCL group name or Transfer Engine Session Id) - named_params_and_buffers: tensors to be transferred from this rank. - tensor_type: "expert" or "non-expert" are two diverse types of tasks. - """ - - named_params_and_buffers: list[tuple[str, torch.Tensor]] - session: str # NCCL group name or target entity id. - tensor_type: Literal["expert", "non-expert"] - - @dataclass class TransferTaskP2PMeta: """ @@ -41,7 +26,7 @@ class TransferTaskP2PMeta: engine_ind: int # The index of the target rollout engine. engine_rank: int # The shard within the target rollout engine. - group: Literal["expert", "non-expert"] + source_shard: int = 0 # The source pp shard index. class RemoteTransferPlan: @@ -73,13 +58,14 @@ def __init__( mode: Transfer backend mode - either "nccl" or "rdma" """ self.mode = mode - self._get_parallel_info(args) - self.targets: list[TransferTaskP2PMeta] = self._plan_p2p() if mode == "rdma" else [] - self.transfer_tasks: list[TransferTask] = [] - self.non_expert_params_buffers = list(non_expert_named_params_and_buffers(args, model)) - self.expert_params_buffers = list(expert_named_params_and_buffers(args, model)) + self._get_parallelism(args) + + def _get_parallelism(self, args: Namespace) -> None: + """ + Collecting and printing out parallelism information for both source (trainer) and target (rollout engines). + Also print out the parallelism information after the ep/tp all-gather for the 2 parameter groups. + """ - def _get_parallel_info(self, args: Namespace) -> None: # Gather the source (current trainer) information. self._pp_rank, self._pp_size = ( mpu.get_pipeline_model_parallel_rank(), @@ -99,12 +85,13 @@ def _get_parallel_info(self, args: Namespace) -> None: self._rollout_tp_size = args.sglang_tp_size self._rollout_dp_size = args.sglang_dp_size self._rollout_ep_size = args.sglang_ep_size - # PP sizes are not supported currently. + # EP and PP sizes are not tested and likely miss functionalities. self._rollout_pp_size = args.sglang_pp_size if self._rollout_ep_size != 1 or self._rollout_pp_size != 1: raise NotImplementedError("Rollout expert and pipeline parallelisms are not supported yet.") - num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) - self._rollout_engine_count = args.rollout_num_gpus // num_gpu_per_engine + self._num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) + self._rollout_engine_count = args.rollout_num_gpus // self._num_gpu_per_engine + self._rollout_num_gpus = args.rollout_num_gpus logger.info( f"RemoteTransferPlan initialized: mode={self.mode}, pp_rank={self._pp_rank}/{self._pp_size}, tp_rank={self._tp_rank}/{self._tp_size}, " f"ep_rank={self._ep_rank}/{self._ep_size}, etp_rank={self._etp_rank}/{self._etp_size}, dp_rank={self._dp_rank}/{self._dp_size}" @@ -113,9 +100,9 @@ def _get_parallel_info(self, args: Namespace) -> None: f"Rollout engine count: {self._rollout_engine_count}, tp_size={self._rollout_tp_size}, ep_size={self._rollout_ep_size}, dp_size={self._rollout_dp_size}" ) - # Expert and non expert parameters can have different parallel groups after all-gather. self._gathered_dp_size = self._dp_size * self._tp_size self._gathered_dp_rank = self._dp_rank * self._tp_size + self._tp_rank + # TODO: If I understand correctly the final size should be same as we now only have pp - dp dimensions for both param groups? expert_tp_size = self._ep_size * self._etp_size self._gathered_expert_dp_size = self._dp_size * expert_tp_size self._gathered_expert_dp_rank = ( @@ -128,46 +115,81 @@ def _get_parallel_info(self, args: Namespace) -> None: f"Gathered dp_rank={self._gathered_dp_rank}, gathered expert dp_rank={self._gathered_expert_dp_rank}" ) - def _plan_p2p(self) -> list[TransferTaskP2PMeta]: - def plan( - source_size: int, - source_rank: int, - num_rank_in_target: int, - num_targets: int, - params: str, - cur_active_rank: int = 0, - ) -> list[TransferTaskP2PMeta]: - transfer_tasks = [] - for target_ind in range(num_targets): - for target_rank in range(num_rank_in_target): - if cur_active_rank % source_size == source_rank: - # TODO(jd): instead of doing round robin, we should prioritize reusing source ranks with same target_rank to - # avoid duplicating local copies. - transfer_tasks.append( - TransferTaskP2PMeta(engine_ind=target_ind, engine_rank=target_rank, group=params) - ) - logger.info( - f"Planned P2P transfer task: source_rank={source_rank} -> target_engine_ind={target_ind}, target_engine_rank={target_rank}, group={params}" - ) - cur_active_rank += 1 - return transfer_tasks - - # TODO(JD): Due to the local replica design, the plan should proritize reusing existing copies and merge sources. - - non_expert_plan = plan( - source_size=self._gathered_dp_size, - source_rank=self._gathered_dp_rank, - num_rank_in_target=self._rollout_dp_size * self._rollout_tp_size, - num_targets=self._rollout_engine_count, - params="non-expert", - ) - return non_expert_plan + plan( - source_size=self._gathered_expert_dp_size, - source_rank=self._gathered_expert_dp_rank, - num_rank_in_target=self._rollout_dp_size * self._rollout_ep_size, - num_targets=self._rollout_engine_count, - params="expert", - ) + self._rank = self._gathered_dp_rank + + def get_nccl_group(self) -> str: + """ + Get the NCCL group name for weight transfer. + + Returns: + str - NCCL group name + """ + assert self.mode == "nccl", "NCCL group only applicable for NCCL mode." + return f"slime-pp_{self._pp_rank}" + + def plan_p2p(self) -> list[TransferTaskP2PMeta]: + """ + For each pp shard source rank, we plan the mapping relationship between n source dp ranks, m target rollout engines with k ranks each. + The Transfer Plan Mapping Heuristics works as follows: + 1. for each target engine (idx, rank), assign source ranks in a round-robin manner until all source ranks are assigned at least once. + 2. for the reminder target (idx, rank), assign them to source ranks by priotizing the source with existing assignmeng of same rank. + + For example, 4 source ranks (0,1,2,3), 2 target engines with 3 ranks each (0,0),(0,1),(0,2),(1,0),(1,1),(1,2). + The first round of assignment: + source_rank=0 -> target (0,0) + source_rank=1 -> target (0,1) + source_rank=2 -> target (0,2) + source_rank=3 -> target (1,0) + The reminder assignment: + source_rank=1 -> target (1,1) # prioritize source_rank=1 as it had (0,1) assigned already. + source_rank=2 -> target (1,2) + + Finally extract the transfer tasks matching the current dp_rank. + """ + + all_targets = [ + (m_idx, k_idx) for m_idx in range(self._rollout_engine_count) for k_idx in range(self._num_gpu_per_engine) + ] + # Assignments: source_rank -> {engin_rank: [engine_indices]} + assignements = defaultdict(lambda: defaultdict(list)) + # First round robin assignment + i = -1 + for source_rank, (idx, target) in zip(range(self._gathered_dp_size), enumerate(all_targets), strict=False): + i = idx + m_idx, k_idx = target + assignements[source_rank][k_idx].append(m_idx) + + def count_engine_index_assignments(k_idx: int) -> int: + return [len(assignements[source][k_idx]) for source in range(self._gathered_dp_size)] + + # Reminder assignment by least_assigned_source + cur_source_index = 0 + if i < len(all_targets) - 1: + for target in all_targets[i + 1 :]: + m_idx, k_idx = target + # count current assignments for source who has k_idx + counted = count_engine_index_assignments(k_idx) + # If any source has existing assignment for k_idx, assign it. + if max(counted) > 0: + _, select_source = min((val, idx) for (idx, val) in enumerate(counted) if val > 0) + # Else go back to round robin. + else: + select_source = cur_source_index % self._gathered_dp_size + cur_source_index += 1 + assignements[select_source][k_idx].append(m_idx) + + # Extract transfer tasks for current rank. + logger.info(f"[TransferPlanner] Full transfer assignments: {dict(assignements)}") + transfer_tasks = [] + for engine_rank, engine_indices in assignements[self._rank].items(): + for engine_ind in engine_indices: + logger.info( + f"[TransferPlanner] New task: source_rank={self._rank} pp_shard={self._pp_rank} -> target_engine_ind={engine_ind}, target_engine_rank={engine_rank}" + ) + transfer_tasks.append( + TransferTaskP2PMeta(source_shard=self._pp_rank, engine_ind=engine_ind, engine_rank=engine_rank) + ) + return transfer_tasks def is_source(self) -> bool: """ @@ -182,35 +204,5 @@ def is_source(self) -> bool: mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 ) - return len(self.targets) > 0 - - def add_transfer_task(self, session: str, param_group: Literal["expert", "non-expert"]) -> None: - """ - Add a transfer task to the plan using remote instance session and tensor names. - """ - params = self.non_expert_params_buffers if param_group == "non-expert" else self.expert_params_buffers - if params: - self.transfer_tasks.append( - TransferTask(session=session, named_params_and_buffers=params, tensor_type=param_group) - ) - logger.info(f"Added {param_group} parameter transfer task: session={session}, num_tensors={len(params)}") - - def clear_transfer_tasks(self) -> None: - self.transfer_tasks = [] - - def get_transfer_tasks(self) -> list[TransferTask]: - # Generate session identifier based on mode - if self.mode == "nccl": - session = f"slime-pp_{self._pp_rank}" - # In NCCL mode, the transfer is simply a broadcast from DP=TP=0 to all rollout engines. - return [ - TransferTask( - session=session, named_params_and_buffers=self.non_expert_params_buffers, tensor_type="non-expert" - ), - TransferTask( - session=session, named_params_and_buffers=self.expert_params_buffers, tensor_type="expert" - ), - ] - if self.targets and not self.transfer_tasks: - raise RuntimeError("RDMA need to query target engine information for transfer task generations.") - return self.transfer_tasks + # Only case where RDMA P2P is not sending is when the current DP rank is >= total number of rollout GPUs. + return False if (self._rank >= self._rollout_num_gpus) else True diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index efae91eb49..eadc61919c 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -43,9 +43,8 @@ def __init__( ) if self._is_source: - transfer_tasks = self.transfer_plan.get_transfer_tasks() assert self.transfer_plan.mode == "nccl", "Only NCCL supported currently." - assert len(transfer_tasks) == 2, "Only two transfer tasks supported currently." + self._group_name = self.transfer_plan.get_nccl_group() # Indicates if the nccl group has been established. self._model_update_groups = None @@ -62,8 +61,6 @@ def connect_rollout_engines( # 1. AllGather paramters to rank 0 # 2. Broadcast parameters from rank 0 to all sglang engines if self._is_source: - transfer_tasks = self.transfer_plan.get_transfer_tasks() - self._group_name = transfer_tasks[0].session if self._model_update_groups is not None: # Reestablish group if already connected, e.g. new instance has joined. disconnect_rollout_engines_from_distributed( @@ -74,7 +71,7 @@ def connect_rollout_engines( ) def _update_bucket_weights_from_remote( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None + self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. @@ -83,7 +80,7 @@ def _update_bucket_weights_from_remote( while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) refs = update_weights_from_distributed( - session_id, + self._group_name, self._model_update_groups, self.weight_version, self.rollout_engines, diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 7acd9ed30d..fee32d277d 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -1,6 +1,5 @@ import dataclasses import logging -import os from argparse import Namespace from collections.abc import Callable, Mapping, Sequence @@ -14,7 +13,6 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig -# from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs from tqdm import tqdm @@ -28,11 +26,47 @@ def create_server_args_from_dict(data_dict: dict) -> ServerArgs: + # Reconstruct Sglang ServerArgs from sglang Http query. valid_fields = {f.name for f in dataclasses.fields(ServerArgs)} filtered_data = {k: v for k, v in data_dict.items() if k in valid_fields} return ServerArgs(**filtered_data) +@dataclasses.dataclass +class RemoteWeightInfo: + # Remote session and weight registration info. + session_id: str + weights_info: dict[str, tuple[int, int, int]] # name -> (remote_address, numel, element_size) + + +@dataclasses.dataclass +class TransferBundle: + model_replica: Sequence[torch.nn.Module] + engine: TransferEngine + weight_memory_registry: dict + remote_weight_infos: list[RemoteWeightInfo] + + def add_remote_session(self, remote_info: RemoteWeightInfo) -> None: + self.remote_weight_infos.append(remote_info) + + def execute(self) -> None: + # Execute transfer for each target session using this replica. + for remote_session in self.remote_weight_infos: + session_id, remote_weights_info = remote_session.session_id, remote_session.weights_info + source_ptrs, target_ptrs, source_lens = [], [], [] + for name, tensor_register in self.weight_memory_registry.items(): + data_ptr, numel, ele_size = tensor_register + source_ptrs.append(data_ptr) + target_ptrs.append(remote_weights_info[name][0]) # remote address + source_lens.append(numel * ele_size) + + # Batch transfer weights through RDMA + ret = self.engine.batch_transfer_sync_write(session_id, source_ptrs, target_ptrs, source_lens) + # logger.info(f"[RDMA] Batch transferred {len(weight_memory_registry)} tensors to session {session_id}.") + if ret < 0: + raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") + + class UpdateWeightFromRDMA(UpdateWeightFromRemote): """ Update weights from RDMA using Transfer Engine. @@ -40,6 +74,7 @@ class UpdateWeightFromRDMA(UpdateWeightFromRemote): Similar to UpdateWeightFromNCCL but uses P2P RDMA transfer engine for the underlying weight transfer. Workflow consists of following steps: 1. Based off the transfer plan, query the target rollout engines for remote session and weight info during connect_rollout_engines. + 2. Construct local model replica according to the plan and attach target session id and weight memory registry 2. Do TP-EP all-gather for bucketed weights on parameters needing transfer from local just as in NCCL case. 3. Convert the gathered HF tensor into target shape and register them with Engine. 4. Call engine to batch transfer weights for each transfer task. @@ -54,10 +89,6 @@ def __init__( model_name: str, quantization_config: dict[str, int | str | list[str]] | None, ) -> None: - """ - Initialize transfer engine. - """ - # Call parent constructor to initialize all base attributes super().__init__( args, model, @@ -67,8 +98,6 @@ def __init__( weight_update_mode="rdma", ) - self.transfer_engine = None - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: @@ -82,7 +111,8 @@ def connect_rollout_engines( if self._is_source: # Query Engine session and weight info from rollout instances according to the transfer plan self.remote_weight_infos_by_session_id = {} - targets_to_query = set((target.engine_ind, target.engine_rank) for target in self.transfer_plan.targets) + targets = self.transfer_plan.plan_p2p() + targets_to_query = set((target.engine_ind, target.engine_rank) for target in targets) targets_to_session_id, self.session_id_to_engine_rank = {}, {} self.session_id_to_server_args = {} for engine_ind, engine_rank in targets_to_query: @@ -106,45 +136,42 @@ def connect_rollout_engines( print_memory("[RDMA] After obtaining remote weight info") - # Local model with identical shape to remote. Create at most one copy per target rank, and link - # them by session id. + # Create local model replicas and transfer engines for each target rollout shard self.engines = {} # Associate transfer tasks based on obtained session and weight info - for target in self.transfer_plan.targets: + for target in targets: session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] - self.transfer_plan.add_transfer_task( - session=session_id, - param_group=target.group, - ) + remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id]) # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. if target.engine_rank not in self.engines: transfer_engine = self._create_transfer_engine() model_replica = self._create_inference_replica( self.args.hf_checkpoint, + pp_shard=target.source_shard, target_rank=target.engine_rank, target_tp=self.args.rollout_num_gpus_per_engine, server_args=self.session_id_to_server_args[session_id], ) - print_memory("[RDMA] After model replica") + print_memory(f"[RDMA] After model replica at {target.engine_rank}") weight_memory_registry = self._register_replica_memory( model_replica, self.remote_weight_infos_by_session_id[session_id], transfer_engine ) - self.engines[target.engine_rank] = (model_replica, transfer_engine, weight_memory_registry) + self.engines[target.engine_rank] = TransferBundle( + model_replica, transfer_engine, weight_memory_registry, [remote_info] + ) + else: + self.engines[target.engine_rank].add_remote_session(remote_info) print_memory("[RDMA] After Local Engine Replicas and engine Creation") def _register_replica_memory(self, model_replica, remote_weight_info, transfer_engine) -> dict: - - old_cuda_alloc_value = os.environ.get("PYTORCH_ALLOC_CONF", "") - os.environ["PYTORCH_ALLOC_CONF"] = "" to_register_named_tensors = [] named_tensors = dict(model_replica.named_parameters()) - # Verify the 1-to-1 mapping between registered weights and remote weights expected. - for name, info in remote_weight_info.items(): - (_, remote_numel, remote_ele_size) = info - if name not in named_tensors: - raise RuntimeError(f"Remote replica parameter {name} not found in local replica.") - tensor = named_tensors[name] + # Verify the 1-to-1 mapping between local replica and remote weights expected. + for name, tensor in named_tensors.items(): + if name not in remote_weight_info: + raise RuntimeError(f"Local replica parameter {name} not found in remote replica.") + remote_numel, remote_ele_size = remote_weight_info[name][1], remote_weight_info[name][2] if tensor.numel() != remote_numel or tensor.element_size() != remote_ele_size: raise RuntimeError( f"Local replica parameter {name} numel {tensor.numel()} size {tensor.element_size()} does not match remote numel {remote_numel} size {remote_ele_size}." @@ -154,15 +181,12 @@ def _register_replica_memory(self, model_replica, remote_weight_info, transfer_e to_register_named_tensors.append((name, tensor)) weight_memory_registry = register_memory_transfer_engine(to_register_named_tensors, transfer_engine) - os.environ["PYTORCH_ALLOC_CONF"] = old_cuda_alloc_value logger.info( f"[RDMA] Registered {len(to_register_named_tensors)} tensors of total {len(named_tensors)} from replica with transfer engine." ) return weight_memory_registry def _create_transfer_engine(self) -> TransferEngine: - # local_ip = ray._private.services.get_node_ip_address() - # transfer_engine = MooncakeTransferEngine(local_ip, None, None) transfer_engine = TransferEngine() local_ip = ray._private.services.get_node_ip_address() transfer_engine.initialize(local_ip, "P2PHANDSHAKE", "rdma", "") @@ -170,7 +194,9 @@ def _create_transfer_engine(self) -> TransferEngine: logger.info(f"[RDMA] Local replica Transfer Engine initialized at port {transfer_engine.get_rpc_port()}") return transfer_engine - def _create_inference_replica(self, model_path: str, target_rank: int, target_tp: int, server_args: ServerArgs): + def _create_inference_replica( + self, model_path: str, pp_shard: int, target_rank: int, target_tp: int, server_args: ServerArgs + ): """ Create model replica for target rank with correct tp settings. @@ -188,7 +214,10 @@ def _create_inference_replica(self, model_path: str, target_rank: int, target_tp # Mock the distributed environment to get correct weight shapes # TODO: Reuse https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py#L845 - # For memory pinning and CPU offloading? + # For memory pinning and CPU offloading + logger.error( + f" Engine replica: {target_rank} tp {target_tp} pp_shard {pp_shard}, model pp sharding not implemented " + ) with MockSglangDistributedContext(tp_size=target_tp, tp_rank=target_rank, server_args=server_args): model = get_model( model_config=model_config, @@ -199,27 +228,9 @@ def _create_inference_replica(self, model_path: str, target_rank: int, target_tp logger.info(f" Model {device}, params: {sum(p.numel() for p in model.parameters())} ") return model - def _execute_transfer(self, session_id: str) -> None: - """ - Execute weight transfer for a single transfer task using RDMA P2P transfer engine. - """ - _, engine, weight_memory_registry = self.engines[self.session_id_to_engine_rank[session_id]] - remote_weight_info = self.remote_weight_infos_by_session_id[session_id] - source_ptrs, target_ptrs, source_lens = [], [], [] - for name, tensor_register in weight_memory_registry.items(): - data_ptr, numel, ele_size = tensor_register - source_ptrs.append(data_ptr) - target_ptrs.append(remote_weight_info[name][0]) # remote address - source_lens.append(numel * ele_size) - - # Batch transfer weights through RDMA - ret = engine.batch_transfer_sync_write(session_id, source_ptrs, target_ptrs, source_lens) - # logger.info(f"[RDMA] Batch transferred {len(weight_memory_registry)} tensors to session {session_id}.") - if ret < 0: - raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {ret}.") - def leader_post_update(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + # Update weight version as we were write-only. ray.get( [ engine.update_weight_version.remote(weight_version=str(self.weight_version)) @@ -229,7 +240,7 @@ def leader_post_update(self) -> None: return def _update_bucket_weights_from_remote( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None + self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: """ The RDMA P2P weight update is implemented as a single side write, meaning the trainer writes its weights directly to the rollout engines' memory. @@ -238,30 +249,26 @@ def _update_bucket_weights_from_remote( if not self._is_source or not converted_named_tensors: return - # Refactoring needed: - # TODO: refactor update_weight to still do a single traversal of the model dict; session_id should be per weight instead. - # TODO: There is probably enough difference to the UpdateFromNCCL that we should just rebuild from scratch maybe? # Functionality missing: # TODO: Fix learner PP, right now we still send all weights from any source. # TODO: Support engine expert parallel, which has a bunch of details like dp_attnetion_tp etc. # TODO: Extensive tests on different pp/ep/tp settings --> tp/dp/ep settings. # TODO: Need a correctness test of the model weights similar to the test:https://github.com/sgl-project/sglang/pull/14997/changes#diff-6efab5fd819ef0efa7a1f43989320bb28231702f8840897eb7acacf174f6e71f - # TODO: Memory profiling. # TODO: Design of experiments --- what other configurations do we need to enable. # Optimizations: - # TODO: remote transfer plan optimizes for reduce local copy memory usage # TODO: pipeline the all-gather/reshard with transfer engine transfer calls for performance. # TODO: increase concurrency with non-blocking transfers to multiple targets. # TODO: memory offloading if the replica becomes a bottleneck. - # Question: - # 1. Do we really want to support sglang pipeline paralell? # Load weights into local replica matching the target session, this handles sharding and reshaping. - self.engines[self.session_id_to_engine_rank[session_id]][0].load_weights(converted_named_tensors) + for transfer_bundle in self.engines.values(): + transfer_bundle.model_replica.load_weights(converted_named_tensors) converted_named_tensors.clear() - def finish_transfer_task(self, session: str) -> None: - self._execute_transfer(session) + def finish_transfer_task(self) -> None: + # Execute transfer for each engine replica. + for transfer_bundle in self.engines.values(): + transfer_bundle.execute() return diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py index 9773cdf62d..5ca4d7eef1 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py @@ -2,6 +2,7 @@ from argparse import Namespace from collections.abc import Callable, Mapping, Sequence from typing import Literal + import ray import torch import torch.distributed as dist @@ -13,7 +14,7 @@ from slime.utils.timer import timer from ..megatron_to_hf import convert_to_hf -from .common import all_gather_param +from .common import all_gather_param, expert_named_params_and_buffers, non_expert_named_params_and_buffers from .remote_transfer_plan import RemoteTransferPlan @@ -54,18 +55,17 @@ def connect_rollout_engines( @abstractmethod def _update_bucket_weights_from_remote( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None + self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: """ - Implementation of the bucketed parameter update from remote. session_id is used as the identifier - for the operation, either NCCL group name or Transfer Engine session id. - TODO(jd): to avoid traversing the model dict multiple times we need session_id to be a list. + Implementation of the bucketed parameter update from remote. """ @torch.no_grad() def update_weights(self) -> None: """ - Pause → flush → non-expert (TP) → expert (EP) → continue. Progress on PP source. + For each named parameter in the model, do bucketed weight update by all-gather EP/TP, convert and quantize, + and relies on underlying implementation to do the transfer. """ self.weight_version += 1 @@ -75,16 +75,15 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) with timer("update_weights_implementation"): - for transfer_task in self.transfer_plan.get_transfer_tasks(): - # Update non-expert or expert weights - if transfer_task.tensor_type == "non-expert": - self._update_weights(transfer_task.named_params_and_buffers, transfer_task.session) - elif transfer_task.tensor_type == "expert": - self._update_expert_weights(transfer_task.named_params_and_buffers, transfer_task.session) - else: - raise ValueError(f"Unknown tensor type {transfer_task.tensor_type} in transfer task.") - dist.barrier(group=get_gloo_group()) - self.finish_transfer_task(transfer_task.session) + # A single traversal through all parameters to update weights. Update happens first to the + # non-expert weights, then to expert weights. + non_expert_params_and_buffers = non_expert_named_params_and_buffers(self.args, self.model) + expert_params_and_buffers = expert_named_params_and_buffers(self.args, self.model) + self._update_weights(non_expert_params_and_buffers) + dist.barrier(group=get_gloo_group()) + self._update_expert_weights(expert_params_and_buffers) + dist.barrier(group=get_gloo_group()) + self.finish_transfer_task() dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: @@ -95,39 +94,33 @@ def leader_post_update(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) return - def finish_transfer_task(self, session: str) -> None: + def finish_transfer_task(self) -> None: return - def _update_expert_weights( - self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]], session_id: str - ) -> None: - pbar = tqdm(desc=f"[{session_id}] Update Expert Weights", total=0) if self._is_source else None + def _update_expert_weights(self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]]) -> None: + pbar = tqdm(desc="[Update Expert Weights]", total=0) if self._is_source else None buffer_size = 0 named_tensors = [] for name, param in named_params_and_buffers: # transfer expert tensors assert ".experts." in name, "Function intended for expert params only." - buffer_size = self._update_expert_weight_from_remote( - name, param, named_tensors, buffer_size, session_id, pbar=pbar - ) + buffer_size = self._update_expert_weight_from_remote(name, param, named_tensors, buffer_size, pbar=pbar) if named_tensors: - self._update_expert_bucket_weights_from_remote(named_tensors, session_id, pbar=pbar) + self._update_expert_bucket_weights_from_remote(named_tensors, pbar=pbar) - def _update_weights(self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]], session_id: str) -> None: - pbar = tqdm(desc=f"[{session_id}] Update Weights", total=0) if self._is_source else None + def _update_weights(self, named_params_and_buffers: Sequence[tuple[str, torch.Tensor]]) -> None: + pbar = tqdm(desc="[Update Weights]", total=0) if self._is_source else None buffer_size = 0 converted_named_tensors = [] # non expert params for name, param in named_params_and_buffers: # transfer tp tensors assert ".experts." not in name, "Function intended for non-expert params only." - buffer_size = self._update_weight_from_remote( - name, param, converted_named_tensors, buffer_size, session_id, pbar=pbar - ) + buffer_size = self._update_weight_from_remote(name, param, converted_named_tensors, buffer_size, pbar=pbar) if converted_named_tensors: - self._update_bucket_weights_from_remote(converted_named_tensors, session_id, pbar=pbar) + self._update_bucket_weights_from_remote(converted_named_tensors, pbar=pbar) def _update_weight_from_remote( self, @@ -135,7 +128,6 @@ def _update_weight_from_remote( param: torch.nn.Parameter, converted_named_tensors: list[tuple[str, torch.Tensor]], buffer_size: int, - session_id: str, pbar: tqdm | None = None, ) -> int | None: """ @@ -148,7 +140,7 @@ def _update_weight_from_remote( param_size = param.numel() * param.element_size() if buffer_size + param_size > self.args.update_weight_buffer_size: - self._update_bucket_weights_from_remote(converted_named_tensors, session_id, pbar=pbar) + self._update_bucket_weights_from_remote(converted_named_tensors, pbar=pbar) buffer_size = 0 converted_named_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) buffer_size += param_size @@ -160,7 +152,6 @@ def _update_expert_weight_from_remote( param: torch.nn.Parameter, named_tensors: list[tuple[str, torch.Tensor]], buffer_size: int, - session_id: str, pbar: tqdm | None = None, ) -> int: """ @@ -172,7 +163,7 @@ def _update_expert_weight_from_remote( if ( buffer_size + param_size ) * mpu.get_expert_model_parallel_world_size() > self.args.update_weight_buffer_size and named_tensors: - self._update_expert_bucket_weights_from_remote(named_tensors, session_id, pbar=pbar) + self._update_expert_bucket_weights_from_remote(named_tensors, pbar=pbar) buffer_size = 0 named_tensors.append((name, param)) @@ -180,7 +171,7 @@ def _update_expert_weight_from_remote( return buffer_size def _update_expert_bucket_weights_from_remote( - self, named_tensors: list[tuple[str, torch.Tensor]], session_id: str, pbar: tqdm | None = None + self, named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: """ Gather EP → HF → broadcast. Clears buffer. @@ -215,4 +206,4 @@ def _update_expert_bucket_weights_from_remote( for name, param in all_gathered_params: converted_hf_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - self._update_bucket_weights_from_remote(converted_hf_tensors, session_id, pbar) + self._update_bucket_weights_from_remote(converted_hf_tensors, pbar) From 52767eb96443a29c09d3b04874d171fca9760b4d Mon Sep 17 00:00:00 2001 From: Risc-lt <1291903308rlt@sjtu.edu.cn> Date: Wed, 24 Dec 2025 18:54:39 +0000 Subject: [PATCH 09/11] fix: modify check weight equality --- slime/backends/sglang_utils/sglang_engine.py | 2 +- tests/test_weight_transfer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 3cfff3aefd..202c3f1b02 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -313,7 +313,7 @@ def resume_memory_occupation(self, tags: list[str] = None): ) def check_weights(self, action: str): - return self._make_request("check_weights", {"action": action}) + return self._make_request("weights_checker", {"action": action}) def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): return self._make_request( diff --git a/tests/test_weight_transfer.py b/tests/test_weight_transfer.py index 92d039004c..27a9a95e51 100644 --- a/tests/test_weight_transfer.py +++ b/tests/test_weight_transfer.py @@ -100,7 +100,7 @@ def execute(args: ScriptArgs): "--actor-num-gpus-per-node 1 " # 1GB buffer for weight update f"--update-weight-buffer-size {1 * 1024 ** 3} " - # "--check-weight-update-equal " + f"--check-weight-update-equal " ) if args.mode == "rdma": misc_args += "--update-weight-transfer-mode rdma " From af9bb9d1ee7cc1e1bce70c749631739408a4a416 Mon Sep 17 00:00:00 2001 From: Risc-lt <1291903308rlt@sjtu.edu.cn> Date: Fri, 26 Dec 2025 05:15:07 +0000 Subject: [PATCH 10/11] feat: offload model replica and transfer engine --- .../update_weight/update_weight_from_rdma.py | 61 ++++++++++++------- .../update_weight_from_remote.py | 1 + slime/ray/actor_group.py | 12 ++++ tests/test_weight_transfer.py | 1 + 4 files changed, 53 insertions(+), 22 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index fee32d277d..4788b88e54 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -7,6 +7,7 @@ import sglang.srt.layers.dp_attention as sglang_dp_attention import sglang.srt.server_args as sglang_server_args import torch +from torch_memory_saver import torch_memory_saver from mooncake.engine import TransferEngine from ray.actor import ActorHandle from sglang.srt.configs.device_config import DeviceConfig @@ -98,6 +99,10 @@ def __init__( weight_update_mode="rdma", ) + # For torch memory saver tagging + self.tag = f"Model Replica {self.global_rank}" + self._is_paused = False + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: @@ -139,28 +144,29 @@ def connect_rollout_engines( # Create local model replicas and transfer engines for each target rollout shard self.engines = {} # Associate transfer tasks based on obtained session and weight info - for target in targets: - session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] - remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id]) - # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. - if target.engine_rank not in self.engines: - transfer_engine = self._create_transfer_engine() - model_replica = self._create_inference_replica( - self.args.hf_checkpoint, - pp_shard=target.source_shard, - target_rank=target.engine_rank, - target_tp=self.args.rollout_num_gpus_per_engine, - server_args=self.session_id_to_server_args[session_id], - ) - print_memory(f"[RDMA] After model replica at {target.engine_rank}") - weight_memory_registry = self._register_replica_memory( - model_replica, self.remote_weight_infos_by_session_id[session_id], transfer_engine - ) - self.engines[target.engine_rank] = TransferBundle( - model_replica, transfer_engine, weight_memory_registry, [remote_info] - ) - else: - self.engines[target.engine_rank].add_remote_session(remote_info) + with torch_memory_saver.region(tag=self.tag): + for target in targets: + session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] + remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id]) + # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. + if target.engine_rank not in self.engines: + transfer_engine = self._create_transfer_engine() + model_replica = self._create_inference_replica( + self.args.hf_checkpoint, + pp_shard=target.source_shard, + target_rank=target.engine_rank, + target_tp=self.args.rollout_num_gpus_per_engine, + server_args=self.session_id_to_server_args[session_id], + ) + print_memory(f"[RDMA] After model replica at {target.engine_rank}") + weight_memory_registry = self._register_replica_memory( + model_replica, self.remote_weight_infos_by_session_id[session_id], transfer_engine + ) + self.engines[target.engine_rank] = TransferBundle( + model_replica, transfer_engine, weight_memory_registry, [remote_info] + ) + else: + self.engines[target.engine_rank].add_remote_session(remote_info) print_memory("[RDMA] After Local Engine Replicas and engine Creation") @@ -261,6 +267,10 @@ def _update_bucket_weights_from_remote( # TODO: memory offloading if the replica becomes a bottleneck. # Load weights into local replica matching the target session, this handles sharding and reshaping. + if self._is_paused: + torch_memory_saver.resume(self.tag) + self._is_paused = False + for transfer_bundle in self.engines.values(): transfer_bundle.model_replica.load_weights(converted_named_tensors) converted_named_tensors.clear() @@ -269,6 +279,13 @@ def finish_transfer_task(self) -> None: # Execute transfer for each engine replica. for transfer_bundle in self.engines.values(): transfer_bundle.execute() + + # Offload model replicas from memory after transfer. + if not self._is_paused: + print_memory("[RDMA] Before offloading model replica") + torch_memory_saver.pause(self.tag) + self._is_paused = True + print_memory("[RDMA] After offloading model replica") return diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py index 5ca4d7eef1..31fe188058 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_remote.py @@ -44,6 +44,7 @@ def __init__( self.weight_version = 0 self.transfer_plan = RemoteTransferPlan(args, model, weight_update_mode) self._is_source = self.transfer_plan.is_source() + self.global_rank = dist.get_rank(group=get_gloo_group()) @abstractmethod def connect_rollout_engines( diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index ce5ca97e08..01e2b0029e 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -71,6 +71,18 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): env_vars["LD_PRELOAD"] = dynlib_path env_vars["TMS_INIT_ENABLE"] = "1" env_vars["TMS_INIT_ENABLE_CPU_BACKUP"] = "1" + elif self.args.update_weight_transfer_mode == "rdma": + import torch_memory_saver + + dynlib_path = os.path.join( + os.path.dirname(os.path.dirname(torch_memory_saver.__file__)), + "torch_memory_saver_hook_mode_preload.abi3.so", + ) + assert os.path.exists(dynlib_path), f"LD_PRELOAD so file {dynlib_path} does not exist." + + env_vars["LD_PRELOAD"] = dynlib_path + # env_vars["TMS_INIT_ENABLE"] = "1" + env_vars["TMS_INIT_ENABLE_CPU_BACKUP"] = "1" if self.args.use_routing_replay: env_vars["ENABLE_ROUTING_REPLAY"] = "1" diff --git a/tests/test_weight_transfer.py b/tests/test_weight_transfer.py index 27a9a95e51..c1939efc51 100644 --- a/tests/test_weight_transfer.py +++ b/tests/test_weight_transfer.py @@ -100,6 +100,7 @@ def execute(args: ScriptArgs): "--actor-num-gpus-per-node 1 " # 1GB buffer for weight update f"--update-weight-buffer-size {1 * 1024 ** 3} " + # enable correctness check f"--check-weight-update-equal " ) if args.mode == "rdma": From 2f637cb494139d8ded4fcd2e7bd5e64e1499027e Mon Sep 17 00:00:00 2001 From: Risc-lt <1291903308rlt@sjtu.edu.cn> Date: Sun, 28 Dec 2025 04:48:13 +0000 Subject: [PATCH 11/11] feat: add load-transfer pipelining --- .../update_weight/update_weight_from_rdma.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 4788b88e54..563206cd11 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -49,6 +49,26 @@ class TransferBundle: def add_remote_session(self, remote_info: RemoteWeightInfo) -> None: self.remote_weight_infos.append(remote_info) + + def execute_each(self, names: Sequence[str]) -> list[int]: + # FIXME: Execute transfer for updated weight. + batch_ids = [] + for remote_session in self.remote_weight_infos: + session_id, remote_weights_info = remote_session.session_id, remote_session.weights_info + source_ptrs, target_ptrs, source_lens = [], [], [] + for name in names: + tensor_register = self.weight_memory_registry[name] + data_ptr, numel, ele_size = tensor_register + source_ptrs.append(data_ptr) + target_ptrs.append(remote_weights_info[name][0]) # remote address + source_lens.append(numel * ele_size) + + # Batch transfer weights through RDMA + batch_id = self.engine.batch_transfer_async_write(session_id, source_ptrs, target_ptrs, source_lens) + batch_ids.append(batch_id) + # logger.info(f"[RDMA] Batch transferred {len(weight_memory_registry)} tensors to session {session_id}.") + return batch_ids + def execute(self) -> None: # Execute transfer for each target session using this replica. @@ -271,14 +291,23 @@ def _update_bucket_weights_from_remote( torch_memory_saver.resume(self.tag) self._is_paused = False - for transfer_bundle in self.engines.values(): - transfer_bundle.model_replica.load_weights(converted_named_tensors) + ret_batch = {} + for n, transfer_bundle in enumerate(self.engines.values()): + updated_name = transfer_bundle.model_replica.load_weights(converted_named_tensors) + ret_batch[n] = transfer_bundle.execute_each(updated_name) # FIXME: Do not wait for finish here. + + # FIXME: Add sync barrier for all transfers to finish. + for n, batch_ids in ret_batch.items(): + result = self.engines[n].engine.get_batch_transfer_status(batch_ids) + if result < 0: + raise RuntimeError(f"Batch transfer weights via RDMA failed with error code {result}.") + converted_named_tensors.clear() def finish_transfer_task(self) -> None: # Execute transfer for each engine replica. - for transfer_bundle in self.engines.values(): - transfer_bundle.execute() + # for transfer_bundle in self.engines.values(): + # transfer_bundle.execute() # Offload model replicas from memory after transfer. if not self._is_paused: