From cacbad88bf52b2ec1c259b87056d5e6394467163 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 18 Jul 2025 18:39:45 -0700 Subject: [PATCH] call get_param_info once Signed-off-by: Yuki Huang --- .../models/policy/megatron_policy_worker.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9e48127411..481096c4ae 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -703,7 +703,7 @@ def __init__( # vars used for refit ## will be initialized in prepare_refit_info - self.refit_param_info_hf = None + self.refit_param_info_mcore = None self.local_key_to_global_keys = None ## used for streaming update inference engine weights self._held_gather_buffer = None @@ -1278,18 +1278,18 @@ def report_device_id(self) -> str: @torch.no_grad() def prepare_refit_info(self) -> None: # Get parameter info for refit - ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - refit_param_info_mcore = get_param_info(self.model, self.dtype) + # param_info: list of ((name, shape, dtype), size_in_bytes) tuples + self.refit_param_info_mcore = get_param_info(self.model, self.dtype) # Create a map that maps any local parameter name to a list of global parameter names. # This map is repeatedly used by parameter gatherring phase during refit of every step. self.local_key_to_global_keys = get_local_key_to_global_keys( - self.model, state_dict_info=refit_param_info_mcore + self.model, state_dict_info=self.refit_param_info_mcore ) # Collect tensor metadata for refit - self.refit_param_info_hf = {} - for key, _ in refit_param_info_mcore: + refit_param_info_hf = {} + for key, _ in self.refit_param_info_mcore: # gather megatron params gathered_megatron_params = gather_params( self.model, @@ -1302,15 +1302,14 @@ def prepare_refit_info(self) -> None: ) # collect tensor metadata for name, tensor in gathered_hf_params.items(): - self.refit_param_info_hf[name] = ( + refit_param_info_hf[name] = ( tensor.shape, tensor.dtype, tensor.numel(), ) - return self.refit_param_info_hf + return refit_param_info_hf - @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1319,10 +1318,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes - # Get parameter info for refit - ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - refit_param_info_mcore = get_param_info(self.model, self.dtype) - # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() @@ -1332,7 +1327,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") total_available_bytes *= float(memory_ratio) - return refit_param_info_mcore, total_available_bytes + return self.refit_param_info_mcore, total_available_bytes def get_handle_from_tensor(self, tensor: torch.Tensor) -> tuple[str, Any]: """Get IPC handle from a tensor."""