diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d33e503636..530ad4e8f9 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -347,6 +347,10 @@ def setup( # wait for all futures to complete ray.get(futures_train + futures_inference) + # prepare refit info + state_dict_info = policy.prepare_refit_info() + policy_generation.prepare_refit_info(state_dict_info) + loss_fn = ClippedPGLossFn(loss_config) print("\n" + "=" * 60) @@ -422,17 +426,15 @@ def refit_policy_generation( # do update for keys in grouped_param_keys: ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights(ipc_handles) + update_success = policy_generation.update_weights_from_ipc_handles( + ipc_handles + ) if not update_success: break else: - # prepare info for update weights - state_dict_info = policy.prepare_info_for_collective() # update weights through nccl futures_train = policy.broadcast_weights_for_collective() - futures_inference = policy_generation.update_weights_from_collective( - state_dict_info - ) + futures_inference = policy_generation.update_weights_from_collective() # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 9db0c357fa..665473cc03 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -228,12 +228,14 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: def finish_generation(self, *args: Any, **kwargs: Any) -> bool: pass - def update_weights(self, ipc_handles: dict[str, Any]) -> bool: + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + raise NotImplementedError + + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update the model weights from the given IPC handles.""" raise NotImplementedError - def update_weights_from_collective( - self, info: dict[str, Any] - ) -> list[ray.ObjectRef]: + def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update the model weights from collective communication.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index a99022e30f..7a382c0bda 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1021,6 +1021,14 @@ async def report_device_id_async(self) -> list[str]: return cast(list[str], list_of_worker_results) + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + + async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None: + """Async version of prepare_refit_info.""" + await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights from IPC handles by delegating to the vLLM Worker implementation. @@ -1132,7 +1140,7 @@ async def update_weights_from_ipc_handles_async( traceback.print_exc() return False - def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" try: assert self.llm is not None, ( @@ -1145,7 +1153,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: ) result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=(info,) + "update_weights_from_collective", args=tuple() ) worker_result = result_or_coro[0] @@ -1162,7 +1170,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: traceback.print_exc() return False - async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bool: + async def update_weights_from_collective_async(self) -> bool: """Async version of update_weights_from_collective.""" try: assert self.llm is not None, ( @@ -1175,7 +1183,7 @@ async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bo ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_collective", args=(info,) + "update_weights_from_collective", args=tuple() ) if asyncio.iscoroutine(result_or_coro): @@ -1908,7 +1916,26 @@ def shutdown(self) -> bool: print(f"Error during policy shutdown: {e}") return False - def update_weights(self, ipc_handles: dict[str, Any]) -> bool: + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare the info for refit.""" + # Choose the appropriate method based on async_engine setting + method_name = ( + "prepare_refit_info_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "prepare_refit_info" + ) + + # Use run_all_workers_single_data to send data to all workers + futures = self.worker_group.run_all_workers_single_data( + method_name, + state_dict_info=state_dict_info, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + + # Wait for all futures to complete + ray.get(futures) + + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights of the policy using IPC handles, considering tensor parallelism. For tp > 1, only the leader in each tensor parallel tied worker group will update weights. @@ -1952,9 +1979,7 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: print(f"Error during update weights: {e}") return False - def update_weights_from_collective( - self, info: dict[str, Any] - ) -> list[ray.ObjectRef]: + def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -1966,10 +1991,9 @@ def update_weights_from_collective( else "update_weights_from_collective" ) - # Use run_all_workers_single_data to send data to all workers + # Use run_all_workers_single_data for methods that don't need data futures = self.worker_group.run_all_workers_single_data( method_name, - info=info, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 3d6ed0253c..fceea5b24f 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any +from typing import Any, Optional import torch @@ -51,6 +51,21 @@ def report_device_id(self) -> str: return get_device_uuid(self.device.index) + def prepare_refit_info( + self, state_dict_info: Optional[dict[str, Any]] = None + ) -> None: + """Prepare the info for refit. + + DtensorPolicyWorker: + colocated inference: state_dict_info is None + non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} + + MegatronPolicyWorker: + colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} + non-colocated inference: not implemented yet + """ + self.state_dict_info = state_dict_info + def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): """Update weights from global IPC handles. @@ -84,6 +99,11 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): weights = [] if is_tensor_packed: + assert self.state_dict_info is not None, ( + "state_dict_info is not prepared. " + "Please call prepare_refit_info when initializing the worker." + ) + # Extract packed tensor from IPC handle dtype_to_packed_tensor = {} for dtype, tensor_handle in all_handles: @@ -95,7 +115,17 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Unpack tensor to weights. Here we only return a view of the tensor to avoid # using extra memory. - for key, (shape, dtype, offset, size) in tensor_metadata.items(): + for key, metadata in tensor_metadata.items(): + # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) + if isinstance(metadata, tuple): + # use dtype of current step + offset, dtype = metadata + shape, _, size = self.state_dict_info[key] + # update record + self.state_dict_info[key] = (shape, dtype, size) + else: + offset = metadata + shape, dtype, size = self.state_dict_info[key] tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view( *shape ) @@ -118,10 +148,15 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): ) return False - def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" + assert self.state_dict_info is not None, ( + "state_dict_info is not prepared. " + "Please call prepare_refit_info when initializing the worker." + ) + try: - for name, (shape, dtype) in info.items(): + for name, (shape, dtype) in self.state_dict_info.items(): weight = torch.empty(shape, dtype=dtype, device="cuda") self.model_update_group.broadcast(weight, src=0) self.model_runner.model.load_weights(weights=[(name, weight)]) diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index fb46030ce9..f96c6b7537 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import re import time -from typing import Dict, List +from typing import Any, Dict, List, Tuple import torch from megatron.core import parallel_state @@ -28,6 +28,9 @@ RowParallelLinear, VocabParallelEmbedding, ) +from torch.distributed import get_process_group_ranks + +from nemo_rl.models.megatron.converters.common import get_global_key_from_local_key def get_tp_dim(model, param_name, named_modules_dict): @@ -155,3 +158,175 @@ def gather_params(model, keys, key_to_global_keys: Dict[str, List[str]]): print(f"Time taken to gather params: {time.perf_counter() - st}") return gathered_params + + +@torch.no_grad() +def get_param_info(model, dtype): + # Get parallel info + tp_group = parallel_state.get_tensor_model_parallel_group() + tp_world_size = torch.distributed.get_world_size(tp_group) + tp_group_rank_ids = get_process_group_ranks(tp_group) + + etp_group = parallel_state.get_expert_tensor_parallel_group() + etp_world_size = torch.distributed.get_world_size(etp_group) + etp_group_rank_ids = get_process_group_ranks(etp_group) + + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + pp_group_rank_ids = get_process_group_ranks(pp_group) + pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() + + ep_group = parallel_state.get_expert_model_parallel_group() + ep_world_size = torch.distributed.get_world_size(ep_group) + ep_group_rank_ids = get_process_group_ranks(ep_group) + + # Collect parameter info + param_info = [] + + # Dictionary of modules we can quickly look up to check if a module has TP + named_modules_dict = dict(model.named_modules()) + + # Process each parameter in the model + # state_dict includes parameters and persistent buffers + ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") + for name, param in model.state_dict().items(): + # Skip _extra_state entries (these are metadata, not actual weights) + if "_extra_state" in name: + continue + + use_etp = True if ep_pattern.search(name) else False + if use_etp: + tensor_mp_rank_ids = etp_group_rank_ids + else: + tensor_mp_rank_ids = tp_group_rank_ids + + shape = list(param.shape) + tp_dim = get_tp_dim(model, name, named_modules_dict) + if tp_dim is not None: + tp_rank_ids = tuple(sorted(tensor_mp_rank_ids)) + shape[tp_dim] *= len(tp_rank_ids) + else: + tp_rank_ids = (torch.distributed.get_rank(),) + + pp_rank_ids = tuple(sorted(pp_group_rank_ids)) + ep_rank_ids = tuple(sorted(ep_group_rank_ids)) + + if ep_pattern.search(name): + ep_rank_ids = tuple(sorted(ep_group_rank_ids)) + else: + ep_rank_ids = (torch.distributed.get_rank(),) + + # Calculate size for this parameter + prec_to_bytes = { + torch.bfloat16: 2, + torch.float16: 2, + torch.float32: 4, + } + scale = prec_to_bytes[dtype] / prec_to_bytes[param.dtype] + size_in_bytes = ( + param.element_size() + * param.numel() + * len(tensor_mp_rank_ids) + * len(ep_rank_ids) + * scale + ) + param_info.append( + ( + ( + name, + pp_local_rank_id, + tuple(shape), + param.dtype, + ), + size_in_bytes, + ) + ) + # Gather parameter info from all pipeline parallel ranks to ensure complete coverage + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + + # Gather all parameter info from all PP ranks + pp_gathered_param_infos = [None] * pp_world_size + torch.distributed.all_gather_object( + pp_gathered_param_infos, param_info, group=pp_group + ) + pp_gathered_param_infos = [x for y in pp_gathered_param_infos for x in y] # type: ignore + + # Gather parameter info from all expert parallel ranks to ensure complete coverage + ep_group = parallel_state.get_expert_model_parallel_group() + ep_world_size = torch.distributed.get_world_size(ep_group) + + # Gather all parameter info from all EP ranks + ep_gathered_param_infos = [None] * ep_world_size + torch.distributed.all_gather_object( + ep_gathered_param_infos, pp_gathered_param_infos, group=ep_group + ) + all_param_infos = [x for y in ep_gathered_param_infos for x in y] + + # Merge all parameter infos, keeping only unique parameter names + merged_param_info = [] + seen_params = set() + + for name, size in all_param_infos: + if name not in seen_params: + merged_param_info.append((name, size)) + seen_params.add(name) + + # Update param_info with the merged information + param_info = merged_param_info + print(f"Prepared {len(param_info)} tensors for refit") + + return param_info + + +@torch.no_grad() +def get_local_key_to_global_keys(model, state_dict_info: List[Tuple[Any, int]]): + """Get the local key to global keys mapping.""" + # Get parallel info + tp_group = parallel_state.get_tensor_model_parallel_group() + tp_world_size = torch.distributed.get_world_size(tp_group) + + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + pp_global_ranks = torch.distributed.get_process_group_ranks(group=pp_group) + pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() + + ep_group = parallel_state.get_expert_model_parallel_group() + ep_world_size = torch.distributed.get_world_size(ep_group) + + # start calculating the global key + ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") + state_dict = model.state_dict() + final_key_to_global_keys = {} + + for param_info, size in state_dict_info: + local_key, owner_pp_local_rank_id, _, _ = param_info + + # Step 1: create global key from local key + # if: for if a parameter is sharded along PP or EP; + # else: not sharded (like embedding) + pp_gathered_objs = [None] + if local_key in state_dict and owner_pp_local_rank_id == pp_local_rank_id: + pp_gathered_objs[0] = get_global_key_from_local_key(local_key, model.config) + + # Step 2: gather global keys from ranks in PP group + src_global_rank = pp_global_ranks[owner_pp_local_rank_id] + torch.distributed.broadcast_object_list( + pp_gathered_objs, src=src_global_rank, group=pp_group + ) + + # Step 3: gather global keys from ranks in EP group + if ep_pattern.search(local_key): + ep_gathered_objs = [None] * ep_world_size + torch.distributed.all_gather_object( + ep_gathered_objs, pp_gathered_objs, group=ep_group + ) + flat_gathered_objs = [x for y in ep_gathered_objs for x in y] + else: + flat_gathered_objs = pp_gathered_objs + + final_key_to_global_keys[(local_key, owner_pp_local_rank_id)] = ( + flat_gathered_objs + ) + + return final_key_to_global_keys diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 6be85d9e0d..55bbfeb170 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -139,13 +139,13 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + self.is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. - if ( - "generation" in config - and config["generation"] is not None - and not config["generation"]["colocated"]["enabled"] - ): + if not self.is_generation_colocated: os.environ["NCCL_CUMEM_ENABLE"] = "1" # Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+) @@ -289,12 +289,6 @@ def __init__( if self.cpu_offload: self.model = self.move_to_device(self.model, "cpu") - # used for streaming update inference engine weights - self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( - None - ) - self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None - if init_reference_model: self.reference_model_state_dict = get_cpu_state_dict( self.model.state_dict().items(), pin_memory=True @@ -350,6 +344,15 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) + # vars used for refit + ## will be initialized in prepare_refit_info + self.refit_param_info = None + ## used for streaming update inference engine weights + self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( + None + ) + self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None + # Refer to nemo impl. Below is original comment. # based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L113 @staticmethod @@ -897,6 +900,26 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) + @torch.no_grad() + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + state_dict = self.model.state_dict() + + if self.is_generation_colocated: + # Collect info for streaming multiple tensors + self.refit_param_info = [] + for name, tensor in state_dict.items(): + # dtensor's numel will return complete tensor instead of only local tensor + size_in_bytes = tensor.element_size() * tensor.numel() + self.refit_param_info.append((name, size_in_bytes)) + + else: + # Collect info for collective communication + state_dict_info = {} + for name, tensor in state_dict.items(): + state_dict_info[name] = (tensor.shape, self.dtype) + + return state_dict_info + @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare the weights for IPC. @@ -917,22 +940,16 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: self.model.state_dict() ) - # Collect info for streaming multiple tensors - state_dict_info = [] - for name, tensor in self._held_sharded_state_dict_reference.items(): - # dtensor's numel will return complete tensor instead of only local tensor - size_in_bytes = tensor.element_size() * tensor.numel() - state_dict_info.append((name, size_in_bytes)) - # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() ## Get device free memory using NVML total_available_bytes = get_free_memory_bytes(device_idx) ## Use 80% of the free memory for safety - total_available_bytes *= 0.8 + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") + total_available_bytes *= float(memory_ratio) - return state_dict_info, total_available_bytes + return self.refit_param_info, total_available_bytes @torch.no_grad() def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: @@ -975,24 +992,6 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: return {device_uuid: serialized} - @torch.no_grad() - def prepare_info_for_collective(self) -> dict[str, Any]: - """Prepare the info for collective communication. - - Returns: - dict: A dictionary containing the info for collective communication. - """ - # Get state_dict - self.model = self.move_to_cuda(self.model) - state_dict = self.model.state_dict() - - # Collect info for collective communication - state_dict_info = {} - for name, tensor in state_dict.items(): - state_dict_info[name] = (tensor.shape, self.dtype) - - return state_dict_info - @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 614340c67b..d63b66e735 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, TypedDict +from typing import Any, Optional, TypedDict import ray import torch @@ -109,15 +109,15 @@ def offload_after_refit(self) -> None: pass @abstractmethod - def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: + def prepare_refit_info(self) -> Optional[dict[str, Any]]: pass @abstractmethod - def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: pass @abstractmethod - def prepare_info_for_collective(self) -> dict[str, Any]: + def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: pass @abstractmethod diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5e82b61d72..c77f2460e7 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -405,6 +405,17 @@ def finish_training(self, *args: Any, **kwargs: Any) -> None: # Placeholder implementation pass + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare the info for refit. + + Returns: + dict: A dictionary containing the info for refit. + """ + futures = self.worker_group.run_all_workers_single_data("prepare_refit_info") + results = ray.get(futures) + # Only get the first worker's info since all workers will have the same result + return results[0] + def prepare_weights_for_ipc( self, _refit_buffer_size_gb: Optional[int] = None ) -> list[list[str]]: @@ -469,19 +480,6 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: return all_handles - def prepare_info_for_collective(self) -> dict[str, Any]: - """Prepare the info for collective communication. - - Returns: - dict: A dictionary containing the info for collective communication. - """ - futures = self.worker_group.run_all_workers_single_data( - "prepare_info_for_collective" - ) - results = ray.get(futures) - # Only get the first worker's info since all workers will have the same result - return results[0] - def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" futures = self.worker_group.run_all_workers_single_data( diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 691e1ce5b3..1b0b6c9b63 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -13,13 +13,12 @@ # limitations under the License. import gc import os -import re import time import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial -from typing import Any, Iterator, List, Optional, Tuple, TypeVar +from typing import Any, Iterator, Optional, TypeVar import ray import torch @@ -83,7 +82,6 @@ reduce_max_stat_across_model_parallel_group, ) from ray.util.queue import Queue -from torch.distributed import get_process_group_ranks from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.interfaces import LossFunction, LossType @@ -100,13 +98,11 @@ forward_step_arbitrary_loss, ) from nemo_rl.models.megatron.community_import import import_model_from_hf_name -from nemo_rl.models.megatron.converters.common import ( - MegatronToHFConverter, - get_global_key_from_local_key, -) +from nemo_rl.models.megatron.converters.common import MegatronToHFConverter from nemo_rl.models.megatron.refit_utils import ( gather_params, - get_tp_dim, + get_local_key_to_global_keys, + get_param_info, ) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -679,14 +675,8 @@ def __init__( ) self.final_padded_vocab_size = tokenizer_config.padded_vocab_size self.dp_size = worker_sharding_annotations.get_axis_size("data_parallel") - self._held_gather_buffer = None self.megatron_to_hf_converter = MegatronToHFConverter(hf_model_name, self.model) - # Create a map that maps any local parameter name to a list of global parameter names. - # This map is repeatedly used by parameter gatherring phase during refit of every step. - self.local_key_to_global_keys = self.get_local_key_to_global_keys( - state_dict_info=self.prepare_weights_for_ipc()[0] - ) self.should_disable_forward_pre_hook = ( self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"] and self.cfg["megatron_cfg"]["distributed_data_parallel_config"][ @@ -694,6 +684,13 @@ def __init__( ] ) + # vars used for refit + ## will be initialized in prepare_refit_info + self.refit_param_info_hf = None + self.local_key_to_global_keys = None + ## used for streaming update inference engine weights + self._held_gather_buffer = None + def is_alive(self): return True @@ -1261,59 +1258,43 @@ def report_device_id(self) -> str: return get_device_uuid(device_idx) @torch.no_grad() - def get_local_key_to_global_keys(self, state_dict_info: List[Tuple[Any, int]]): - """Get the local key to global keys mapping.""" - # Get parallel info - tp_group = parallel_state.get_tensor_model_parallel_group() - tp_world_size = torch.distributed.get_world_size(tp_group) - - pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_world_size = torch.distributed.get_world_size(pp_group) - pp_global_ranks = torch.distributed.get_process_group_ranks(group=pp_group) - pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() - - ep_group = parallel_state.get_expert_model_parallel_group() - ep_world_size = torch.distributed.get_world_size(ep_group) - - # start calculating the global key - ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") - state_dict = self.model.state_dict() - final_key_to_global_keys = {} - - for param_info, size in state_dict_info: - local_key, owner_pp_local_rank_id, _, _ = param_info - - # Step 1: create global key from local key - # if: for if a parameter is sharded along PP or EP; - # else: not sharded (like embedding) - pp_gathered_objs = [None] - if local_key in state_dict and owner_pp_local_rank_id == pp_local_rank_id: - pp_gathered_objs[0] = get_global_key_from_local_key( - local_key, self.model.config - ) - - # Step 2: gather global keys from ranks in PP group - src_global_rank = pp_global_ranks[owner_pp_local_rank_id] - torch.distributed.broadcast_object_list( - pp_gathered_objs, src=src_global_rank, group=pp_group - ) + def prepare_refit_info(self) -> None: + # Get parameter info for refit + ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples + # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different + ## e.g. e_score_correction_bias + refit_param_info_mcore = get_param_info(self.model, self.dtype) - # Step 3: gather global keys from ranks in EP group - if ep_pattern.search(local_key): - ep_gathered_objs = [None] * ep_world_size - torch.distributed.all_gather_object( - ep_gathered_objs, pp_gathered_objs, group=ep_group - ) - flat_gathered_objs = [x for y in ep_gathered_objs for x in y] - else: - flat_gathered_objs = pp_gathered_objs + # Create a map that maps any local parameter name to a list of global parameter names. + # This map is repeatedly used by parameter gatherring phase during refit of every step. + self.local_key_to_global_keys = get_local_key_to_global_keys( + self.model, state_dict_info=refit_param_info_mcore + ) - final_key_to_global_keys[(local_key, owner_pp_local_rank_id)] = ( - flat_gathered_objs + # Collect tensor metadata for refit + self.refit_param_info_hf = {} + for key, _ in refit_param_info_mcore: + # gather megatron params + gathered_megatron_params = gather_params( + self.model, + [key], + key_to_global_keys=self.local_key_to_global_keys, ) + # convert to hf params + gathered_hf_params = self.megatron_to_hf_converter.convert( + gathered_megatron_params, self.model.config + ) + # collect tensor metadata + for name, tensor in gathered_hf_params.items(): + self.refit_param_info_hf[name] = ( + tensor.shape, + tensor.dtype, + tensor.numel(), + ) - return final_key_to_global_keys + return self.refit_param_info_hf + @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1322,139 +1303,25 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes - no_grad = torch.no_grad() - no_grad.__enter__() - # Ensure model is in evaluation mode - self.model.eval() - - # Get parallel info - tp_group = parallel_state.get_tensor_model_parallel_group() - tp_world_size = torch.distributed.get_world_size(tp_group) - tp_group_rank_ids = get_process_group_ranks(tp_group) - - etp_group = parallel_state.get_expert_tensor_parallel_group() - etp_world_size = torch.distributed.get_world_size(etp_group) - etp_group_rank_ids = get_process_group_ranks(etp_group) - - pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_world_size = torch.distributed.get_world_size(pp_group) - pp_group_rank_ids = get_process_group_ranks(pp_group) - pp_local_rank_id = parallel_state.get_pipeline_model_parallel_rank() - - ep_group = parallel_state.get_expert_model_parallel_group() - ep_world_size = torch.distributed.get_world_size(ep_group) - ep_group_rank_ids = get_process_group_ranks(ep_group) - - # Collect parameter info - param_info = [] - - # Dictionary of modules we can quickly look up to check if a module has TP - named_modules_dict = dict(self.model.named_modules()) - - # Process each parameter in the model - # state_dict includes parameters and persistent buffers - ep_pattern = re.compile(r"mlp\.experts.*\.weight\d*$") - for name, param in self.model.state_dict().items(): - # Skip _extra_state entries (these are metadata, not actual weights) - if "_extra_state" in name: - continue - - use_etp = True if ep_pattern.search(name) else False - if use_etp: - tensor_mp_rank_ids = etp_group_rank_ids - else: - tensor_mp_rank_ids = tp_group_rank_ids - - shape = list(param.shape) - tp_dim = get_tp_dim(self.model, name, named_modules_dict) - if tp_dim is not None: - tp_rank_ids = tuple(sorted(tensor_mp_rank_ids)) - shape[tp_dim] *= len(tp_rank_ids) - else: - tp_rank_ids = (torch.distributed.get_rank(),) - - pp_rank_ids = tuple(sorted(pp_group_rank_ids)) - ep_rank_ids = tuple(sorted(ep_group_rank_ids)) - - if ep_pattern.search(name): - ep_rank_ids = tuple(sorted(ep_group_rank_ids)) - else: - ep_rank_ids = (torch.distributed.get_rank(),) - - # Calculate size for this parameter - prec_to_bytes = { - torch.bfloat16: 2, - torch.float16: 2, - torch.float32: 4, - } - scale = prec_to_bytes[self.dtype] / prec_to_bytes[param.dtype] - size_in_bytes = ( - param.element_size() - * param.numel() - * len(tensor_mp_rank_ids) - * len(ep_rank_ids) - * scale - ) - param_info.append( - ( - ( - name, - pp_local_rank_id, - tuple(shape), - param.dtype, - ), - size_in_bytes, - ) - ) - # Gather parameter info from all pipeline parallel ranks to ensure complete coverage - pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_world_size = torch.distributed.get_world_size(pp_group) - - # Gather all parameter info from all PP ranks - pp_gathered_param_infos = [None] * pp_world_size - torch.distributed.all_gather_object( - pp_gathered_param_infos, param_info, group=pp_group - ) - pp_gathered_param_infos = [x for y in pp_gathered_param_infos for x in y] # type: ignore - - # Gather parameter info from all expert parallel ranks to ensure complete coverage - ep_group = parallel_state.get_expert_model_parallel_group() - ep_world_size = torch.distributed.get_world_size(ep_group) - - # Gather all parameter info from all EP ranks - ep_gathered_param_infos = [None] * ep_world_size - torch.distributed.all_gather_object( - ep_gathered_param_infos, pp_gathered_param_infos, group=ep_group - ) - all_param_infos = [x for y in ep_gathered_param_infos for x in y] - - # Merge all parameter infos, keeping only unique parameter names - merged_param_info = [] - seen_params = set() - - for name, size in all_param_infos: - if name not in seen_params: - merged_param_info.append((name, size)) - seen_params.add(name) - - # Update param_info with the merged information - param_info = merged_param_info - - print(f"Prepared {len(param_info)} tensors for IPC transfer") - no_grad.__exit__(None, None, None) + # Get parameter info for refit + ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples + # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different + ## e.g. e_score_correction_bias + refit_param_info_mcore = get_param_info(self.model, self.dtype) # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() ## Get device free memory using NVML total_available_bytes = get_free_memory_bytes(device_idx) - # TODO: setting to low value (10%) since - # more buckets seems to have better perf - total_available_bytes *= 0.1 + ## default to 20% to get some more speedup than 10%, OOM if set to 30% + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") + total_available_bytes *= float(memory_ratio) - return param_info, total_available_bytes + return refit_param_info_mcore, total_available_bytes # Temporary fix, 'keys' is a kwarg due to some sort of ray bug + @torch.no_grad() def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: """Get IPC handles for the requested Megatron model weights. @@ -1496,14 +1363,23 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: type_to_total_size = defaultdict(lambda: 0) tensor_metadata = dict() + # Record offset of the tensor for key, tensor in gathered_hf_params.items(): - tensor_metadata[key] = ( - tensor.shape, # shape of the tensor - tensor.dtype, # dtype of the tensor - type_to_total_size[tensor.dtype], # offset of the tensor - # in packed buffer - tensor.numel(), # size of the tensor - ) + # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) + if tensor.dtype == self.refit_param_info_hf[key][1]: + tensor_metadata[key] = type_to_total_size[tensor.dtype] + else: + # also send dtype if it changes + tensor_metadata[key] = ( + type_to_total_size[tensor.dtype], + tensor.dtype, + ) + # update record + self.refit_param_info_hf[key] = ( + tensor.shape, + tensor.dtype, + tensor.numel(), + ) type_to_total_size[tensor.dtype] += tensor.numel() # Allocate consolidated tensors for each dtype @@ -1519,8 +1395,11 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Copy tensors into consolidated buffers for key, tensor in gathered_hf_params.items(): - metadata = tensor_metadata[key] - _, dtype, offset, size = metadata + offset = tensor_metadata[key] + if isinstance(offset, tuple): + offset, _ = offset + dtype = tensor.dtype + size = tensor.numel() packed_tensors[dtype][offset : offset + size].copy_( tensor.detach().view(-1) ) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 626fce2b6f..8d6cff05f8 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -41,7 +41,7 @@ "name": model_name, }, "dtype": "bfloat16", - "max_new_tokens": 5, + "max_new_tokens": 5, # Small number of tokens for testing "temperature": 0.8, "top_p": 1.0, "top_k": None, @@ -133,15 +133,6 @@ def get_basic_megatron_test_config( "learning_rate": 5e-6, "logprob_batch_size": 2, "precision": precision, - "generation": { - "backend": "megatron", - "temperature": 1.0, - "max_new_tokens": 16, # Small number of tokens for testing - "top_p": 1.0, - "top_k": None, - "stop_token_ids": None, - "stop_strings": None, - }, "dtensor_cfg": { "enabled": False, # Disabled for Megatron tests }, @@ -202,6 +193,7 @@ def get_basic_megatron_test_config( "optimizer": None, # Remove default FSDP optimizer "scheduler": None, # Remove default scheduler "max_grad_norm": 1.0, + "generation": deepcopy(basic_vllm_test_config), } @@ -426,11 +418,18 @@ async def test_vllm_policy_generation_async( dtensor_config = basic_dtensor_test_config from nemo_rl.models.policy.lm_policy import Policy + print("creating vllm policy...") async_policy = VllmGeneration(cluster, vllm_config) async_policy.finish_generation() - print("creating hf policy...") + print("creating lm policy...") lm_policy = Policy(cluster, dtensor_config, tokenizer) + + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + async_policy.prepare_refit_info(state_dict_info) + + print("refitting vllm policy...") refit_policy_generation( lm_policy, async_policy, vllm_config["colocated"]["enabled"] ) @@ -528,6 +527,9 @@ def test_vllm_worker_seed_behavior(cluster, tokenizer): dtensor_config = basic_dtensor_test_config lm_policy = Policy(cluster, dtensor_config, tokenizer) + state_dict_info = lm_policy.prepare_refit_info() + policy.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation(lm_policy, policy, vllm_config["colocated"]["enabled"]) @@ -682,6 +684,10 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine print("Creating DTensor policy...") lm_policy = Policy(cluster, dtensor_config, tokenizer) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation( lm_policy, vllm_policy, vllm_config["colocated"]["enabled"] @@ -930,9 +936,14 @@ def test_vllm_weight_update_and_prefix_cache_reset( try: print(f"Creating DTensor policy for TP={tensor_parallel_size}...") lm_policy = Policy(cluster, dtensor_config, tokenizer) + print(f"Creating vLLM policy for TP={tensor_parallel_size}...") vllm_policy = VllmGeneration(cluster, vllm_config) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + # Prepare input data (batch size 2) text = """Answer the question based on the context below. Keep the answer short and concise. Respond "Unsure about answer" if not sure about the answer. Context: Teplizumab traces its roots to a New Jersey drug company called Ortho Pharmaceutical. There, scientists generated an early version of the antibody, dubbed OKT3. Originally sourced from mice, the molecule was able to bind to the surface of T cells and limit their cell-killing potential. In 1986, it was approved to help prevent organ rejection after kidney transplants, making it the first therapeutic antibody allowed for human use.Question: What was OKT3 originally sourced from?Answer:""" test_prompt = [text, text] # Use batch size 2 @@ -970,7 +981,7 @@ def test_vllm_weight_update_and_prefix_cache_reset( grouped_param_keys = lm_policy.prepare_weights_for_ipc() for keys in grouped_param_keys: ipc_handles = lm_policy.get_weights_ipc_handles(keys) - update_success = vllm_policy.update_weights(ipc_handles) + update_success = vllm_policy.update_weights_from_ipc_handles(ipc_handles) assert update_success, "Weight update should succeed" print("vLLM weights successfully updated.") @@ -1035,6 +1046,10 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor): dtensor_config = basic_dtensor_test_config lm_policy = Policy(cluster, dtensor_config, tokenizer) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") # take it outside statistics to get clean peak memory during refit lm_policy.offload_before_refit() @@ -1112,6 +1127,10 @@ def test_vllm_generation_with_stop( dtensor_config = basic_dtensor_test_config lm_policy = Policy(cluster, dtensor_config, tokenizer) + print("preparing refit info...") + state_dict_info = lm_policy.prepare_refit_info() + vllm_generation.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation( lm_policy, vllm_generation, vllm_config["colocated"]["enabled"] @@ -1219,6 +1238,10 @@ async def test_vllm_refit_non_collocated_update_weights( futures_inference = vllm_generation.init_collective(ip, port, world_size=2) ray.get(futures_train + futures_inference) + # prepare refit info + state_dict_info = lm_policy.prepare_refit_info() + vllm_generation.prepare_refit_info(state_dict_info) + print("refitting vllm policy...") refit_policy_generation( lm_policy, vllm_generation, vllm_config["colocated"]["enabled"] @@ -1317,6 +1340,10 @@ def test_vllm_generation_with_megatron_training( print("Creating Megatron policy...") megatron_policy = Policy(cluster, megatron_config, test_tokenizer) + print("preparing refit info...") + state_dict_info = megatron_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("Refitting vLLM policy with Megatron weights...") refit_policy_generation( megatron_policy, vllm_policy, vllm_config["colocated"]["enabled"] @@ -1436,6 +1463,10 @@ def test_vllm_megatron_weight_update_memory(cluster, tokenizer): print("Creating Megatron policy...") megatron_policy = Policy(cluster, megatron_config, test_tokenizer) + print("preparing refit info...") + state_dict_info = megatron_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("Refitting vLLM policy with Megatron...") # Take it outside statistics to get clean peak memory during refit megatron_policy.offload_before_refit() @@ -1539,6 +1570,10 @@ def test_vllm_megatron_pipeline_parallel(cluster, tokenizer): vllm_policy = VllmGeneration(cluster, vllm_config) vllm_policy.finish_generation() + print("preparing refit info...") + state_dict_info = megatron_policy.prepare_refit_info() + vllm_policy.prepare_refit_info(state_dict_info) + print("Refitting vLLM with Megatron PP=2 weights...") refit_policy_generation( megatron_policy, vllm_policy, vllm_config["colocated"]["enabled"] @@ -1569,3 +1604,58 @@ def test_vllm_megatron_pipeline_parallel(cluster, tokenizer): vllm_policy.shutdown() if megatron_policy: megatron_policy.shutdown() + + +def test_vllm_megatron_weight_update_with_packing(cluster, test_input_data): + megatron_policy = None + vllm_generation = None + + try: + # Enable packing during test + os.environ["NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD"] = "1" + + # Both policies must use the same model (Qwen2.5-0.5B) for weight transfer compatibility + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = get_tokenizer({"name": model_name}) + + # Create Policy + megatron_config = get_basic_megatron_test_config( + tp=1, pp=1, precision="float32" + ) + megatron_config["model_name"] = model_name + megatron_config["tokenizer"]["name"] = model_name + megatron_policy = Policy(cluster, megatron_config, tokenizer) + + # Create VllmGeneration + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) + vllm_config["model_name"] = model_name + vllm_config["tokenizer"]["name"] = model_name + vllm_generation = VllmGeneration(cluster, vllm_config) + + # prepare refit info + state_dict_info = megatron_policy.prepare_refit_info() + vllm_generation.prepare_refit_info(state_dict_info) + + print("refitting vllm policy...") + refit_policy_generation( + megatron_policy, vllm_generation, vllm_config["colocated"]["enabled"] + ) + + # test generate + outputs = vllm_generation.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == [ + "Hello, my name is John. I am a", + "The capital of France is Paris. It is the", + ], "Output should be the same as the expected output" + + finally: + # Restore the original value + os.environ.pop("NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", None) + # Clean up + if megatron_policy: + megatron_policy.shutdown() + if vllm_generation: + vllm_generation.shutdown() diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index a23c1b5559..ee1d422a3b 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -60,6 +60,13 @@ def create_megatron_test_config( "top_k": None, "stop_token_ids": None, "stop_strings": None, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, }, "dtensor_cfg": { "enabled": False, # Disabled for Megatron tests