Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions docs/nsys-profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ NeMo RL supports Nsight profiling for Ray workers through environment variable p
Set the `NRL_NSYS_WORKER_PATTERNS` environment variable with a comma-separated list of patterns to match worker names:

```bash
export NRL_NSYS_WORKER_PATTERNS="*policy*,*other-worker*"
export NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*"
```

Set the `NRL_NSYS_PROFILE_STEP_RANGE` environment variable to control which training steps the profiler captures. Its
Expand All @@ -40,7 +40,7 @@ export NRL_NSYS_PROFILE_STEP_RANGE=3:5

The supported worker types are:
- **DTensorPolicyWorker**: Pattern matched against `"dtensor_policy_worker"`
- **MegatronPolicyWorker**: Pattern matched against `"megatron_policy_worker"`
- **VllmGenerationWorker**: Pattern matched against `"vllm_generation_worker"`

## Example Usage

Expand All @@ -49,10 +49,16 @@ The supported worker types are:
NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="*policy*" uv run examples/run_grpo_math.py grpo.max_num_steps=5
```

### Profile Multiple Worker Types

```bash
NRL_NSYS_PROFILE_STEP_RANGE=1:2 NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*" uv run examples/run_grpo_math.py grpo.max_num_steps=5
```

### Profile Workers with Exact Names

```bash
NRL_NSYS_PROFILE_STEP_RANGE=3:10 NRL_NSYS_WORKER_PATTERNS="dtensor_policy_worker" uv run examples/run_grpo_math.py grpo.max_num_steps=5
NRL_NSYS_PROFILE_STEP_RANGE=3:10 NRL_NSYS_WORKER_PATTERNS="dtensor_policy_worker,vllm_generation_worker" uv run examples/run_grpo_math.py grpo.max_num_steps=5
```

### Profile Megatron Workers
Expand All @@ -63,7 +69,7 @@ To profile a Megatron worker, you should set `LD_LIBRARY_PATH` as follows, other

```bash
LD_LIBRARY_PATH="/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/lib64:/usr/local/cuda/lib:/usr/local/nvidia/lib64:/usr/local/nvidia/lib:/usr/lib/x86_64-linux-gnu" \
NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="megatron_policy_worker" uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml grpo.max_num_steps=5
NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="megatron_policy_worker,vllm_generation_worker" uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml grpo.max_num_steps=5
```

## Profile Output
Expand All @@ -78,7 +84,10 @@ When profiling is enabled, it generates the following logs and files:
2. **Profile Files**: Each profiled worker generates a `.nsys-rep` file with naming pattern:
```
dtensor_policy_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep
vllm_generation_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep
worker_process_<PID>.nsys-rep
```
If you are not using model parallelism in Vllm, you should directly refer to `vllm_generation_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep` for nsight reports; If you are using model parallelism, the `vllm_generation_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep` will be empty, and the `worker_process_<PID>.nsys-rep` are nsight profiles from vllm's ray distributed executors (refer to https://github.com/vllm-project/vllm/blob/7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad/vllm/executor/ray_distributed_executor.py#L136 for more information).

3. **File Location**: Profile files are saved in `/tmp/ray/session*/logs/nsight/` directory on each worker node. Ensure you check both `ls /tmp/ray/session_[0-9]*/logs/nsight` and `ls /tmp/ray/session_latest/logs/nsight` for the profiles, since the "latest" pointer may be stale.

Expand Down
1 change: 1 addition & 0 deletions nemo_rl/distributed/worker_group_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_nsight_config_if_pattern_matches(worker_name: str) -> dict[str, Any]:
# Profile will only start/stop when torch.cuda.profiler.start()/stop() is called
"capture-range": "cudaProfilerApi",
"capture-range-end": "stop",
"cuda-graph-trace": "node",
}
}

Expand Down
21 changes: 21 additions & 0 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from nemo_rl.models.huggingface.common import ModelFlag
from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled
from nemo_rl.utils.nsys import wrap_with_nvtx_name


class VllmSpecificArgs(TypedDict):
Expand Down Expand Up @@ -323,6 +324,18 @@ def _patch_vllm_init_workers_ray():
if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name):
load_format = "auto"

if (
len(get_nsight_config_if_pattern_matches("vllm_generation_worker")) > 0
and vllm_kwargs["distributed_executor_backend"] == "ray"
):
logger.warning(
"Nsight profiling is enabled for vllm generation worker through the vllm ray distributed executor. "
"The nsight command-line args and output file names are automatically picked by the ray distributed "
"executor. Refer to https://github.com/vllm-project/vllm/blob/7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad/vllm/executor/ray_distributed_executor.py#L136 "
"for more information."
)
vllm_kwargs["ray_workers_use_nsight"] = True

llm_kwargs = dict(
model=self.model_name,
load_format=load_format,
Expand Down Expand Up @@ -436,6 +449,7 @@ def _build_sampling_params(
include_stop_str_in_output=True,
)

@wrap_with_nvtx_name("vllm_genertion_worker/generate")
def generate(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
) -> BatchedDataDict[GenerationOutputSpec]:
Expand Down Expand Up @@ -799,6 +813,7 @@ async def process_single_sample(sample_idx):
await asyncio.gather(*sample_tasks, return_exceptions=True)
raise e

@wrap_with_nvtx_name("vllm_genertion_worker/generate_text")
def generate_text(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
) -> BatchedDataDict[GenerationOutputSpec]:
Expand Down Expand Up @@ -1033,6 +1048,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non
"""Async version of prepare_refit_info."""
await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))

@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles")
def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights from IPC handles by delegating to the vLLM Worker implementation.

Expand Down Expand Up @@ -1144,6 +1160,7 @@ async def update_weights_from_ipc_handles_async(
traceback.print_exc()
return False

@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective")
def update_weights_from_collective(self) -> bool:
"""Update the model weights from collective communication."""
try:
Expand Down Expand Up @@ -1317,10 +1334,14 @@ async def wake_up_async(self, **kwargs):
def start_gpu_profiling(self) -> None:
"""Start GPU profiling."""
torch.cuda.profiler.start()
if self.llm is not None:
self.llm.collective_rpc("start_gpu_profiling", args=tuple())

def stop_gpu_profiling(self) -> None:
"""Stop GPU profiling."""
torch.cuda.profiler.stop()
if self.llm is not None:
self.llm.collective_rpc("stop_gpu_profiling", args=tuple())


class VllmGeneration(GenerationInterface):
Expand Down
19 changes: 19 additions & 0 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.multiprocessing.reductions import rebuild_cuda_tensor

from nemo_rl.utils.nsys import wrap_with_nvtx_name

try:
import vllm # noqa: F401
except ImportError:
Expand Down Expand Up @@ -66,6 +68,9 @@ def prepare_refit_info(
"""
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_global_ipc_handles"
)
def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
"""Update weights from global IPC handles.

Expand All @@ -79,6 +84,9 @@ def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
local_device_ipc_handles = global_device_ipc_handles[device_uuid]
return self.update_weights_from_local_ipc_handles(local_device_ipc_handles)

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_local_ipc_handles"
)
def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
"""Update weights from local IPC handles.

Expand Down Expand Up @@ -155,6 +163,9 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
)
return False

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_collective"
)
def update_weights_from_collective(self) -> bool:
"""Update the model weights from collective communication."""
assert self.state_dict_info is not None, (
Expand All @@ -174,3 +185,11 @@ def update_weights_from_collective(self) -> bool:
return False

return True

def start_gpu_profiling(self) -> None:
"""Start GPU profiling."""
torch.cuda.profiler.start()

def stop_gpu_profiling(self) -> None:
"""Stop GPU profiling."""
torch.cuda.profiler.stop()
9 changes: 9 additions & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
load_checkpoint,
save_checkpoint,
)
from nemo_rl.utils.nsys import wrap_with_nvtx_name


@contextmanager
Expand Down Expand Up @@ -513,6 +514,7 @@ def get_gpu_info(self) -> dict[str, Any]:
"""Return information about the GPU being used by this worker."""
return get_gpu_info(self.model)

@wrap_with_nvtx_name("dtensor_policy_worker/train")
def train(
self,
data: BatchedDataDict[Any],
Expand Down Expand Up @@ -855,6 +857,7 @@ def train(

return metrics

@wrap_with_nvtx_name("dtensor_policy_worker/get_logprobs")
def get_logprobs(
self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
) -> BatchedDataDict[LogprobOutputSpec]:
Expand Down Expand Up @@ -1137,6 +1140,7 @@ def use_reference_model(self) -> Generator[None, None, None]:
val = to_local_if_dtensor(v)
val.copy_(curr_state_dict[k])

@wrap_with_nvtx_name("dtensor_policy_worker/get_reference_policy_logprobs")
def get_reference_policy_logprobs(
self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
Expand Down Expand Up @@ -1234,6 +1238,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
return self.refit_param_info, total_available_bytes

@torch.no_grad()
@wrap_with_nvtx_name("dtensor_policy_worker/get_weights_ipc_handles")
def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
assert self._held_sharded_state_dict_reference is not None, (
"prepare_weights_for_ipc must be called before get_weights_ipc_handles"
Expand Down Expand Up @@ -1296,6 +1301,7 @@ def broadcast_weights_for_collective(self) -> None:
if self.cpu_offload:
self.model = self.move_to_cpu(self.model)

@wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_lp_inference")
def prepare_for_lp_inference(self) -> None:
if not self.cpu_offload:
self.move_to_cuda(self.model)
Expand All @@ -1305,6 +1311,7 @@ def prepare_for_lp_inference(self) -> None:
self.model.eval()
self.offload_before_refit()

@wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_training")
def prepare_for_training(self, *args, **kwargs) -> None:
# onload models and optimizer state to cuda
if not self.cpu_offload:
Expand All @@ -1329,6 +1336,7 @@ def prepare_for_training(self, *args, **kwargs) -> None:
torch.cuda.empty_cache()

@torch.no_grad()
@wrap_with_nvtx_name("dtensor_policy_worker/offload_before_refit")
def offload_before_refit(self) -> None:
"""Offload the optimizer to the CPU."""
torch.randn(1).cuda() # wake up torch allocator
Expand All @@ -1342,6 +1350,7 @@ def offload_before_refit(self) -> None:
torch.cuda.empty_cache()

@torch.no_grad()
@wrap_with_nvtx_name("dtensor_policy_worker/offload_after_refit")
def offload_after_refit(self) -> None:
# Offload as much as possible on the CPU
self.model = self.move_to_cpu(self.model)
Expand Down
10 changes: 10 additions & 0 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
get_megatron_checkpoint_dir,
get_runtime_env_for_policy_worker,
)
from nemo_rl.utils.nsys import wrap_with_nvtx_name

TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)

Expand Down Expand Up @@ -765,6 +766,7 @@ def disable_forward_pre_hook(self, param_sync=True):
assert isinstance(self.model, DistributedDataParallel)
self.model.disable_forward_pre_hook(param_sync=param_sync)

@wrap_with_nvtx_name("megatron_policy_worker/train")
def train(
self,
data: BatchedDataDict,
Expand Down Expand Up @@ -1010,6 +1012,7 @@ def train(
}
return metrics

@wrap_with_nvtx_name("megatron_policy_worker/get_logprobs")
def get_logprobs(
self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
) -> BatchedDataDict[LogprobOutputSpec]:
Expand Down Expand Up @@ -1240,6 +1243,7 @@ def use_reference_model(self):
self.enable_forward_pre_hook()

# Temporary fix, 'data' is a kwarg due to some sort of ray bug
@wrap_with_nvtx_name("megatron_policy_worker/get_reference_policy_logprobs")
def get_reference_policy_logprobs(
self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
Expand All @@ -1262,6 +1266,7 @@ def get_reference_policy_logprobs(
return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu()
return return_data

@wrap_with_nvtx_name("megatron_policy_worker/generate")
def generate(
self, *, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
) -> BatchedDataDict[GenerationOutputSpec]:
Expand Down Expand Up @@ -1405,6 +1410,7 @@ def report_device_id(self) -> str:
return get_device_uuid(device_idx)

@torch.no_grad()
@wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info")
def prepare_refit_info(self) -> None:
# Get parameter info for refit
# param_info: list of ((name, shape, dtype), size_in_bytes) tuples
Expand Down Expand Up @@ -1439,6 +1445,7 @@ def prepare_refit_info(self) -> None:

return refit_param_info_hf

@wrap_with_nvtx_name("megatron_policy_worker/prepare_weights_for_ipc")
def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
"""Prepare Megatron model weights for IPC transfer to vLLM.

Expand All @@ -1460,6 +1467,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:

# Temporary fix, 'keys' is a kwarg due to some sort of ray bug
@torch.no_grad()
@wrap_with_nvtx_name("megatron_policy_worker/get_weights_ipc_handles")
def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:
"""Get IPC handles for the requested Megatron model weights.

Expand Down Expand Up @@ -1592,6 +1600,7 @@ def prepare_for_training(self, *args, **kwargs):

torch.cuda.empty_cache()

@wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit")
def offload_before_refit(self):
"""Offload the optimizer and buffers to the CPU."""
no_grad = torch.no_grad()
Expand Down Expand Up @@ -1630,6 +1639,7 @@ def offload_before_refit(self):
)
no_grad.__exit__(None, None, None)

@wrap_with_nvtx_name("megatron_policy_worker/offload_after_refit")
def offload_after_refit(self):
no_grad = torch.no_grad()
no_grad.__enter__()
Expand Down
16 changes: 16 additions & 0 deletions nemo_rl/utils/nsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Protocol

import rich
import torch

NRL_NSYS_WORKER_PATTERNS = os.environ.get("NRL_NSYS_WORKER_PATTERNS", "")
NRL_NSYS_PROFILE_STEP_RANGE = os.environ.get("NRL_NSYS_PROFILE_STEP_RANGE", "")
Expand Down Expand Up @@ -76,3 +77,18 @@ def stop_profiler_on_exit():
)
policy.stop_gpu_profiling()
policy.__NRL_PROFILE_STARTED = False


def wrap_with_nvtx_name(name: str):
"""A decorator to wrap a function with an NVTX range with the given name."""

def decorator(func):
def wrapper(*args, **kwargs):
torch.cuda.nvtx.range_push(name)
ret = func(*args, **kwargs)
torch.cuda.nvtx.range_pop()
return ret

return wrapper

return decorator