Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerBase
from contextlib import nullcontext

from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.loss_functions import (
Expand Down Expand Up @@ -362,6 +363,9 @@ def setup(
print(" " * 18 + "SETUP COMPLETE")
print("=" * 60 + "\n")

if os.getenv("RAY_PROFILING", None) == "1":
ray.timeline(filename=os.getenv("NEMO_RL_RAY_TIMELINE_FILE", "/tmp/ray_timeline.json"))

return (
policy,
policy_generation,
Expand All @@ -386,6 +390,7 @@ def refit_policy_generation(
policy_generation: GenerationInterface,
colocated_inference: bool,
_refit_buffer_size_gb: Optional[int] = None,
timer: Optional[Timer] = None,
) -> None:
"""Refit the policy generation interface with the latest policy weights.

Expand All @@ -410,8 +415,10 @@ def refit_policy_generation(
print(f"[Refit] Number of splits: {len(grouped_param_keys)}")
# do update
for keys in grouped_param_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
update_success = policy_generation.update_weights(ipc_handles)
with timer.time("prepare_for_generation/get_weights_ipc_handles") if timer else nullcontext():
ipc_handles = policy.get_weights_ipc_handles(keys)
with timer.time("prepare_for_generation/update_weights") if timer else nullcontext():
update_success = policy_generation.update_weights(ipc_handles)
if not update_success:
break
else:
Expand Down Expand Up @@ -528,7 +535,7 @@ def grpo_train(
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
policy, policy_generation, colocated_inference, timer=timer
)
POLICY_GENERATION_STALE = False
else:
Expand Down
34 changes: 34 additions & 0 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,22 @@ def _patch_vllm_init_workers_ray():
else:
self.llm = vllm.LLM(**llm_kwargs)

# torch profiler
import socket
if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.is_model_owner and (0 in bundle_indices):
self.profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
"grpo_refit_trace/update_weights_from_ipc_handles",
worker_name=f"{socket.gethostname()}_vllm_worker_{self.rank}",
use_gzip=True,
),
)
else:
self.profiler = None
self.maybe_profile_refit_times = 0

def init_collective(self, data: int, ip: str, port: int, world_size: int) -> None:
self.llm.collective_rpc(
"init_collective",
Expand Down Expand Up @@ -879,11 +895,20 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool:
"update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead."
)

if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None:
self.maybe_profile_refit_times += 1
if self.maybe_profile_refit_times == 3:
self.profiler.start()

result_or_coro = self.llm.collective_rpc(
"update_weights_from_ipc_handles", args=(data,)
)
worker_result = result_or_coro[0]

if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None:
if self.maybe_profile_refit_times == 3:
self.profiler.stop()

if not worker_result:
print(
f"Error: Worker failed to update weights. Result: {worker_result}"
Expand Down Expand Up @@ -916,6 +941,11 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b
"update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead."
)

if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None:
self.maybe_profile_refit_times += 1
if self.maybe_profile_refit_times == 3:
self.profiler.start()

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_ipc_handles", args=(data,)
)
Expand All @@ -927,6 +957,10 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b

worker_result = worker_results[0]

if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None:
if self.maybe_profile_refit_times == 3:
self.profiler.stop()

if not worker_result:
print(
f"Error: Worker failed to update weights. Result: {worker_result}"
Expand Down
37 changes: 29 additions & 8 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,38 @@ def update_weights_from_ipc_handles(self, ipc_handles):
# Get handles for this device
device_uuid = self.report_device_id()
handles = ipc_handles[device_uuid]
is_tensor_packed = handles[0]
if is_tensor_packed:
_, all_handles, tensor_metadata = handles
else:
_, name_and_handle_list = handles

device_id = self.device.index
weights = []

# Process each handle to get the tensor
for name, handle in handles:
func, args = handle
list_args = list(args)
# Update device ID to match the current device
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
if is_tensor_packed:
# Extract packed tensor from IPC handle
dtype_to_packed_tensor = {}
for dtype, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
dtype_to_packed_tensor[dtype] = tensor

# Unpack tensor to weights. Here we only return a view of the tensor to avoid
# using extra memory.
for key, (shape, dtype, offset, size) in tensor_metadata.items():
tensor = dtype_to_packed_tensor[dtype][offset:offset+size].view(*shape)
weights.append((key, tensor))
else:
# Process each handle to get the tensor
for name, handle in name_and_handle_list:
func, args = handle
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))

# Load weights into the model
self.model_runner.model.load_weights(weights=weights)
Expand Down
1 change: 0 additions & 1 deletion nemo_rl/models/megatron/refit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,4 @@ def gather_params(
if k is not None:
gathered_params[k] = p

print(f"Time taken to gather params: {time.perf_counter() - st}")
return gathered_params
99 changes: 87 additions & 12 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,22 @@ def __init__(
state_dict_info=self.prepare_weights_for_ipc()[0]
)

# torch profiler
import socket
if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and get_rank_safe() == 0:
self.profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
"grpo_refit_trace/get_weight_ipc_handles",
worker_name=f"{socket.gethostname()}_megatron_policy_worker_0",
use_gzip=True,
),
)
else:
self.profiler = None
self.maybe_profile_refit_times = 0

def configure_worker(self, num_gpus: int, bundle_indices: Optional[tuple] = None):
USE_EXPANDABLE_SEGMENTS = False # Disabling this right now as it seems to cause vLLM refit issues with Ampere
if USE_EXPANDABLE_SEGMENTS:
Expand Down Expand Up @@ -1436,6 +1452,11 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:
Returns:
Dict mapping device UUID to list of (mapped_key, handle) tuples
"""
if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None:
self.maybe_profile_refit_times += 1
if self.maybe_profile_refit_times == 3:
self.profiler.start()

if self._held_gather_buffer is not None:
del self._held_gather_buffer
self._held_gather_buffer = None
Expand All @@ -1455,18 +1476,72 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:
from torch.multiprocessing.reductions import reduce_tensor

# Create IPC handles for each parameter
all_handles = []
for key, tensor in gathered_hf_params.items():
handle = reduce_tensor(tensor.detach())
all_handles.append((key, handle))

# Store references to avoid premature garbage collection
self._held_gather_buffer = gathered_hf_params
shapes = {}
for key, tensor in gathered_hf_params.items():
shapes[key] = tensor.shape

return {device_uuid: all_handles}
tensor_number_threshold = os.getenv(
"NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32"
) # an arbitrary threshold
if len(gathered_hf_params) >= int(tensor_number_threshold):
pack_tensor_for_ipc = True
else:
pack_tensor_for_ipc = False

if pack_tensor_for_ipc:
# Pack tensors in gathered_hf_params into consolidated tensors by dtype
# First calculate total size needed for each dtype
type_to_total_size = defaultdict(lambda: 0)
tensor_metadata = dict()

for key, tensor in gathered_hf_params.items():
tensor_metadata[key] = (
tensor.shape, # shape of the tensor
tensor.dtype, # dtype of the tensor
type_to_total_size[tensor.dtype], # offset of the tensor
# in packed buffer
tensor.numel() # size of the tensor
)
type_to_total_size[tensor.dtype] += tensor.numel()

# Allocate consolidated tensors for each dtype
packed_tensors = {
dtype: torch.empty(
total_size,
device=next(iter(gathered_hf_params.values())).device,
dtype=dtype,
requires_grad=False
)
for dtype, total_size in type_to_total_size.items()
}

# Copy tensors into consolidated buffers
for key, tensor in gathered_hf_params.items():
metadata = tensor_metadata[key]
_, dtype, offset, size = metadata
packed_tensors[dtype][offset:offset + size].copy_(
tensor.detach().view(-1)
)

# Create IPC handles for consolidated tensors
all_handles = [
(dtype, reduce_tensor(tensor.detach()))
for dtype, tensor in packed_tensors.items()
]

# Store reference to prevent garbage collection
self._held_gather_buffer = packed_tensors

serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata)
else:
all_handles = []
for key, tensor in gathered_hf_params.items():
handle = reduce_tensor(tensor.detach())
all_handles.append((key, handle))
self._held_gather_buffer = gathered_hf_params
serialized = (False, all_handles)

if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None:
if self.maybe_profile_refit_times == 3:
self.profiler.stop()

return {device_uuid: serialized}

def prepare_for_lp_inference(self):
self.model = self.move_model(self.model, "cuda", move_grads=False)
Expand Down