diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ac9335bcd4..fdd60c1d9b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -20,6 +20,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizerBase +from contextlib import nullcontext from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.algorithms.loss_functions import ( @@ -362,6 +363,9 @@ def setup( print(" " * 18 + "SETUP COMPLETE") print("=" * 60 + "\n") + if os.getenv("RAY_PROFILING", None) == "1": + ray.timeline(filename=os.getenv("NEMO_RL_RAY_TIMELINE_FILE", "/tmp/ray_timeline.json")) + return ( policy, policy_generation, @@ -386,6 +390,7 @@ def refit_policy_generation( policy_generation: GenerationInterface, colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, + timer: Optional[Timer] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -410,8 +415,10 @@ def refit_policy_generation( print(f"[Refit] Number of splits: {len(grouped_param_keys)}") # do update for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights(ipc_handles) + with timer.time("prepare_for_generation/get_weights_ipc_handles") if timer else nullcontext(): + ipc_handles = policy.get_weights_ipc_handles(keys) + with timer.time("prepare_for_generation/update_weights") if timer else nullcontext(): + update_success = policy_generation.update_weights(ipc_handles) if not update_success: break else: @@ -528,7 +535,7 @@ def grpo_train( with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, policy_generation, colocated_inference, timer=timer ) POLICY_GENERATION_STALE = False else: diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index cbb603f74e..49249eb685 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -348,6 +348,22 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) + # torch profiler + import socket + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.is_model_owner and (0 in bundle_indices): + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + "grpo_refit_trace/update_weights_from_ipc_handles", + worker_name=f"{socket.gethostname()}_vllm_worker_{self.rank}", + use_gzip=True, + ), + ) + else: + self.profiler = None + self.maybe_profile_refit_times = 0 + def init_collective(self, data: int, ip: str, port: int, world_size: int) -> None: self.llm.collective_rpc( "init_collective", @@ -879,11 +895,20 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." ) + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.maybe_profile_refit_times += 1 + if self.maybe_profile_refit_times == 3: + self.profiler.start() + result_or_coro = self.llm.collective_rpc( "update_weights_from_ipc_handles", args=(data,) ) worker_result = result_or_coro[0] + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + if self.maybe_profile_refit_times == 3: + self.profiler.stop() + if not worker_result: print( f"Error: Worker failed to update weights. Result: {worker_result}" @@ -916,6 +941,11 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b "update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead." ) + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.maybe_profile_refit_times += 1 + if self.maybe_profile_refit_times == 3: + self.profiler.start() + result_or_coro = await self.llm.collective_rpc( "update_weights_from_ipc_handles", args=(data,) ) @@ -927,6 +957,10 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b worker_result = worker_results[0] + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + if self.maybe_profile_refit_times == 3: + self.profiler.stop() + if not worker_result: print( f"Error: Worker failed to update weights. Result: {worker_result}" diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 16f51b06ac..378f40742d 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -59,17 +59,38 @@ def update_weights_from_ipc_handles(self, ipc_handles): # Get handles for this device device_uuid = self.report_device_id() handles = ipc_handles[device_uuid] + is_tensor_packed = handles[0] + if is_tensor_packed: + _, all_handles, tensor_metadata = handles + else: + _, name_and_handle_list = handles + device_id = self.device.index weights = [] - # Process each handle to get the tensor - for name, handle in handles: - func, args = handle - list_args = list(args) - # Update device ID to match the current device - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) + if is_tensor_packed: + # Extract packed tensor from IPC handle + dtype_to_packed_tensor = {} + for dtype, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = device_id + tensor = func(*list_args) + dtype_to_packed_tensor[dtype] = tensor + + # Unpack tensor to weights. Here we only return a view of the tensor to avoid + # using extra memory. + for key, (shape, dtype, offset, size) in tensor_metadata.items(): + tensor = dtype_to_packed_tensor[dtype][offset:offset+size].view(*shape) + weights.append((key, tensor)) + else: + # Process each handle to get the tensor + for name, handle in name_and_handle_list: + func, args = handle + list_args = list(args) + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) # Load weights into the model self.model_runner.model.load_weights(weights=weights) diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index e6ca825e0a..07081b4532 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -169,5 +169,4 @@ def gather_params( if k is not None: gathered_params[k] = p - print(f"Time taken to gather params: {time.perf_counter() - st}") return gathered_params diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 6805bea08c..6eea16d018 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -669,6 +669,22 @@ def __init__( state_dict_info=self.prepare_weights_for_ipc()[0] ) + # torch profiler + import socket + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and get_rank_safe() == 0: + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + "grpo_refit_trace/get_weight_ipc_handles", + worker_name=f"{socket.gethostname()}_megatron_policy_worker_0", + use_gzip=True, + ), + ) + else: + self.profiler = None + self.maybe_profile_refit_times = 0 + def configure_worker(self, num_gpus: int, bundle_indices: Optional[tuple] = None): USE_EXPANDABLE_SEGMENTS = False # Disabling this right now as it seems to cause vLLM refit issues with Ampere if USE_EXPANDABLE_SEGMENTS: @@ -1436,6 +1452,11 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: Returns: Dict mapping device UUID to list of (mapped_key, handle) tuples """ + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.maybe_profile_refit_times += 1 + if self.maybe_profile_refit_times == 3: + self.profiler.start() + if self._held_gather_buffer is not None: del self._held_gather_buffer self._held_gather_buffer = None @@ -1455,18 +1476,72 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter - all_handles = [] - for key, tensor in gathered_hf_params.items(): - handle = reduce_tensor(tensor.detach()) - all_handles.append((key, handle)) - - # Store references to avoid premature garbage collection - self._held_gather_buffer = gathered_hf_params - shapes = {} - for key, tensor in gathered_hf_params.items(): - shapes[key] = tensor.shape - - return {device_uuid: all_handles} + tensor_number_threshold = os.getenv( + "NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32" + ) # an arbitrary threshold + if len(gathered_hf_params) >= int(tensor_number_threshold): + pack_tensor_for_ipc = True + else: + pack_tensor_for_ipc = False + + if pack_tensor_for_ipc: + # Pack tensors in gathered_hf_params into consolidated tensors by dtype + # First calculate total size needed for each dtype + type_to_total_size = defaultdict(lambda: 0) + tensor_metadata = dict() + + for key, tensor in gathered_hf_params.items(): + tensor_metadata[key] = ( + tensor.shape, # shape of the tensor + tensor.dtype, # dtype of the tensor + type_to_total_size[tensor.dtype], # offset of the tensor + # in packed buffer + tensor.numel() # size of the tensor + ) + type_to_total_size[tensor.dtype] += tensor.numel() + + # Allocate consolidated tensors for each dtype + packed_tensors = { + dtype: torch.empty( + total_size, + device=next(iter(gathered_hf_params.values())).device, + dtype=dtype, + requires_grad=False + ) + for dtype, total_size in type_to_total_size.items() + } + + # Copy tensors into consolidated buffers + for key, tensor in gathered_hf_params.items(): + metadata = tensor_metadata[key] + _, dtype, offset, size = metadata + packed_tensors[dtype][offset:offset + size].copy_( + tensor.detach().view(-1) + ) + + # Create IPC handles for consolidated tensors + all_handles = [ + (dtype, reduce_tensor(tensor.detach())) + for dtype, tensor in packed_tensors.items() + ] + + # Store reference to prevent garbage collection + self._held_gather_buffer = packed_tensors + + serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata) + else: + all_handles = [] + for key, tensor in gathered_hf_params.items(): + handle = reduce_tensor(tensor.detach()) + all_handles.append((key, handle)) + self._held_gather_buffer = gathered_hf_params + serialized = (False, all_handles) + + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + if self.maybe_profile_refit_times == 3: + self.profiler.stop() + + return {device_uuid: serialized} def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False)