diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 7a2e64f1312..39cd27f5dcf 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -201,7 +201,81 @@ Trace files are written to the `torch_profiler_dir` specified in your stage YAML > **Important:** Always stop the same stages you started. Stopping a stage that was never started will produce errors. -### 5. Analyzing Traces +### 5. Nsight Systems (nsys) Profiling + +vLLM-Omni supports NVIDIA Nsight Systems profiling via the `cuda` profiler backend. This uses `torch.cuda.profiler` to mark profiling ranges that nsys captures, and `torch.cuda.nvtx` for annotating trace regions. + +#### Configuration + +Set `profiler: cuda` in your profiler config. Unlike the `torch` profiler, no `torch_profiler_dir` is needed since nsys manages its own output files. + +**Stage YAML (for omni or diffusion stages):** +```yaml +stage_args: + - stage_id: 0 + stage_type: llm + engine_args: + profiler_config: + profiler: cuda +``` + +**Diffusion CLI:** +```bash +python image_to_video.py \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --profiler-config '{"profiler": "cuda"}' \ + ... +``` + +#### Running with nsys + +Wrap your command with `nsys profile` using `--capture-range=cudaProfilerApi` so nsys only captures the region between `start_profile()` and `stop_profile()`: + +```bash +nsys profile \ + --capture-range=cudaProfilerApi \ + --capture-range-end=repeat \ + --trace-fork-before-exec=true \ + --cuda-graph-trace=node \ + -o my_trace \ + python your_script.py ... +``` + +| Flag | Purpose | +|---|---| +| `--capture-range=cudaProfilerApi` | Only capture between CUDA profiler start/stop calls | +| `--capture-range-end=repeat` | Allow multiple start/stop cycles | +| `--trace-fork-before-exec=true` | Trace child processes (needed for multi-worker) | +| `--cuda-graph-trace=node` | Trace individual kernels inside CUDA graphs | + +#### Online Serving with nsys + +1. Start the server under nsys with `profiler: cuda` in your stage YAML: +```bash +nsys profile \ + --capture-range=cudaProfilerApi \ + --capture-range-end=repeat \ + --trace-fork-before-exec=true \ + --cuda-graph-trace=node \ + -o serving_trace \ + vllm serve Qwen/Qwen2.5-Omni-7B \ + --omni \ + --stage-configs-path qwen2_5_omni.yaml \ + --port 8091 +``` + +2. Trigger profiling via HTTP endpoints: +```bash +curl -X POST http://localhost:8091/start_profile +# ... send inference requests ... +curl -X POST http://localhost:8091/stop_profile +``` + +#### Analyzing nsys Traces + +Open the `.nsys-rep` output file in NVIDIA Nsight Systems GUI for detailed kernel-level analysis, memory transfers, and NVTX annotations. + +### 6. Analyzing Traces Output files are saved to the `torch_profiler_dir` specified in your stage YAML config. diff --git a/tests/diffusion/test_diffusion_worker_cuda_profiler.py b/tests/diffusion/test_diffusion_worker_cuda_profiler.py new file mode 100644 index 00000000000..c51f95f9c90 --- /dev/null +++ b/tests/diffusion/test_diffusion_worker_cuda_profiler.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Unit tests for DiffusionWorker CUDA (nsys) profiler integration. + +Verifies that DiffusionWorker correctly creates a CudaProfilerWrapper +when profiler_config.profiler == "cuda", enabling Nsight Systems profiling. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +@pytest.fixture +def mock_od_config(): + """Create a mock OmniDiffusionConfig with cuda profiler.""" + config = MagicMock() + config.num_gpus = 1 + config.master_port = 12345 + config.enable_sleep_mode = False + config.cache_backend = None + config.cache_config = None + config.model = "test-model" + config.diffusion_load_format = "default" + config.max_cpu_loras = 0 + config.lora_path = None + config.lora_scale = 1.0 + + # Profiler config for cuda (nsys) + profiler_config = MagicMock() + profiler_config.profiler = "cuda" + profiler_config.delay_iterations = 0 + profiler_config.max_iterations = 0 + config.profiler_config = profiler_config + + # Parallel config + parallel_config = MagicMock() + parallel_config.tensor_parallel_size = 1 + parallel_config.data_parallel_size = 1 + parallel_config.enable_expert_parallel = False + config.parallel_config = parallel_config + + return config + + +class TestDiffusionWorkerCudaProfiler: + """Test that DiffusionWorker creates CudaProfilerWrapper for nsys.""" + + def test_cuda_profiler_created(self, mock_od_config): + """DiffusionWorker should create CudaProfilerWrapper when profiler == 'cuda'.""" + from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker + + with ( + patch.object(DiffusionWorker, "init_device"), + patch.object(DiffusionWorker, "load_model"), + patch.object(DiffusionWorker, "init_lora_manager"), + ): + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + from vllm.profiler.wrapper import CudaProfilerWrapper + + assert worker.profiler is not None + assert isinstance(worker.profiler, CudaProfilerWrapper) + + def test_torch_profiler_not_created_for_cuda(self, mock_od_config): + """CudaProfilerWrapper should not be an OmniTorchProfilerWrapper.""" + from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker + + with ( + patch.object(DiffusionWorker, "init_device"), + patch.object(DiffusionWorker, "load_model"), + patch.object(DiffusionWorker, "init_lora_manager"), + ): + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + from vllm_omni.profiler import OmniTorchProfilerWrapper + + assert not isinstance(worker.profiler, OmniTorchProfilerWrapper) + + def test_no_profiler_when_none(self, mock_od_config): + """DiffusionWorker should have no profiler when profiler_config is None.""" + mock_od_config.profiler_config = None + + from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker + + with ( + patch.object(DiffusionWorker, "init_device"), + patch.object(DiffusionWorker, "load_model"), + patch.object(DiffusionWorker, "init_lora_manager"), + ): + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + assert worker.profiler is None + + def test_profile_start_stop_with_cuda_profiler(self, mock_od_config): + """profile() should call start/stop on CudaProfilerWrapper without errors.""" + from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker + + with ( + patch.object(DiffusionWorker, "init_device"), + patch.object(DiffusionWorker, "load_model"), + patch.object(DiffusionWorker, "init_lora_manager"), + ): + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + # Mock the underlying cuda profiler methods to avoid actual CUDA calls + worker.profiler._cuda_profiler = MagicMock() + + # Should not raise + worker.profile(is_start=True) + worker.profile(is_start=False) + + # Verify start and stop were called + worker.profiler._cuda_profiler.start.assert_called_once() + worker.profiler._cuda_profiler.stop.assert_called_once() diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 6e1cabba0ce..6c616298eba 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -20,6 +20,7 @@ from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger +from vllm.profiler.wrapper import WorkerProfiler from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.mem_utils import GiB_bytes from vllm.v1.worker.workspace import init_workspace_manager @@ -42,7 +43,7 @@ from vllm_omni.diffusion.worker.utils import RunnerOutput from vllm_omni.lora.request import LoRARequest from vllm_omni.platforms import current_omni_platform -from vllm_omni.profiler import OmniTorchProfilerWrapper, create_omni_profiler +from vllm_omni.profiler import create_omni_profiler from vllm_omni.worker.gpu_memory_utils import get_process_gpu_memory logger = init_logger(__name__) @@ -84,7 +85,7 @@ def __init__( device=self.device, ) # Initialize profiler if configured - self.profiler: OmniTorchProfilerWrapper | None = None + self.profiler: WorkerProfiler | None = None profiler_config = self.od_config.profiler_config if profiler_config and profiler_config.profiler == "torch": self.profiler = create_omni_profiler( @@ -92,6 +93,10 @@ def __init__( worker_name=f"diffusion_worker_{self.rank}", local_rank=self.local_rank, ) + elif profiler_config and profiler_config.profiler == "cuda": + from vllm.profiler.wrapper import CudaProfilerWrapper + + self.profiler = CudaProfilerWrapper(profiler_config) if not skip_load_model: self.load_model(load_format=self.od_config.diffusion_load_format) self.init_lora_manager()