From 71380f5c1301394f1255cb668aa6f71c17ff30e4 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 10 Jul 2025 02:04:39 -0700 Subject: [PATCH 01/14] move some code to refit util Signed-off-by: Yuki Huang --- nemo_rl/models/megatron/refit_utils.py | 143 +++++++++++++- .../models/policy/megatron_policy_worker.py | 185 +----------------- .../models/generation/test_vllm_generation.py | 12 +- .../models/policy/test_megatron_worker.py | 7 + 4 files changed, 159 insertions(+), 188 deletions(-) diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index fb46030ce9..bceec907a6 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import re import time -from typing import Dict, List +from typing import Any, Dict, List, Tuple import torch from megatron.core import parallel_state @@ -28,6 +28,9 @@ RowParallelLinear, VocabParallelEmbedding, ) +from torch.distributed import get_process_group_ranks + +from nemo_rl.models.megatron.converters.common import get_global_key_from_local_key def get_tp_dim(model, param_name, named_modules_dict): @@ -155,3 +158,141 @@ def gather_params(model, keys, key_to_global_keys: Dict[str, List[str]]): print(f"Time taken to gather params: {time.perf_counter() - st}") return gathered_params + + +@torch.no_grad() +def get_param_info(model, dtype): + # Get parallel info + tp_group = parallel_state.get_tensor_model_parallel_group() + tp_group_rank_ids = get_process_group_ranks(tp_group) + + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + pp_group_rank_ids = get_process_group_ranks(pp_group) + + # Collect parameter info + param_info = [] + + # Dictionary of modules we can quickly look up to check if a module has TP + named_modules_dict = dict(model.named_modules()) + + # Process each parameter in the model + # state_dict includes parameters and persistent buffers + for name, param in model.state_dict().items(): + # Skip _extra_state entries (these are metadata, not actual weights) + if "_extra_state" in name: + continue + + shape = list(param.shape) + tp_dim = get_tp_dim(model, name, named_modules_dict) + if tp_dim is not None: + tp_rank_ids = tuple(sorted(tp_group_rank_ids)) + shape[tp_dim] *= len(tp_rank_ids) + else: + tp_rank_ids = (torch.distributed.get_rank(),) + + pp_rank_ids = tuple(sorted(pp_group_rank_ids)) + + # Calculate size for this parameter + prec_to_bytes = { + torch.bfloat16: 2, + torch.float16: 2, + torch.float32: 4, + } + scale = prec_to_bytes[dtype] / prec_to_bytes[param.dtype] + size_in_bytes = ( + param.element_size() + * param.numel() + * len(tp_rank_ids) + * len(pp_rank_ids) + * scale + ) + param_info.append( + ( + ( + name, + tuple(shape), + param.dtype, + ), + size_in_bytes, + ) + ) + + # Gather all parameter info from all PP ranks + pp_gathered_param_infos = [None] * pp_world_size + torch.distributed.all_gather_object( + pp_gathered_param_infos, param_info, group=pp_group + ) + pp_gathered_param_infos = [x for y in pp_gathered_param_infos for x in y] # type: ignore + + all_param_infos = pp_gathered_param_infos + + # Merge all parameter infos, keeping only unique parameter names + merged_param_info = [] + seen_params = set() + + for name, size in all_param_infos: + if name not in seen_params: + merged_param_info.append((name, size)) + seen_params.add(name) + + # Update param_info with the merged information + param_info = merged_param_info + print(f"Prepared {len(param_info)} tensors for refit") + + return param_info + + +@torch.no_grad() +def get_local_key_to_global_keys(self, state_dict_info: List[Tuple[Any, int]]): + """Get the local key to global keys mapping.""" + # Get parallel info + tp_group = parallel_state.get_tensor_model_parallel_group() + tp_world_size = torch.distributed.get_world_size(tp_group) + + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + pp_global_ranks = torch.distributed.get_process_group_ranks(group=pp_group) + pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() + + ep_group = parallel_state.get_expert_model_parallel_group() + ep_world_size = torch.distributed.get_world_size(ep_group) + + # start calculating the global key + ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") + state_dict = self.model.state_dict() + final_key_to_global_keys = {} + + for param_info, size in state_dict_info: + local_key, owner_pp_local_rank_id, _, _ = param_info + + # Step 1: create global key from local key + # if: for if a parameter is sharded along PP or EP; + # else: not sharded (like embedding) + pp_gathered_objs = [None] + if local_key in state_dict and owner_pp_local_rank_id == pp_local_rank_id: + pp_gathered_objs[0] = get_global_key_from_local_key( + local_key, self.model.config + ) + + # Step 2: gather global keys from ranks in PP group + src_global_rank = pp_global_ranks[owner_pp_local_rank_id] + torch.distributed.broadcast_object_list( + pp_gathered_objs, src=src_global_rank, group=pp_group + ) + + # Step 3: gather global keys from ranks in EP group + if ep_pattern.search(local_key): + ep_gathered_objs = [None] * ep_world_size + torch.distributed.all_gather_object( + ep_gathered_objs, pp_gathered_objs, group=ep_group + ) + flat_gathered_objs = [x for y in ep_gathered_objs for x in y] + else: + flat_gathered_objs = pp_gathered_objs + + final_key_to_global_keys[(local_key, owner_pp_local_rank_id)] = ( + flat_gathered_objs + ) + + return final_key_to_global_keys diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 691e1ce5b3..4555430ac6 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -83,7 +83,6 @@ reduce_max_stat_across_model_parallel_group, ) from ray.util.queue import Queue -from torch.distributed import get_process_group_ranks from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.interfaces import LossFunction, LossType @@ -100,13 +99,11 @@ forward_step_arbitrary_loss, ) from nemo_rl.models.megatron.community_import import import_model_from_hf_name -from nemo_rl.models.megatron.converters.common import ( - MegatronToHFConverter, - get_global_key_from_local_key, -) +from nemo_rl.models.megatron.converters.common import MegatronToHFConverter from nemo_rl.models.megatron.refit_utils import ( gather_params, - get_tp_dim, + get_param_info, + get_local_key_to_global_keys, ) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -684,7 +681,7 @@ def __init__( # Create a map that maps any local parameter name to a list of global parameter names. # This map is repeatedly used by parameter gatherring phase during refit of every step. - self.local_key_to_global_keys = self.get_local_key_to_global_keys( + self.local_key_to_global_keys = get_local_key_to_global_keys( state_dict_info=self.prepare_weights_for_ipc()[0] ) self.should_disable_forward_pre_hook = ( @@ -1261,59 +1258,6 @@ def report_device_id(self) -> str: return get_device_uuid(device_idx) @torch.no_grad() - def get_local_key_to_global_keys(self, state_dict_info: List[Tuple[Any, int]]): - """Get the local key to global keys mapping.""" - # Get parallel info - tp_group = parallel_state.get_tensor_model_parallel_group() - tp_world_size = torch.distributed.get_world_size(tp_group) - - pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_world_size = torch.distributed.get_world_size(pp_group) - pp_global_ranks = torch.distributed.get_process_group_ranks(group=pp_group) - pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() - - ep_group = parallel_state.get_expert_model_parallel_group() - ep_world_size = torch.distributed.get_world_size(ep_group) - - # start calculating the global key - ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") - state_dict = self.model.state_dict() - final_key_to_global_keys = {} - - for param_info, size in state_dict_info: - local_key, owner_pp_local_rank_id, _, _ = param_info - - # Step 1: create global key from local key - # if: for if a parameter is sharded along PP or EP; - # else: not sharded (like embedding) - pp_gathered_objs = [None] - if local_key in state_dict and owner_pp_local_rank_id == pp_local_rank_id: - pp_gathered_objs[0] = get_global_key_from_local_key( - local_key, self.model.config - ) - - # Step 2: gather global keys from ranks in PP group - src_global_rank = pp_global_ranks[owner_pp_local_rank_id] - torch.distributed.broadcast_object_list( - pp_gathered_objs, src=src_global_rank, group=pp_group - ) - - # Step 3: gather global keys from ranks in EP group - if ep_pattern.search(local_key): - ep_gathered_objs = [None] * ep_world_size - torch.distributed.all_gather_object( - ep_gathered_objs, pp_gathered_objs, group=ep_group - ) - flat_gathered_objs = [x for y in ep_gathered_objs for x in y] - else: - flat_gathered_objs = pp_gathered_objs - - final_key_to_global_keys[(local_key, owner_pp_local_rank_id)] = ( - flat_gathered_objs - ) - - return final_key_to_global_keys - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1322,126 +1266,12 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes - no_grad = torch.no_grad() - no_grad.__enter__() # Ensure model is in evaluation mode self.model.eval() - # Get parallel info - tp_group = parallel_state.get_tensor_model_parallel_group() - tp_world_size = torch.distributed.get_world_size(tp_group) - tp_group_rank_ids = get_process_group_ranks(tp_group) - - etp_group = parallel_state.get_expert_tensor_parallel_group() - etp_world_size = torch.distributed.get_world_size(etp_group) - etp_group_rank_ids = get_process_group_ranks(etp_group) - - pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_world_size = torch.distributed.get_world_size(pp_group) - pp_group_rank_ids = get_process_group_ranks(pp_group) - pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() - - ep_group = parallel_state.get_expert_model_parallel_group() - ep_world_size = torch.distributed.get_world_size(ep_group) - ep_group_rank_ids = get_process_group_ranks(ep_group) - - # Collect parameter info - param_info = [] - - # Dictionary of modules we can quickly look up to check if a module has TP - named_modules_dict = dict(self.model.named_modules()) - - # Process each parameter in the model - # state_dict includes parameters and persistent buffers - ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") - for name, param in self.model.state_dict().items(): - # Skip _extra_state entries (these are metadata, not actual weights) - if "_extra_state" in name: - continue - - use_etp = True if ep_pattern.search(name) else False - if use_etp: - tensor_mp_rank_ids = etp_group_rank_ids - else: - tensor_mp_rank_ids = tp_group_rank_ids - - shape = list(param.shape) - tp_dim = get_tp_dim(self.model, name, named_modules_dict) - if tp_dim is not None: - tp_rank_ids = tuple(sorted(tensor_mp_rank_ids)) - shape[tp_dim] *= len(tp_rank_ids) - else: - tp_rank_ids = (torch.distributed.get_rank(),) - - pp_rank_ids = tuple(sorted(pp_group_rank_ids)) - ep_rank_ids = tuple(sorted(ep_group_rank_ids)) - - if ep_pattern.search(name): - ep_rank_ids = tuple(sorted(ep_group_rank_ids)) - else: - ep_rank_ids = (torch.distributed.get_rank(),) - - # Calculate size for this parameter - prec_to_bytes = { - torch.bfloat16: 2, - torch.float16: 2, - torch.float32: 4, - } - scale = prec_to_bytes[self.dtype] / prec_to_bytes[param.dtype] - size_in_bytes = ( - param.element_size() - * param.numel() - * len(tensor_mp_rank_ids) - * len(ep_rank_ids) - * scale - ) - param_info.append( - ( - ( - name, - pp_local_rank_id, - tuple(shape), - param.dtype, - ), - size_in_bytes, - ) - ) - # Gather parameter info from all pipeline parallel ranks to ensure complete coverage - pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_world_size = torch.distributed.get_world_size(pp_group) - - # Gather all parameter info from all PP ranks - pp_gathered_param_infos = [None] * pp_world_size - torch.distributed.all_gather_object( - pp_gathered_param_infos, param_info, group=pp_group - ) - pp_gathered_param_infos = [x for y in pp_gathered_param_infos for x in y] # type: ignore - - # Gather parameter info from all expert parallel ranks to ensure complete coverage - ep_group = parallel_state.get_expert_model_parallel_group() - ep_world_size = torch.distributed.get_world_size(ep_group) - - # Gather all parameter info from all EP ranks - ep_gathered_param_infos = [None] * ep_world_size - torch.distributed.all_gather_object( - ep_gathered_param_infos, pp_gathered_param_infos, group=ep_group - ) - all_param_infos = [x for y in ep_gathered_param_infos for x in y] - - # Merge all parameter infos, keeping only unique parameter names - merged_param_info = [] - seen_params = set() - - for name, size in all_param_infos: - if name not in seen_params: - merged_param_info.append((name, size)) - seen_params.add(name) - - # Update param_info with the merged information - param_info = merged_param_info - - print(f"Prepared {len(param_info)} tensors for IPC transfer") - no_grad.__exit__(None, None, None) + # Get parameter info for refit + # param_info: list of ((name, shape, dtype), size_in_bytes) tuples + param_info = get_param_info(self.model, self.dtype) # Collect current available memory for refit ## Get current device index from torch @@ -1455,6 +1285,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: return param_info, total_available_bytes # Temporary fix, 'keys' is a kwarg due to some sort of ray bug + @torch.no_grad() def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: """Get IPC handles for the requested Megatron model weights. diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 626fce2b6f..9dc948b909 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -41,7 +41,7 @@ "name": model_name, }, "dtype": "bfloat16", - "max_new_tokens": 5, + "max_new_tokens": 5, # Small number of tokens for testing "temperature": 0.8, "top_p": 1.0, "top_k": None, @@ -133,15 +133,6 @@ def get_basic_megatron_test_config( "learning_rate": 5e-6, "logprob_batch_size": 2, "precision": precision, - "generation": { - "backend": "megatron", - "temperature": 1.0, - "max_new_tokens": 16, # Small number of tokens for testing - "top_p": 1.0, - "top_k": None, - "stop_token_ids": None, - "stop_strings": None, - }, "dtensor_cfg": { "enabled": False, # Disabled for Megatron tests }, @@ -202,6 +193,7 @@ def get_basic_megatron_test_config( "optimizer": None, # Remove default FSDP optimizer "scheduler": None, # Remove default scheduler "max_grad_norm": 1.0, + "generation": deepcopy(basic_vllm_test_config), } diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index a23c1b5559..ee1d422a3b 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -60,6 +60,13 @@ def create_megatron_test_config( "top_k": None, "stop_token_ids": None, "stop_strings": None, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, }, "dtensor_cfg": { "enabled": False, # Disabled for Megatron tests From 8fef28529e3311a1d636e9c031313ca99f9f5a7b Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 10 Jul 2025 02:10:06 -0700 Subject: [PATCH 02/14] prepare refit info once Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 12 ++-- nemo_rl/models/generation/interfaces.py | 8 ++- nemo_rl/models/generation/vllm.py | 39 ++++++++++--- nemo_rl/models/generation/vllm_backend.py | 14 ++++- .../models/policy/dtensor_policy_worker.py | 57 +++++++++---------- nemo_rl/models/policy/interfaces.py | 8 +-- nemo_rl/models/policy/lm_policy.py | 24 ++++---- .../models/policy/megatron_policy_worker.py | 14 ++--- 8 files changed, 101 insertions(+), 75 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d33e503636..dc32f9bab8 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -347,6 +347,12 @@ def setup( # wait for all futures to complete ray.get(futures_train + futures_inference) + # prepare refit info + # state_dict_info: {tensor_name: (shape, dtype)} + state_dict_info = policy.prepare_refit_info() + if not colocated_inference: + policy_generation.prepare_refit_info(state_dict_info) + loss_fn = ClippedPGLossFn(loss_config) print("\n" + "=" * 60) @@ -426,13 +432,9 @@ def refit_policy_generation( if not update_success: break else: - # prepare info for update weights - state_dict_info = policy.prepare_info_for_collective() # update weights through nccl futures_train = policy.broadcast_weights_for_collective() - futures_inference = policy_generation.update_weights_from_collective( - state_dict_info - ) + futures_inference = policy_generation.update_weights_from_collective() # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 9db0c357fa..652f3f9ca5 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -228,12 +228,14 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: def finish_generation(self, *args: Any, **kwargs: Any) -> bool: pass + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + raise NotImplementedError + def update_weights(self, ipc_handles: dict[str, Any]) -> bool: """Update the model weights from the given IPC handles.""" raise NotImplementedError - def update_weights_from_collective( - self, info: dict[str, Any] - ) -> list[ray.ObjectRef]: + def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update the model weights from collective communication.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index a99022e30f..7d020c05e3 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1021,6 +1021,14 @@ async def report_device_id_async(self) -> list[str]: return cast(list[str], list_of_worker_results) + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + + async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None: + """Async version of prepare_refit_info.""" + await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights from IPC handles by delegating to the vLLM Worker implementation. @@ -1132,7 +1140,7 @@ async def update_weights_from_ipc_handles_async( traceback.print_exc() return False - def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" try: assert self.llm is not None, ( @@ -1145,7 +1153,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: ) result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=(info,) + "update_weights_from_collective", args=tuple() ) worker_result = result_or_coro[0] @@ -1162,7 +1170,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: traceback.print_exc() return False - async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bool: + async def update_weights_from_collective_async(self) -> bool: """Async version of update_weights_from_collective.""" try: assert self.llm is not None, ( @@ -1175,7 +1183,7 @@ async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bo ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_collective", args=(info,) + "update_weights_from_collective", args=tuple() ) if asyncio.iscoroutine(result_or_coro): @@ -1908,6 +1916,22 @@ def shutdown(self) -> bool: print(f"Error during policy shutdown: {e}") return False + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + # Choose the appropriate method based on async_engine setting + method_name = ( + "prepare_refit_info_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "prepare_refit_info" + ) + + # Use run_all_workers_single_data to send data to all workers + self.worker_group.run_all_workers_single_data( + method_name, + state_dict_info=state_dict_info, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + def update_weights(self, ipc_handles: dict[str, Any]) -> bool: """Update weights of the policy using IPC handles, considering tensor parallelism. @@ -1952,9 +1976,7 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: print(f"Error during update weights: {e}") return False - def update_weights_from_collective( - self, info: dict[str, Any] - ) -> list[ray.ObjectRef]: + def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -1966,10 +1988,9 @@ def update_weights_from_collective( else "update_weights_from_collective" ) - # Use run_all_workers_single_data to send data to all workers + # Use run_all_workers_single_data for methods that don't need data futures = self.worker_group.run_all_workers_single_data( method_name, - info=info, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 3d6ed0253c..9b2b97e6a6 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -51,6 +51,11 @@ def report_device_id(self) -> str: return get_device_uuid(self.device.index) + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + # state_dict_info: {tensor_name: (shape, dtype)} + self.state_dict_info = state_dict_info + def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): """Update weights from global IPC handles. @@ -118,10 +123,15 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): ) return False - def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" + assert self.state_dict_info is not None, ( + "state_dict_info is not prepared. " + "Please call prepare_refit_info when initializing the worker." + ) + try: - for name, (shape, dtype) in info.items(): + for name, (shape, dtype) in self.state_dict_info.items(): weight = torch.empty(shape, dtype=dtype, device="cuda") self.model_update_group.broadcast(weight, src=0) self.model_runner.model.load_weights(weights=[(name, weight)]) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 6be85d9e0d..bb95b7623b 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -139,13 +139,13 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + self.is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. - if ( - "generation" in config - and config["generation"] is not None - and not config["generation"]["colocated"]["enabled"] - ): + if not self.is_generation_colocated: os.environ["NCCL_CUMEM_ENABLE"] = "1" # Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+) @@ -897,6 +897,26 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) + @torch.no_grad() + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + state_dict = self.model.state_dict() + + if self.is_generation_colocated: + # Collect info for streaming multiple tensors + self.refit_param_info = [] + for name, tensor in self.model.state_dict(): + # dtensor's numel will return complete tensor instead of only local tensor + size_in_bytes = tensor.element_size() * tensor.numel() + self.refit_param_info.append((name, size_in_bytes)) + + else: + # Collect info for collective communication + state_dict_info = {} + for name, tensor in state_dict.items(): + state_dict_info[name] = (tensor.shape, self.dtype) + + return state_dict_info + @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare the weights for IPC. @@ -917,13 +937,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: self.model.state_dict() ) - # Collect info for streaming multiple tensors - state_dict_info = [] - for name, tensor in self._held_sharded_state_dict_reference.items(): - # dtensor's numel will return complete tensor instead of only local tensor - size_in_bytes = tensor.element_size() * tensor.numel() - state_dict_info.append((name, size_in_bytes)) - # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() @@ -932,7 +945,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: ## Use 80% of the free memory for safety total_available_bytes *= 0.8 - return state_dict_info, total_available_bytes + return self.refit_param_info, total_available_bytes @torch.no_grad() def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: @@ -975,24 +988,6 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: return {device_uuid: serialized} - @torch.no_grad() - def prepare_info_for_collective(self) -> dict[str, Any]: - """Prepare the info for collective communication. - - Returns: - dict: A dictionary containing the info for collective communication. - """ - # Get state_dict - self.model = self.move_to_cuda(self.model) - state_dict = self.model.state_dict() - - # Collect info for collective communication - state_dict_info = {} - for name, tensor in state_dict.items(): - state_dict_info[name] = (tensor.shape, self.dtype) - - return state_dict_info - @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 614340c67b..d63b66e735 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, TypedDict +from typing import Any, Optional, TypedDict import ray import torch @@ -109,15 +109,15 @@ def offload_after_refit(self) -> None: pass @abstractmethod - def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: + def prepare_refit_info(self) -> Optional[dict[str, Any]]: pass @abstractmethod - def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: pass @abstractmethod - def prepare_info_for_collective(self) -> dict[str, Any]: + def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: pass @abstractmethod diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5e82b61d72..c77f2460e7 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -405,6 +405,17 @@ def finish_training(self, *args: Any, **kwargs: Any) -> None: # Placeholder implementation pass + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare the info for refit. + + Returns: + dict: A dictionary containing the info for refit. + """ + futures = self.worker_group.run_all_workers_single_data("prepare_refit_info") + results = ray.get(futures) + # Only get the first worker's info since all workers will have the same result + return results[0] + def prepare_weights_for_ipc( self, _refit_buffer_size_gb: Optional[int] = None ) -> list[list[str]]: @@ -469,19 +480,6 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: return all_handles - def prepare_info_for_collective(self) -> dict[str, Any]: - """Prepare the info for collective communication. - - Returns: - dict: A dictionary containing the info for collective communication. - """ - futures = self.worker_group.run_all_workers_single_data( - "prepare_info_for_collective" - ) - results = ray.get(futures) - # Only get the first worker's info since all workers will have the same result - return results[0] - def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" futures = self.worker_group.run_all_workers_single_data( diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 4555430ac6..a63b39f053 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1258,6 +1258,11 @@ def report_device_id(self) -> str: return get_device_uuid(device_idx) @torch.no_grad() + def prepare_refit_info(self) -> None: + # Get parameter info for refit + # param_info: list of ((name, shape, dtype), size_in_bytes) tuples + self.refit_param_info = get_param_info(self.model, self.dtype) + def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1266,13 +1271,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes - # Ensure model is in evaluation mode - self.model.eval() - - # Get parameter info for refit - # param_info: list of ((name, shape, dtype), size_in_bytes) tuples - param_info = get_param_info(self.model, self.dtype) - # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() @@ -1282,7 +1280,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: # more buckets seems to have better perf total_available_bytes *= 0.1 - return param_info, total_available_bytes + return self.refit_param_info, total_available_bytes # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() From 42613cab689e035b13f65f594cdeb1407336a248 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 4 Jul 2025 04:35:14 +0000 Subject: [PATCH 03/14] rename update_weights Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 4 +++- nemo_rl/models/generation/interfaces.py | 2 +- nemo_rl/models/generation/vllm.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index dc32f9bab8..4252861e78 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -428,7 +428,9 @@ def refit_policy_generation( # do update for keys in grouped_param_keys: ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights(ipc_handles) + update_success = policy_generation.update_weights_from_ipc_handles( + ipc_handles + ) if not update_success: break else: diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 652f3f9ca5..665473cc03 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -232,7 +232,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError - def update_weights(self, ipc_handles: dict[str, Any]) -> bool: + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update the model weights from the given IPC handles.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 7d020c05e3..bcf6140bd5 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1932,7 +1932,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) - def update_weights(self, ipc_handles: dict[str, Any]) -> bool: + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights of the policy using IPC handles, considering tensor parallelism. For tp > 1, only the leader in each tensor parallel tied worker group will update weights. From 26af0373baf150328ff4dd755204901cbeefbc9f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 10 Jul 2025 09:52:50 +0000 Subject: [PATCH 04/14] fix rebase Signed-off-by: Yuki Huang --- nemo_rl/models/megatron/refit_utils.py | 52 +++++++++++++++---- .../models/policy/megatron_policy_worker.py | 20 ++++--- 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index bceec907a6..f96c6b7537 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -164,11 +164,21 @@ def gather_params(model, keys, key_to_global_keys: Dict[str, List[str]]): def get_param_info(model, dtype): # Get parallel info tp_group = parallel_state.get_tensor_model_parallel_group() + tp_world_size = torch.distributed.get_world_size(tp_group) tp_group_rank_ids = get_process_group_ranks(tp_group) + etp_group = parallel_state.get_expert_tensor_parallel_group() + etp_world_size = torch.distributed.get_world_size(etp_group) + etp_group_rank_ids = get_process_group_ranks(etp_group) + pp_group = parallel_state.get_pipeline_model_parallel_group() pp_world_size = torch.distributed.get_world_size(pp_group) pp_group_rank_ids = get_process_group_ranks(pp_group) + pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() + + ep_group = parallel_state.get_expert_model_parallel_group() + ep_world_size = torch.distributed.get_world_size(ep_group) + ep_group_rank_ids = get_process_group_ranks(ep_group) # Collect parameter info param_info = [] @@ -178,20 +188,33 @@ def get_param_info(model, dtype): # Process each parameter in the model # state_dict includes parameters and persistent buffers + ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") for name, param in model.state_dict().items(): # Skip _extra_state entries (these are metadata, not actual weights) if "_extra_state" in name: continue + use_etp = True if ep_pattern.search(name) else False + if use_etp: + tensor_mp_rank_ids = etp_group_rank_ids + else: + tensor_mp_rank_ids = tp_group_rank_ids + shape = list(param.shape) tp_dim = get_tp_dim(model, name, named_modules_dict) if tp_dim is not None: - tp_rank_ids = tuple(sorted(tp_group_rank_ids)) + tp_rank_ids = tuple(sorted(tensor_mp_rank_ids)) shape[tp_dim] *= len(tp_rank_ids) else: tp_rank_ids = (torch.distributed.get_rank(),) pp_rank_ids = tuple(sorted(pp_group_rank_ids)) + ep_rank_ids = tuple(sorted(ep_group_rank_ids)) + + if ep_pattern.search(name): + ep_rank_ids = tuple(sorted(ep_group_rank_ids)) + else: + ep_rank_ids = (torch.distributed.get_rank(),) # Calculate size for this parameter prec_to_bytes = { @@ -203,20 +226,24 @@ def get_param_info(model, dtype): size_in_bytes = ( param.element_size() * param.numel() - * len(tp_rank_ids) - * len(pp_rank_ids) + * len(tensor_mp_rank_ids) + * len(ep_rank_ids) * scale ) param_info.append( ( ( name, + pp_local_rank_id, tuple(shape), param.dtype, ), size_in_bytes, ) ) + # Gather parameter info from all pipeline parallel ranks to ensure complete coverage + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) # Gather all parameter info from all PP ranks pp_gathered_param_infos = [None] * pp_world_size @@ -225,7 +252,16 @@ def get_param_info(model, dtype): ) pp_gathered_param_infos = [x for y in pp_gathered_param_infos for x in y] # type: ignore - all_param_infos = pp_gathered_param_infos + # Gather parameter info from all expert parallel ranks to ensure complete coverage + ep_group = parallel_state.get_expert_model_parallel_group() + ep_world_size = torch.distributed.get_world_size(ep_group) + + # Gather all parameter info from all EP ranks + ep_gathered_param_infos = [None] * ep_world_size + torch.distributed.all_gather_object( + ep_gathered_param_infos, pp_gathered_param_infos, group=ep_group + ) + all_param_infos = [x for y in ep_gathered_param_infos for x in y] # Merge all parameter infos, keeping only unique parameter names merged_param_info = [] @@ -244,7 +280,7 @@ def get_param_info(model, dtype): @torch.no_grad() -def get_local_key_to_global_keys(self, state_dict_info: List[Tuple[Any, int]]): +def get_local_key_to_global_keys(model, state_dict_info: List[Tuple[Any, int]]): """Get the local key to global keys mapping.""" # Get parallel info tp_group = parallel_state.get_tensor_model_parallel_group() @@ -260,7 +296,7 @@ def get_local_key_to_global_keys(self, state_dict_info: List[Tuple[Any, int]]): # start calculating the global key ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") - state_dict = self.model.state_dict() + state_dict = model.state_dict() final_key_to_global_keys = {} for param_info, size in state_dict_info: @@ -271,9 +307,7 @@ def get_local_key_to_global_keys(self, state_dict_info: List[Tuple[Any, int]]): # else: not sharded (like embedding) pp_gathered_objs = [None] if local_key in state_dict and owner_pp_local_rank_id == pp_local_rank_id: - pp_gathered_objs[0] = get_global_key_from_local_key( - local_key, self.model.config - ) + pp_gathered_objs[0] = get_global_key_from_local_key(local_key, model.config) # Step 2: gather global keys from ranks in PP group src_global_rank = pp_global_ranks[owner_pp_local_rank_id] diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index a63b39f053..51dd74b118 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -13,13 +13,12 @@ # limitations under the License. import gc import os -import re import time import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial -from typing import Any, Iterator, List, Optional, Tuple, TypeVar +from typing import Any, Iterator, Optional, TypeVar import ray import torch @@ -102,8 +101,8 @@ from nemo_rl.models.megatron.converters.common import MegatronToHFConverter from nemo_rl.models.megatron.refit_utils import ( gather_params, - get_param_info, get_local_key_to_global_keys, + get_param_info, ) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -679,11 +678,6 @@ def __init__( self._held_gather_buffer = None self.megatron_to_hf_converter = MegatronToHFConverter(hf_model_name, self.model) - # Create a map that maps any local parameter name to a list of global parameter names. - # This map is repeatedly used by parameter gatherring phase during refit of every step. - self.local_key_to_global_keys = get_local_key_to_global_keys( - state_dict_info=self.prepare_weights_for_ipc()[0] - ) self.should_disable_forward_pre_hook = ( self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"] and self.cfg["megatron_cfg"]["distributed_data_parallel_config"][ @@ -691,6 +685,10 @@ def __init__( ] ) + # refit stuff, will be initialized in prepare_refit_info + self.refit_param_info = None + self.local_key_to_global_keys = None + def is_alive(self): return True @@ -1263,6 +1261,12 @@ def prepare_refit_info(self) -> None: # param_info: list of ((name, shape, dtype), size_in_bytes) tuples self.refit_param_info = get_param_info(self.model, self.dtype) + # Create a map that maps any local parameter name to a list of global parameter names. + # This map is repeatedly used by parameter gatherring phase during refit of every step. + self.local_key_to_global_keys = get_local_key_to_global_keys( + self.model, state_dict_info=self.refit_param_info + ) + def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. From d8e48f3395538b877c695d66d554b7be9326e92d Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 11 Jul 2025 04:11:11 +0000 Subject: [PATCH 05/14] add metainfo in prepare_refit_info Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 4 +- nemo_rl/models/generation/vllm_backend.py | 21 ++++++++--- .../models/policy/megatron_policy_worker.py | 37 ++++++++++++++----- 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 4252861e78..530ad4e8f9 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -348,10 +348,8 @@ def setup( ray.get(futures_train + futures_inference) # prepare refit info - # state_dict_info: {tensor_name: (shape, dtype)} state_dict_info = policy.prepare_refit_info() - if not colocated_inference: - policy_generation.prepare_refit_info(state_dict_info) + policy_generation.prepare_refit_info(state_dict_info) loss_fn = ClippedPGLossFn(loss_config) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 9b2b97e6a6..68fdd6c388 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any +from typing import Any, Optional import torch @@ -51,9 +51,19 @@ def report_device_id(self) -> str: return get_device_uuid(self.device.index) - def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: - """Prepare the info for refit.""" - # state_dict_info: {tensor_name: (shape, dtype)} + def prepare_refit_info( + self, state_dict_info: Optional[dict[str, Any]] = None + ) -> None: + """Prepare the info for refit. + + DtensorPolicyWorker: + colocated inference: state_dict_info is None + non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} + + MegatronPolicyWorker: + colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} + non-colocated inference: not implemented yet + """ self.state_dict_info = state_dict_info def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): @@ -100,7 +110,8 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Unpack tensor to weights. Here we only return a view of the tensor to avoid # using extra memory. - for key, (shape, dtype, offset, size) in tensor_metadata.items(): + for key, offset in tensor_metadata.items(): + shape, dtype, size = self.state_dict_info[key] tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view( *shape ) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 51dd74b118..42014355fd 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1267,6 +1267,29 @@ def prepare_refit_info(self) -> None: self.model, state_dict_info=self.refit_param_info ) + # Collect tensor metadata for refit + state_dict_info = {} + for key, _ in self.refit_param_info: + # gather megatron params + gathered_megatron_params = gather_params( + self.model, + [key], + key_to_global_keys=self.local_key_to_global_keys, + ) + # convert to hf params + gathered_hf_params = self.megatron_to_hf_converter.convert( + gathered_megatron_params, self.model.config + ) + # collect tensor metadata + for name, tensor in gathered_hf_params.items(): + state_dict_info[name] = ( + tensor.shape, + tensor.dtype, + tensor.numel(), + ) + + return state_dict_info + def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1329,14 +1352,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: type_to_total_size = defaultdict(lambda: 0) tensor_metadata = dict() + # Record offset of the tensor for key, tensor in gathered_hf_params.items(): - tensor_metadata[key] = ( - tensor.shape, # shape of the tensor - tensor.dtype, # dtype of the tensor - type_to_total_size[tensor.dtype], # offset of the tensor - # in packed buffer - tensor.numel(), # size of the tensor - ) + tensor_metadata[key] = type_to_total_size[tensor.dtype] type_to_total_size[tensor.dtype] += tensor.numel() # Allocate consolidated tensors for each dtype @@ -1352,8 +1370,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Copy tensors into consolidated buffers for key, tensor in gathered_hf_params.items(): - metadata = tensor_metadata[key] - _, dtype, offset, size = metadata + offset = tensor_metadata[key] + dtype = tensor.dtype + size = tensor.numel() packed_tensors[dtype][offset : offset + size].copy_( tensor.detach().view(-1) ) From eb6b910bbb4508a818d98c9b6e24c637b5a8dce4 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 11 Jul 2025 07:21:01 +0000 Subject: [PATCH 06/14] some fix and update unit test with prepare_refit_info Signed-off-by: Yuki Huang --- nemo_rl/models/generation/vllm.py | 5 +- .../models/policy/dtensor_policy_worker.py | 2 +- .../models/generation/test_vllm_generation.py | 47 ++++++++++++++++++- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index bcf6140bd5..7a382c0bda 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1926,12 +1926,15 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: ) # Use run_all_workers_single_data to send data to all workers - self.worker_group.run_all_workers_single_data( + futures = self.worker_group.run_all_workers_single_data( method_name, state_dict_info=state_dict_info, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) + # Wait for all futures to complete + ray.get(futures) + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights of the policy using IPC handles, considering tensor parallelism. diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index bb95b7623b..78297ae3e6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -904,7 +904,7 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: if self.is_generation_colocated: # Collect info for streaming multiple tensors self.refit_param_info = [] - for name, tensor in self.model.state_dict(): + for name, tensor in state_dict.items(): # dtensor's numel will return complete tensor instead of only local tensor size_in_bytes = tensor.element_size() * tensor.numel() self.refit_param_info.append((name, size_in_bytes)) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 9dc948b909..b7159f4129 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -418,11 +418,18 @@ async def test_vllm_policy_generation_async( dtensor_config = basic_dtensor_test_config from nemo_rl.models.policy.lm_policy import Policy + print("creating vllm policy...") async_policy = VllmGeneration(cluster, vllm_config) async_policy.finish_generation() - print("creating hf policy...") + print("creating lm policy...") lm_policy = Policy(cluster, dtensor_config, tokenizer) + + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + async_policy.prepare_refit_info(state_dict_info) + + print("refitting vllm policy...") refit_policy_generation( lm_policy, async_policy, vllm_config["colocated"]["enabled"] ) @@ -520,6 +527,9 @@ def test_vllm_worker_seed_behavior(cluster, tokenizer): dtensor_config = basic_dtensor_test_config lm_policy = Policy(cluster, dtensor_config, tokenizer) + state_dict_info = lm_policy.prepare_refit_info() + policy.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation(lm_policy, policy, vllm_config["colocated"]["enabled"]) @@ -674,6 +684,10 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine print("Creating DTensor policy...") lm_policy = Policy(cluster, dtensor_config, tokenizer) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation( lm_policy, vllm_policy, vllm_config["colocated"]["enabled"] @@ -922,9 +936,14 @@ def test_vllm_weight_update_and_prefix_cache_reset( try: print(f"Creating DTensor policy for TP={tensor_parallel_size}...") lm_policy = Policy(cluster, dtensor_config, tokenizer) + print(f"Creating vLLM policy for TP={tensor_parallel_size}...") vllm_policy = VllmGeneration(cluster, vllm_config) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + # Prepare input data (batch size 2) text = """Answer the question based on the context below. Keep the answer short and concise. Respond "Unsure about answer" if not sure about the answer. Context: Teplizumab traces its roots to a New Jersey drug company called Ortho Pharmaceutical. There, scientists generated an early version of the antibody, dubbed OKT3. Originally sourced from mice, the molecule was able to bind to the surface of T cells and limit their cell-killing potential. In 1986, it was approved to help prevent organ rejection after kidney transplants, making it the first therapeutic antibody allowed for human use.Question: What was OKT3 originally sourced from?Answer:""" test_prompt = [text, text] # Use batch size 2 @@ -962,7 +981,7 @@ def test_vllm_weight_update_and_prefix_cache_reset( grouped_param_keys = lm_policy.prepare_weights_for_ipc() for keys in grouped_param_keys: ipc_handles = lm_policy.get_weights_ipc_handles(keys) - update_success = vllm_policy.update_weights(ipc_handles) + update_success = vllm_policy.update_weights_from_ipc_handles(ipc_handles) assert update_success, "Weight update should succeed" print("vLLM weights successfully updated.") @@ -1027,6 +1046,10 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor): dtensor_config = basic_dtensor_test_config lm_policy = Policy(cluster, dtensor_config, tokenizer) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") # take it outside statistics to get clean peak memory during refit lm_policy.offload_before_refit() @@ -1104,6 +1127,10 @@ def test_vllm_generation_with_stop( dtensor_config = basic_dtensor_test_config lm_policy = Policy(cluster, dtensor_config, tokenizer) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_generation.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation( lm_policy, vllm_generation, vllm_config["colocated"]["enabled"] @@ -1211,6 +1238,10 @@ async def test_vllm_refit_non_collocated_update_weights( futures_inference = vllm_generation.init_collective(ip, port, world_size=2) ray.get(futures_train + futures_inference) + # prepare refit info + state_dict_info = lm_policy.prepare_refit_info() + vllm_generation.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation( lm_policy, vllm_generation, vllm_config["colocated"]["enabled"] @@ -1309,6 +1340,10 @@ def test_vllm_generation_with_megatron_training( print("Creating Megatron policy...") megatron_policy = Policy(cluster, megatron_config, test_tokenizer) + print("preparing refit info...") + state_dict_info = megatron_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("Refitting vLLM policy with Megatron weights...") refit_policy_generation( megatron_policy, vllm_policy, vllm_config["colocated"]["enabled"] @@ -1428,6 +1463,10 @@ def test_vllm_megatron_weight_update_memory(cluster, tokenizer): print("Creating Megatron policy...") megatron_policy = Policy(cluster, megatron_config, test_tokenizer) + print("preparing refit info...") + state_dict_info = megatron_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("Refitting vLLM policy with Megatron...") # Take it outside statistics to get clean peak memory during refit megatron_policy.offload_before_refit() @@ -1531,6 +1570,10 @@ def test_vllm_megatron_pipeline_parallel(cluster, tokenizer): vllm_policy = VllmGeneration(cluster, vllm_config) vllm_policy.finish_generation() + print("preparing refit info...") + state_dict_info = megatron_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("Refitting vLLM with Megatron PP=2 weights...") refit_policy_generation( megatron_policy, vllm_policy, vllm_config["colocated"]["enabled"] From 23c5d3a1f21f379e5a50a19014ecd85ba44a7d8f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sat, 12 Jul 2025 05:33:27 -0700 Subject: [PATCH 07/14] support update dtype record during training Signed-off-by: Yuki Huang --- nemo_rl/models/generation/vllm_backend.py | 14 ++++++++-- .../models/policy/megatron_policy_worker.py | 26 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 68fdd6c388..124cbb2af5 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -110,8 +110,18 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Unpack tensor to weights. Here we only return a view of the tensor to avoid # using extra memory. - for key, offset in tensor_metadata.items(): - shape, dtype, size = self.state_dict_info[key] + for key, metadata in tensor_metadata.items(): + # dtype for the 1st and 2nd steps may be different + # e.g. model.layers.4.mlp.gate.e_score_correction_bias + if isinstance(metadata, tuple): + # use dtype of current step + offset, dtype = metadata + shape, _, size = self.state_dict_info[key] + # update record + self.state_dict_info[key] = (shape, dtype, size) + else: + offset = metadata + shape, dtype, size = self.state_dict_info[key] tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view( *shape ) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 42014355fd..cca8d62d29 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -687,6 +687,7 @@ def __init__( # refit stuff, will be initialized in prepare_refit_info self.refit_param_info = None + self.state_dict_info = None self.local_key_to_global_keys = None def is_alive(self): @@ -1268,7 +1269,7 @@ def prepare_refit_info(self) -> None: ) # Collect tensor metadata for refit - state_dict_info = {} + self.state_dict_info = {} for key, _ in self.refit_param_info: # gather megatron params gathered_megatron_params = gather_params( @@ -1282,13 +1283,13 @@ def prepare_refit_info(self) -> None: ) # collect tensor metadata for name, tensor in gathered_hf_params.items(): - state_dict_info[name] = ( + self.state_dict_info[name] = ( tensor.shape, tensor.dtype, tensor.numel(), ) - return state_dict_info + return self.state_dict_info def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1354,7 +1355,22 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Record offset of the tensor for key, tensor in gathered_hf_params.items(): - tensor_metadata[key] = type_to_total_size[tensor.dtype] + # dtype for the 1st and 2nd steps may be different + # e.g. model.layers.4.mlp.gate.e_score_correction_bias + if tensor.dtype == self.state_dict_info[key][1]: + tensor_metadata[key] = type_to_total_size[tensor.dtype] + else: + # also send dtype if it changes + tensor_metadata[key] = ( + type_to_total_size[tensor.dtype], + tensor.dtype, + ) + # update record + self.state_dict_info[key] = ( + tensor.shape, + tensor.dtype, + tensor.numel(), + ) type_to_total_size[tensor.dtype] += tensor.numel() # Allocate consolidated tensors for each dtype @@ -1371,6 +1387,8 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Copy tensors into consolidated buffers for key, tensor in gathered_hf_params.items(): offset = tensor_metadata[key] + if isinstance(offset, tuple): + offset, _ = offset dtype = tensor.dtype size = tensor.numel() packed_tensors[dtype][offset : offset + size].copy_( From 687c868369371e4504093e31c820d9d36a823276 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 13 Jul 2025 00:20:20 -0700 Subject: [PATCH 08/14] not cache refit_param_info since dtype may change Signed-off-by: Yuki Huang --- nemo_rl/models/generation/vllm_backend.py | 3 +-- .../models/policy/megatron_policy_worker.py | 23 ++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 124cbb2af5..02de2aacc5 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -111,8 +111,7 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Unpack tensor to weights. Here we only return a view of the tensor to avoid # using extra memory. for key, metadata in tensor_metadata.items(): - # dtype for the 1st and 2nd steps may be different - # e.g. model.layers.4.mlp.gate.e_score_correction_bias + # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) if isinstance(metadata, tuple): # use dtype of current step offset, dtype = metadata diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index cca8d62d29..342476756e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -686,7 +686,6 @@ def __init__( ) # refit stuff, will be initialized in prepare_refit_info - self.refit_param_info = None self.state_dict_info = None self.local_key_to_global_keys = None @@ -1259,18 +1258,20 @@ def report_device_id(self) -> str: @torch.no_grad() def prepare_refit_info(self) -> None: # Get parameter info for refit - # param_info: list of ((name, shape, dtype), size_in_bytes) tuples - self.refit_param_info = get_param_info(self.model, self.dtype) + ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples + # Cannot cache refit_param_info since dtype for the 1st and 2nd steps may be different + ## e.g. e_score_correction_bias + refit_param_info = get_param_info(self.model, self.dtype) # Create a map that maps any local parameter name to a list of global parameter names. # This map is repeatedly used by parameter gatherring phase during refit of every step. self.local_key_to_global_keys = get_local_key_to_global_keys( - self.model, state_dict_info=self.refit_param_info + self.model, state_dict_info=refit_param_info ) # Collect tensor metadata for refit self.state_dict_info = {} - for key, _ in self.refit_param_info: + for key, _ in refit_param_info: # gather megatron params gathered_megatron_params = gather_params( self.model, @@ -1291,6 +1292,7 @@ def prepare_refit_info(self) -> None: return self.state_dict_info + @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1299,6 +1301,12 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes + # Get parameter info for refit + ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples + # Cannot cache refit_param_info since dtype for the 1st and 2nd steps may be different + ## e.g. e_score_correction_bias + refit_param_info = get_param_info(self.model, self.dtype) + # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() @@ -1308,7 +1316,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: # more buckets seems to have better perf total_available_bytes *= 0.1 - return self.refit_param_info, total_available_bytes + return refit_param_info, total_available_bytes # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() @@ -1355,8 +1363,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Record offset of the tensor for key, tensor in gathered_hf_params.items(): - # dtype for the 1st and 2nd steps may be different - # e.g. model.layers.4.mlp.gate.e_score_correction_bias + # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) if tensor.dtype == self.state_dict_info[key][1]: tensor_metadata[key] = type_to_total_size[tensor.dtype] else: From cabac3b1ec9b51c99d6ea7c37e7a8424690e92d2 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 13 Jul 2025 08:13:11 +0000 Subject: [PATCH 09/14] add unit test Signed-off-by: Yuki Huang --- .../models/generation/test_vllm_generation.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index b7159f4129..8d6cff05f8 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -1604,3 +1604,58 @@ def test_vllm_megatron_pipeline_parallel(cluster, tokenizer): vllm_policy.shutdown() if megatron_policy: megatron_policy.shutdown() + + +def test_vllm_megatron_weight_update_with_packing(cluster, test_input_data): + megatron_policy = None + vllm_generation = None + + try: + # Enable packing during test + os.environ["NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD"] = "1" + + # Both policies must use the same model (Qwen2.5-0.5B) for weight transfer compatibility + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = get_tokenizer({"name": model_name}) + + # Create Policy + megatron_config = get_basic_megatron_test_config( + tp=1, pp=1, precision="float32" + ) + megatron_config["model_name"] = model_name + megatron_config["tokenizer"]["name"] = model_name + megatron_policy = Policy(cluster, megatron_config, tokenizer) + + # Create VllmGeneration + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) + vllm_config["model_name"] = model_name + vllm_config["tokenizer"]["name"] = model_name + vllm_generation = VllmGeneration(cluster, vllm_config) + + # prepare refit info + state_dict_info = megatron_policy.prepare_refit_info() + vllm_generation.prepare_refit_info(state_dict_info) + + print("refitting vllm policy...") + refit_policy_generation( + megatron_policy, vllm_generation, vllm_config["colocated"]["enabled"] + ) + + # test generate + outputs = vllm_generation.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == [ + "Hello, my name is John. I am a", + "The capital of France is Paris. It is the", + ], "Output should be the same as the expected output" + + finally: + # Restore the original value + os.environ.pop("NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", None) + # Clean up + if megatron_policy: + megatron_policy.shutdown() + if vllm_generation: + vllm_generation.shutdown() From a88ca8dfa56bbafd294fd9b2a6b813456254edf0 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 13 Jul 2025 10:56:17 +0000 Subject: [PATCH 10/14] rename some vars Signed-off-by: Yuki Huang --- .../models/policy/dtensor_policy_worker.py | 15 +++++---- .../models/policy/megatron_policy_worker.py | 32 ++++++++++--------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 78297ae3e6..a78d2f1f4c 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -289,12 +289,6 @@ def __init__( if self.cpu_offload: self.model = self.move_to_device(self.model, "cpu") - # used for streaming update inference engine weights - self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( - None - ) - self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None - if init_reference_model: self.reference_model_state_dict = get_cpu_state_dict( self.model.state_dict().items(), pin_memory=True @@ -350,6 +344,15 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) + # vars used for refit + ## will be initialized in prepare_refit_info + self.refit_param_info = None + ## used for streaming update inference engine weights + self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( + None + ) + self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None + # Refer to nemo impl. Below is original comment. # based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L113 @staticmethod diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 342476756e..852c1376f6 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -675,7 +675,6 @@ def __init__( ) self.final_padded_vocab_size = tokenizer_config.padded_vocab_size self.dp_size = worker_sharding_annotations.get_axis_size("data_parallel") - self._held_gather_buffer = None self.megatron_to_hf_converter = MegatronToHFConverter(hf_model_name, self.model) self.should_disable_forward_pre_hook = ( @@ -685,9 +684,12 @@ def __init__( ] ) - # refit stuff, will be initialized in prepare_refit_info - self.state_dict_info = None + # vars used for refit + ## will be initialized in prepare_refit_info + self.refit_param_info_hf = None self.local_key_to_global_keys = None + ## used for streaming update inference engine weights + self._held_gather_buffer = None def is_alive(self): return True @@ -1259,19 +1261,19 @@ def report_device_id(self) -> str: def prepare_refit_info(self) -> None: # Get parameter info for refit ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - # Cannot cache refit_param_info since dtype for the 1st and 2nd steps may be different + # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different ## e.g. e_score_correction_bias - refit_param_info = get_param_info(self.model, self.dtype) + refit_param_info_mcore = get_param_info(self.model, self.dtype) # Create a map that maps any local parameter name to a list of global parameter names. # This map is repeatedly used by parameter gatherring phase during refit of every step. self.local_key_to_global_keys = get_local_key_to_global_keys( - self.model, state_dict_info=refit_param_info + self.model, state_dict_info=refit_param_info_mcore ) # Collect tensor metadata for refit - self.state_dict_info = {} - for key, _ in refit_param_info: + self.refit_param_info_hf = {} + for key, _ in refit_param_info_mcore: # gather megatron params gathered_megatron_params = gather_params( self.model, @@ -1284,13 +1286,13 @@ def prepare_refit_info(self) -> None: ) # collect tensor metadata for name, tensor in gathered_hf_params.items(): - self.state_dict_info[name] = ( + self.refit_param_info_hf[name] = ( tensor.shape, tensor.dtype, tensor.numel(), ) - return self.state_dict_info + return self.refit_param_info_hf @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: @@ -1303,9 +1305,9 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: # Get parameter info for refit ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - # Cannot cache refit_param_info since dtype for the 1st and 2nd steps may be different + # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different ## e.g. e_score_correction_bias - refit_param_info = get_param_info(self.model, self.dtype) + refit_param_info_mcore = get_param_info(self.model, self.dtype) # Collect current available memory for refit ## Get current device index from torch @@ -1316,7 +1318,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: # more buckets seems to have better perf total_available_bytes *= 0.1 - return refit_param_info, total_available_bytes + return refit_param_info_mcore, total_available_bytes # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() @@ -1364,7 +1366,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Record offset of the tensor for key, tensor in gathered_hf_params.items(): # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) - if tensor.dtype == self.state_dict_info[key][1]: + if tensor.dtype == self.refit_param_info_hf[key][1]: tensor_metadata[key] = type_to_total_size[tensor.dtype] else: # also send dtype if it changes @@ -1373,7 +1375,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: tensor.dtype, ) # update record - self.state_dict_info[key] = ( + self.refit_param_info_hf[key] = ( tensor.shape, tensor.dtype, tensor.numel(), From f152f61c03a2173c8e75951653b8db58b5685bf8 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 15 Jul 2025 03:13:38 +0000 Subject: [PATCH 11/14] add NRL_REFIT_BUFFER_MEMORY_RATIO, update default from 10% to 20% in mcore for speedup Signed-off-by: Yuki Huang --- nemo_rl/models/policy/dtensor_policy_worker.py | 3 ++- nemo_rl/models/policy/megatron_policy_worker.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index a78d2f1f4c..55bbfeb170 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -946,7 +946,8 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: ## Get device free memory using NVML total_available_bytes = get_free_memory_bytes(device_idx) ## Use 80% of the free memory for safety - total_available_bytes *= 0.8 + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") + total_available_bytes *= float(memory_ratio) return self.refit_param_info, total_available_bytes diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 852c1376f6..1b0b6c9b63 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1314,9 +1314,9 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: device_idx = torch.cuda.current_device() ## Get device free memory using NVML total_available_bytes = get_free_memory_bytes(device_idx) - # TODO: setting to low value (10%) since - # more buckets seems to have better perf - total_available_bytes *= 0.1 + ## default to 20% to get some more speedup than 10%, OOM if set to 30% + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") + total_available_bytes *= float(memory_ratio) return refit_param_info_mcore, total_available_bytes From 1da2f5e6ed8cf68dafcfc089d62ed1dace27cbe7 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 15 Jul 2025 19:43:09 -0700 Subject: [PATCH 12/14] fix: maintain fp32 mlp.router.expert_bias even with bf16 enabled Signed-off-by: Zhiyu Li --- 3rdparty/NeMo-workspace/NeMo | 2 +- nemo_rl/models/policy/megatron_policy_worker.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 33259f2540..8ddf438734 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 33259f2540af6eef375d43fc48bdcbd7ec490c29 +Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 1b0b6c9b63..666a0dd704 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -27,6 +27,7 @@ from megatron.core.distributed.custom_fsdp import ( FullyShardedDataParallel as custom_FSDP, ) +from megatron.core.transformer.module import Float16Module from megatron.core.inference.engines import ( StaticInferenceEngine, ) @@ -178,11 +179,25 @@ def setup_megatron_model( if policy_cfg["megatron_cfg"]["freeze_moe_router"]: def freeze_moe_router(model_module): + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): layer.mlp.router.weight.requires_grad = False + # Re-enable float32 expert bias for moe router to avoid parameter dtype inconsistency + # see https://github.com/NVIDIA/Megatron-LM/blob/e6c510ff3c1159f8955589b26f7c395bdf0607d9/megatron/core/transformer/moe/router.py#L149 + def re_enable_float32_expert_bias(model_module): + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module + for layer in model_module.decoder.layers: + if hasattr(layer.mlp, "router"): + layer.mlp.router._maintain_float32_expert_bias() + model_post_init_fns.append(freeze_moe_router) + model_post_init_fns.append(re_enable_float32_expert_bias) # Model, optimizer, and learning rate. model = get_model_from_config( From a6018ab9538bb7d3c798afc9c76cdea7f3a5a868 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 15 Jul 2025 19:01:31 -0700 Subject: [PATCH 13/14] avoid serializing rebuild_cuda_tensor function Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm_backend.py | 7 +++++-- nemo_rl/models/policy/megatron_policy_worker.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 02de2aacc5..079165fec1 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -15,6 +15,7 @@ from typing import Any, Optional import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor try: import vllm # noqa: F401 @@ -102,7 +103,8 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Extract packed tensor from IPC handle dtype_to_packed_tensor = {} for dtype, tensor_handle in all_handles: - func, args = tensor_handle + func = rebuild_cuda_tensor + args = tensor_handle[0] list_args = list(args) list_args[6] = device_id tensor = func(*list_args) @@ -128,7 +130,8 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): else: # Process each handle to get the tensor for name, handle in name_and_handle_list: - func, args = handle + func = rebuild_cuda_tensor + args = handle[0] list_args = list(args) list_args[6] = device_id tensor = func(*list_args) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 666a0dd704..7148d9a7d4 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1335,6 +1335,13 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: return refit_param_info_mcore, total_available_bytes + def get_handle_from_tensor(self, tensor: torch.Tensor) -> tuple[str, Any]: + """Get IPC handle from a tensor.""" + from torch.multiprocessing.reductions import reduce_tensor + + # skip serializing the function for better refit performance + return reduce_tensor(tensor.detach())[1:] + # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: @@ -1361,7 +1368,6 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Get device UUID for IPC handles device_uuid = self.report_device_id() - from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter tensor_number_threshold = os.getenv( @@ -1421,7 +1427,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Create IPC handles for consolidated tensors all_handles = [ - (dtype, reduce_tensor(tensor.detach())) + (dtype, self.get_handle_from_tensor(tensor)) for dtype, tensor in packed_tensors.items() ] @@ -1432,7 +1438,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: else: all_handles = [] for key, tensor in gathered_hf_params.items(): - handle = reduce_tensor(tensor.detach()) + handle = self.get_handle_from_tensor(tensor) all_handles.append((key, handle)) self._held_gather_buffer = gathered_hf_params serialized = (False, all_handles) From ad1d3ba7228d5a6284944dc76159e64be9bed07d Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 16 Jul 2025 05:09:40 -0700 Subject: [PATCH 14/14] assert dtype Signed-off-by: Yuki Huang --- nemo_rl/models/policy/megatron_policy_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 7148d9a7d4..7704f8c111 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -27,7 +27,6 @@ from megatron.core.distributed.custom_fsdp import ( FullyShardedDataParallel as custom_FSDP, ) -from megatron.core.transformer.module import Float16Module from megatron.core.inference.engines import ( StaticInferenceEngine, ) @@ -50,6 +49,7 @@ ) from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.transformer.module import Float16Module from megatron.inference.text_generation.mcore_engine_server import ( run_mcore_engine, ) @@ -1390,6 +1390,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: if tensor.dtype == self.refit_param_info_hf[key][1]: tensor_metadata[key] = type_to_total_size[tensor.dtype] else: + assert False, ( + f"{key} dtype mismatch: {tensor.dtype} vs {self.refit_param_info_hf[key][1]}" + ) # also send dtype if it changes tensor_metadata[key] = ( type_to_total_size[tensor.dtype],