diff --git a/docs/nsys-profiling.md b/docs/nsys-profiling.md index db3f6a768a..951251420c 100644 --- a/docs/nsys-profiling.md +++ b/docs/nsys-profiling.md @@ -17,7 +17,7 @@ NeMo RL supports Nsight profiling for Ray workers through environment variable p Set the `NRL_NSYS_WORKER_PATTERNS` environment variable with a comma-separated list of patterns to match worker names: ```bash -export NRL_NSYS_WORKER_PATTERNS="*policy*,*other-worker*" +export NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*" ``` Set the `NRL_NSYS_PROFILE_STEP_RANGE` environment variable to control which training steps the profiler captures. Its @@ -40,7 +40,7 @@ export NRL_NSYS_PROFILE_STEP_RANGE=3:5 The supported worker types are: - **DTensorPolicyWorker**: Pattern matched against `"dtensor_policy_worker"` -- **MegatronPolicyWorker**: Pattern matched against `"megatron_policy_worker"` +- **VllmGenerationWorker**: Pattern matched against `"vllm_generation_worker"` ## Example Usage @@ -49,10 +49,16 @@ The supported worker types are: NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="*policy*" uv run examples/run_grpo_math.py grpo.max_num_steps=5 ``` +### Profile Multiple Worker Types + +```bash +NRL_NSYS_PROFILE_STEP_RANGE=1:2 NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*" uv run examples/run_grpo_math.py grpo.max_num_steps=5 +``` + ### Profile Workers with Exact Names ```bash -NRL_NSYS_PROFILE_STEP_RANGE=3:10 NRL_NSYS_WORKER_PATTERNS="dtensor_policy_worker" uv run examples/run_grpo_math.py grpo.max_num_steps=5 +NRL_NSYS_PROFILE_STEP_RANGE=3:10 NRL_NSYS_WORKER_PATTERNS="dtensor_policy_worker,vllm_generation_worker" uv run examples/run_grpo_math.py grpo.max_num_steps=5 ``` ### Profile Megatron Workers @@ -63,7 +69,7 @@ To profile a Megatron worker, you should set `LD_LIBRARY_PATH` as follows, other ```bash LD_LIBRARY_PATH="/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/lib64:/usr/local/cuda/lib:/usr/local/nvidia/lib64:/usr/local/nvidia/lib:/usr/lib/x86_64-linux-gnu" \ -NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="megatron_policy_worker" uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml grpo.max_num_steps=5 +NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="megatron_policy_worker,vllm_generation_worker" uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml grpo.max_num_steps=5 ``` ## Profile Output @@ -78,7 +84,10 @@ When profiling is enabled, it generates the following logs and files: 2. **Profile Files**: Each profiled worker generates a `.nsys-rep` file with naming pattern: ``` dtensor_policy_worker__.nsys-rep + vllm_generation_worker__.nsys-rep + worker_process_.nsys-rep ``` +If you are not using model parallelism in Vllm, you should directly refer to `vllm_generation_worker__.nsys-rep` for nsight reports; If you are using model parallelism, the `vllm_generation_worker__.nsys-rep` will be empty, and the `worker_process_.nsys-rep` are nsight profiles from vllm's ray distributed executors (refer to https://github.com/vllm-project/vllm/blob/7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad/vllm/executor/ray_distributed_executor.py#L136 for more information). 3. **File Location**: Profile files are saved in `/tmp/ray/session*/logs/nsight/` directory on each worker node. Ensure you check both `ls /tmp/ray/session_[0-9]*/logs/nsight` and `ls /tmp/ray/session_latest/logs/nsight` for the profiles, since the "latest" pointer may be stale. diff --git a/nemo_rl/distributed/worker_group_utils.py b/nemo_rl/distributed/worker_group_utils.py index fe2a9a03be..c51d3b8a7f 100644 --- a/nemo_rl/distributed/worker_group_utils.py +++ b/nemo_rl/distributed/worker_group_utils.py @@ -57,6 +57,7 @@ def get_nsight_config_if_pattern_matches(worker_name: str) -> dict[str, Any]: # Profile will only start/stop when torch.cuda.profiler.start()/stop() is called "capture-range": "cudaProfilerApi", "capture-range-end": "stop", + "cuda-graph-trace": "node", } } diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 8c9b6b99c9..ace30de017 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -53,6 +53,7 @@ ) from nemo_rl.models.huggingface.common import ModelFlag from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled +from nemo_rl.utils.nsys import wrap_with_nvtx_name class VllmSpecificArgs(TypedDict): @@ -323,6 +324,18 @@ def _patch_vllm_init_workers_ray(): if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name): load_format = "auto" + if ( + len(get_nsight_config_if_pattern_matches("vllm_generation_worker")) > 0 + and vllm_kwargs["distributed_executor_backend"] == "ray" + ): + logger.warning( + "Nsight profiling is enabled for vllm generation worker through the vllm ray distributed executor. " + "The nsight command-line args and output file names are automatically picked by the ray distributed " + "executor. Refer to https://github.com/vllm-project/vllm/blob/7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad/vllm/executor/ray_distributed_executor.py#L136 " + "for more information." + ) + vllm_kwargs["ray_workers_use_nsight"] = True + llm_kwargs = dict( model=self.model_name, load_format=load_format, @@ -436,6 +449,7 @@ def _build_sampling_params( include_stop_str_in_output=True, ) + @wrap_with_nvtx_name("vllm_genertion_worker/generate") def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: @@ -799,6 +813,7 @@ async def process_single_sample(sample_idx): await asyncio.gather(*sample_tasks, return_exceptions=True) raise e + @wrap_with_nvtx_name("vllm_genertion_worker/generate_text") def generate_text( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: @@ -1033,6 +1048,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non """Async version of prepare_refit_info.""" await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles") def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights from IPC handles by delegating to the vLLM Worker implementation. @@ -1144,6 +1160,7 @@ async def update_weights_from_ipc_handles_async( traceback.print_exc() return False + @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective") def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" try: @@ -1317,10 +1334,14 @@ async def wake_up_async(self, **kwargs): def start_gpu_profiling(self) -> None: """Start GPU profiling.""" torch.cuda.profiler.start() + if self.llm is not None: + self.llm.collective_rpc("start_gpu_profiling", args=tuple()) def stop_gpu_profiling(self) -> None: """Stop GPU profiling.""" torch.cuda.profiler.stop() + if self.llm is not None: + self.llm.collective_rpc("stop_gpu_profiling", args=tuple()) class VllmGeneration(GenerationInterface): diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 1861f7643d..92c7916c43 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -17,6 +17,8 @@ import torch from torch.multiprocessing.reductions import rebuild_cuda_tensor +from nemo_rl.utils.nsys import wrap_with_nvtx_name + try: import vllm # noqa: F401 except ImportError: @@ -66,6 +68,9 @@ def prepare_refit_info( """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + @wrap_with_nvtx_name( + "vllm_internal_worker_extension/update_weights_from_global_ipc_handles" + ) def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): """Update weights from global IPC handles. @@ -79,6 +84,9 @@ def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): local_device_ipc_handles = global_device_ipc_handles[device_uuid] return self.update_weights_from_local_ipc_handles(local_device_ipc_handles) + @wrap_with_nvtx_name( + "vllm_internal_worker_extension/update_weights_from_local_ipc_handles" + ) def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): """Update weights from local IPC handles. @@ -155,6 +163,9 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): ) return False + @wrap_with_nvtx_name( + "vllm_internal_worker_extension/update_weights_from_collective" + ) def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" assert self.state_dict_info is not None, ( @@ -174,3 +185,11 @@ def update_weights_from_collective(self) -> bool: return False return True + + def start_gpu_profiling(self) -> None: + """Start GPU profiling.""" + torch.cuda.profiler.start() + + def stop_gpu_profiling(self) -> None: + """Stop GPU profiling.""" + torch.cuda.profiler.stop() diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index aba599923f..e0d13da319 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -77,6 +77,7 @@ load_checkpoint, save_checkpoint, ) +from nemo_rl.utils.nsys import wrap_with_nvtx_name @contextmanager @@ -513,6 +514,7 @@ def get_gpu_info(self) -> dict[str, Any]: """Return information about the GPU being used by this worker.""" return get_gpu_info(self.model) + @wrap_with_nvtx_name("dtensor_policy_worker/train") def train( self, data: BatchedDataDict[Any], @@ -855,6 +857,7 @@ def train( return metrics + @wrap_with_nvtx_name("dtensor_policy_worker/get_logprobs") def get_logprobs( self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None ) -> BatchedDataDict[LogprobOutputSpec]: @@ -1137,6 +1140,7 @@ def use_reference_model(self) -> Generator[None, None, None]: val = to_local_if_dtensor(v) val.copy_(curr_state_dict[k]) + @wrap_with_nvtx_name("dtensor_policy_worker/get_reference_policy_logprobs") def get_reference_policy_logprobs( self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None ) -> BatchedDataDict[ReferenceLogprobOutputSpec]: @@ -1234,6 +1238,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: return self.refit_param_info, total_available_bytes @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/get_weights_ipc_handles") def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: assert self._held_sharded_state_dict_reference is not None, ( "prepare_weights_for_ipc must be called before get_weights_ipc_handles" @@ -1296,6 +1301,7 @@ def broadcast_weights_for_collective(self) -> None: if self.cpu_offload: self.model = self.move_to_cpu(self.model) + @wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_lp_inference") def prepare_for_lp_inference(self) -> None: if not self.cpu_offload: self.move_to_cuda(self.model) @@ -1305,6 +1311,7 @@ def prepare_for_lp_inference(self) -> None: self.model.eval() self.offload_before_refit() + @wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_training") def prepare_for_training(self, *args, **kwargs) -> None: # onload models and optimizer state to cuda if not self.cpu_offload: @@ -1329,6 +1336,7 @@ def prepare_for_training(self, *args, **kwargs) -> None: torch.cuda.empty_cache() @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/offload_before_refit") def offload_before_refit(self) -> None: """Offload the optimizer to the CPU.""" torch.randn(1).cuda() # wake up torch allocator @@ -1342,6 +1350,7 @@ def offload_before_refit(self) -> None: torch.cuda.empty_cache() @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/offload_after_refit") def offload_after_refit(self) -> None: # Offload as much as possible on the CPU self.model = self.move_to_cpu(self.model) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 07c9594322..3a94ffa98c 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -124,6 +124,7 @@ get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, ) +from nemo_rl.utils.nsys import wrap_with_nvtx_name TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -765,6 +766,7 @@ def disable_forward_pre_hook(self, param_sync=True): assert isinstance(self.model, DistributedDataParallel) self.model.disable_forward_pre_hook(param_sync=param_sync) + @wrap_with_nvtx_name("megatron_policy_worker/train") def train( self, data: BatchedDataDict, @@ -1010,6 +1012,7 @@ def train( } return metrics + @wrap_with_nvtx_name("megatron_policy_worker/get_logprobs") def get_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None ) -> BatchedDataDict[LogprobOutputSpec]: @@ -1240,6 +1243,7 @@ def use_reference_model(self): self.enable_forward_pre_hook() # Temporary fix, 'data' is a kwarg due to some sort of ray bug + @wrap_with_nvtx_name("megatron_policy_worker/get_reference_policy_logprobs") def get_reference_policy_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None ) -> BatchedDataDict[ReferenceLogprobOutputSpec]: @@ -1262,6 +1266,7 @@ def get_reference_policy_logprobs( return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() return return_data + @wrap_with_nvtx_name("megatron_policy_worker/generate") def generate( self, *, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: @@ -1405,6 +1410,7 @@ def report_device_id(self) -> str: return get_device_uuid(device_idx) @torch.no_grad() + @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: # Get parameter info for refit # param_info: list of ((name, shape, dtype), size_in_bytes) tuples @@ -1439,6 +1445,7 @@ def prepare_refit_info(self) -> None: return refit_param_info_hf + @wrap_with_nvtx_name("megatron_policy_worker/prepare_weights_for_ipc") def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1460,6 +1467,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() + @wrap_with_nvtx_name("megatron_policy_worker/get_weights_ipc_handles") def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: """Get IPC handles for the requested Megatron model weights. @@ -1592,6 +1600,7 @@ def prepare_for_training(self, *args, **kwargs): torch.cuda.empty_cache() + @wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit") def offload_before_refit(self): """Offload the optimizer and buffers to the CPU.""" no_grad = torch.no_grad() @@ -1630,6 +1639,7 @@ def offload_before_refit(self): ) no_grad.__exit__(None, None, None) + @wrap_with_nvtx_name("megatron_policy_worker/offload_after_refit") def offload_after_refit(self): no_grad = torch.no_grad() no_grad.__enter__() diff --git a/nemo_rl/utils/nsys.py b/nemo_rl/utils/nsys.py index b5609f8c41..d9282970ab 100644 --- a/nemo_rl/utils/nsys.py +++ b/nemo_rl/utils/nsys.py @@ -16,6 +16,7 @@ from typing import Protocol import rich +import torch NRL_NSYS_WORKER_PATTERNS = os.environ.get("NRL_NSYS_WORKER_PATTERNS", "") NRL_NSYS_PROFILE_STEP_RANGE = os.environ.get("NRL_NSYS_PROFILE_STEP_RANGE", "") @@ -76,3 +77,18 @@ def stop_profiler_on_exit(): ) policy.stop_gpu_profiling() policy.__NRL_PROFILE_STARTED = False + + +def wrap_with_nvtx_name(name: str): + """A decorator to wrap a function with an NVTX range with the given name.""" + + def decorator(func): + def wrapper(*args, **kwargs): + torch.cuda.nvtx.range_push(name) + ret = func(*args, **kwargs) + torch.cuda.nvtx.range_pop() + return ret + + return wrapper + + return decorator