Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def __init__(

# vars used for refit
## will be initialized in prepare_refit_info
self.refit_param_info_hf = None
self.refit_param_info_mcore = None
self.local_key_to_global_keys = None
## used for streaming update inference engine weights
self._held_gather_buffer = None
Expand Down Expand Up @@ -1278,18 +1278,18 @@ def report_device_id(self) -> str:
@torch.no_grad()
def prepare_refit_info(self) -> None:
# Get parameter info for refit
## param_info: list of ((name, shape, dtype), size_in_bytes) tuples
refit_param_info_mcore = get_param_info(self.model, self.dtype)
# param_info: list of ((name, shape, dtype), size_in_bytes) tuples
self.refit_param_info_mcore = get_param_info(self.model, self.dtype)

# Create a map that maps any local parameter name to a list of global parameter names.
# This map is repeatedly used by parameter gatherring phase during refit of every step.
self.local_key_to_global_keys = get_local_key_to_global_keys(
self.model, state_dict_info=refit_param_info_mcore
self.model, state_dict_info=self.refit_param_info_mcore
)

# Collect tensor metadata for refit
self.refit_param_info_hf = {}
for key, _ in refit_param_info_mcore:
refit_param_info_hf = {}
for key, _ in self.refit_param_info_mcore:
# gather megatron params
gathered_megatron_params = gather_params(
self.model,
Expand All @@ -1302,15 +1302,14 @@ def prepare_refit_info(self) -> None:
)
# collect tensor metadata
for name, tensor in gathered_hf_params.items():
self.refit_param_info_hf[name] = (
refit_param_info_hf[name] = (
tensor.shape,
tensor.dtype,
tensor.numel(),
)

return self.refit_param_info_hf
return refit_param_info_hf

@torch.no_grad()
def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
"""Prepare Megatron model weights for IPC transfer to vLLM.

Expand All @@ -1319,10 +1318,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
"""
from nemo_rl.utils.nvml import get_free_memory_bytes

# Get parameter info for refit
## param_info: list of ((name, shape, dtype), size_in_bytes) tuples
refit_param_info_mcore = get_param_info(self.model, self.dtype)

# Collect current available memory for refit
## Get current device index from torch
device_idx = torch.cuda.current_device()
Expand All @@ -1332,7 +1327,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2")
total_available_bytes *= float(memory_ratio)

return refit_param_info_mcore, total_available_bytes
return self.refit_param_info_mcore, total_available_bytes

def get_handle_from_tensor(self, tensor: torch.Tensor) -> tuple[str, Any]:
"""Get IPC handle from a tensor."""
Expand Down