diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 7a2e64f1312..418fb707ae9 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -1,216 +1,192 @@ # Profiling vLLM-Omni -> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production. +> **Warning:** Profiling is for development and debugging only. It adds significant overhead and should not be enabled in production. -vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**. +vLLM-Omni supports two profiler backends through `profiler_config`: -### 1. Configure Profiling in the Stage YAML +- `torch`: detailed CPU/CUDA traces written to `torch_profiler_dir` +- `cuda`: low-overhead CUDA range control for NVIDIA Nsight Systems (`nsys`) -Enable profiling by adding `profiler_config` under `engine_args` for the stage(s) you want to profile in your stage config YAML: +## 1. Configure Profiling + +Use the same `profiler_config` shape everywhere: + +```yaml +profiler_config: + profiler: torch + torch_profiler_dir: ./perf +``` + +Supported fields: + +| Field | Description | +|---|---| +| `profiler` | Profiler backend. Supported values: `torch`, `cuda`. | +| `torch_profiler_dir` | Output directory for torch traces. Required when `profiler: torch`. | +| `delay_iterations` | Number of worker iterations to skip before profiling starts. | +| `max_iterations` | Maximum number of worker iterations to capture before auto-stop. | +| `warmup_iterations` | Torch-profiler warmup iterations. | +| `active_iterations` | Torch-profiler active iterations. | +| `wait_iterations` | Torch-profiler wait iterations before warmup. | + +For multi-stage omni pipelines, put `profiler_config` under the target stage's `engine_args`. ```yaml stage_args: - stage_id: 0 stage_type: llm engine_args: - # ... other engine args ... profiler_config: profiler: torch torch_profiler_dir: ./perf ``` -| Field | Description | -|---|---| -| `profiler` | Profiler backend to use. Currently supports `torch`. | -| `torch_profiler_dir` | Directory where trace files are saved. Created automatically if it doesn't exist. | - -> **Tip:** Only enable `profiler_config` on stages you actually need to profile. Stages without it will not start a profiler, keeping overhead minimal. - -### 2. Profiling Omni-Modality Models +For single-stage diffusion usage, pass `profiler_config` directly to `Omni(...)` or `vllm serve`. -**Selective Stage Profiling** +## 2. Profiling Omni Pipelines -It is highly recommended to profile specific stages to prevent producing overly large trace files: +It is usually best to profile only the stages you need. ```python -# Profile all stages -omni_llm.start_profile() +# Profile all stages. +omni.start_profile() -# Only profile Stage 1 -omni_llm.start_profile(stages=[1]) - -# Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni -omni_llm.start_profile(stages=[0, 2]) +# Profile selected stages only. +omni.start_profile(stages=[0, 2]) +... +omni.stop_profile(stages=[0, 2]) ``` -> **Important:** Always pass the same `stages` list to both `start_profile()` and `stop_profile()`. If you omit `stages` from `stop_profile()`, it defaults to stopping all stages — including ones that were never started — which will produce errors. - -**Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`. +Always stop the same stage set that you started. If only some stages have `profiler_config`, pass an explicit `stages=[...]` list instead of relying on the default "all stages" behavior. -```python -profiler_stages = [0] # Only profile the stages you need +Examples: -# 1. Start profiling -omni.start_profile(stages=profiler_stages) +1. [Qwen2.5-Omni end2end](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py) +2. [Qwen3-Omni end2end](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py) -# Initialize generator -omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator) +## 3. Profiling Single-Stage Diffusion -total_requests = len(prompts) -processed_count = 0 +Single-stage diffusion models use the same `start_profile()` / `stop_profile()` controls, but you must provide `profiler_config` explicitly. -# Main Processing Loop -for stage_outputs in omni_generator: +### PyTorch profiler - # ... [Output processing logic for text/audio would go here] ... +```python +from vllm_omni import Omni + +omni = Omni( + model="Wan-AI/Wan2.2-I2V-A14B-Diffusers", + profiler_config={ + "profiler": "torch", + "torch_profiler_dir": "./perf", + }, +) + +omni.start_profile() +... +omni.stop_profile() +``` - # Update count to track when to stop profiling - processed_count += len(stage_outputs.request_output) +### Nsight Systems (`nsys`) - # 2. Check if all requests are done to stop the profiler safely - if profiler_enabled and processed_count >= total_requests: - print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") +For Nsight Systems, use `profiler: cuda` and wrap the process with `nsys profile`. - # Stop the profiler while workers are still active - # Pass the same stages list used in start_profile() - omni_llm.stop_profile(stages=profiler_stages) +```bash +nsys profile \ + --trace-fork-before-exec=true \ + --cuda-graph-trace=node \ + --capture-range=cudaProfilerApi \ + --capture-range-end=repeat \ + -o diffusion_trace \ + python image_to_video.py ... +``` - # Wait for traces to flush to disk - print("[Info] Waiting 30s for workers to write trace files to disk...") - time.sleep(30) - print("[Info] Trace export wait time finished.") +The Python process being profiled must create the diffusion engine with: -omni_llm.close() +```python +profiler_config={"profiler": "cuda"} ``` +Then call `start_profile()` before the requests you want to capture and `stop_profile()` after them. The diffusion worker processes open and close the CUDA capture range themselves, so `nsys` sees the actual GPU work instead of only the parent process. -**CLI Usage** (using `end2end.py`): -```bash -# Profile only Stage 0 (Thinker) -python end2end.py --output-wav output_audio \ - --query-type text --enable-profiler --profiler-stages 0 +Examples: -# Profile Stage 0 and Stage 2 -python end2end.py --output-wav output_audio \ - --query-type text --enable-profiler --profiler-stages 0 2 +1. [Image edit example](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) +2. [Image to video example](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video) -# Profile all stages (omit --profiler-stages) -python end2end.py --output-wav output_audio \ - --query-type text --enable-profiler -``` +## 4. Profiling Online Serving -**Examples**: +When any stage has `profiler_config.profiler` set, the server exposes: -1. **Qwen2.5-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py) +- `POST /start_profile` +- `POST /stop_profile` -2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py) +### Start the server -### 3. Profiling diffusion models +Multi-stage omni serving: -Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding. Standalone diffusion scripts use `--profiler-dir` to enable profiling. - -**CLI Usage:** ```bash -python image_to_video.py \ - --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ - --image qwen-bear.png \ - --prompt "A cat playing with yarn, smooth motion" \ - --profiler-dir \ - \ - # Minimize Spatial Dimensions (Optional but helpful): - # Drastically reduces memory usage so the profiler doesn't - # crash due to overhead, though for accurate performance - # tuning you often want target resolutions. - --height 48 \ - --width 64 \ - \ - # Minimize Temporal Dimension (Frames): - # Video models process 3D tensors (Time, Height, Width). - # Reducing frames to the absolute minimum (2) keeps the - # tensor size small, ensuring the trace file doesn't become - # multi-gigabytes in size. - --num-frames 2 \ - \ - # Minimize Iteration Loop (Steps): - # This is the most critical setting for profiling. - # Diffusion models run the same loop X times. - # Profiling 2 steps gives you the exact same performance - # data as 50 steps, but saves minutes of runtime and - # prevents the trace viewer from freezing. - --num-inference-steps 2 \ - \ - --guidance-scale 5.0 \ - --guidance-scale-high 6.0 \ - --boundary-ratio 0.875 \ - --flow-shift 12.0 \ - --fps 16 \ - --output i2v_output.mp4 +vllm serve Qwen/Qwen2.5-Omni-7B \ + --omni \ + --stage-configs-path qwen2_5_omni.yaml \ + --port 8091 ``` -> **Note:** For diffusion stages within a multi-stage omni pipeline, use `profiler_config` in the stage YAML instead (see Section 1). - -**Examples**: - -1. **Qwen image edit**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) - -2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video) - -### 4. Profiling Online Serving - -When `profiler_config` is set in the stage YAML, the server automatically exposes `/start_profile` and `/stop_profile` HTTP endpoints. +Single-stage diffusion serving with torch profiler: -**1. Start the server** with a stage YAML that has `profiler_config` enabled: ```bash -vllm serve Qwen/Qwen2.5-Omni-7B \ - --omni \ - --stage-configs-path qwen2_5_omni.yaml \ - --port 8091 +vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --omni \ + --port 8091 \ + --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' ``` -Or for one stage diffusion models: +Single-stage diffusion serving with Nsight Systems: ```bash -vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers --omni --port 8091 --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' +nsys profile \ + --trace-fork-before-exec=true \ + --cuda-graph-trace=node \ + --capture-range=cudaProfilerApi \ + --capture-range-end=repeat \ + -o serving_trace \ + vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --omni \ + --port 8091 \ + --profiler-config '{"profiler": "cuda"}' ``` -**2. Start profiling** by sending a POST request: +### Control capture + ```bash -# Profile all stages that have profiler_config set +# Start profiling on all profiled stages. curl -X POST http://localhost:8091/start_profile -# Profile specific stages only +# Start profiling on selected stages. curl -X POST http://localhost:8091/start_profile \ - -H "Content-Type: application/json" \ - -d '{"stages": [0]}' -``` + -H "Content-Type: application/json" \ + -d '{"stages": [0]}' -**3. Send your inference requests** as normal while the profiler is running. - -**4. Stop profiling** and collect traces: -```bash -# Stop all stages +# Stop profiling. curl -X POST http://localhost:8091/stop_profile - -# Stop specific stages (must match the stages you started) -curl -X POST http://localhost:8091/stop_profile \ - -H "Content-Type: application/json" \ - -d '{"stages": [0]}' ``` -Trace files are written to the `torch_profiler_dir` specified in your stage YAML. +For mixed-stage pipelines, use explicit `stages` and pass the same stage list to both endpoints. + +## 5. Analyze Results -> **Important:** Always stop the same stages you started. Stopping a stage that was never started will produce errors. +Torch profiler output: -### 5. Analyzing Traces +- Chrome/Perfetto traces under `torch_profiler_dir` +- Optional aggregated CUDA-time tables under the same directory -Output files are saved to the `torch_profiler_dir` specified in your stage YAML config. +CUDA profiler / Nsight Systems output: -**Output** -**Chrome Trace** (`.json.gz`): Visual timeline of kernels and stages. Open in Perfetto UI. +- `.nsys-rep` report files written by `nsys -o ...` -**Viewing Tools:** +Recommended viewers: -- [Perfetto](https://ui.perfetto.dev/) (recommended) -- `chrome://tracing` (Chrome only) +- [Perfetto](https://ui.perfetto.dev/) for torch traces +- `nsys stats .nsys-rep` for CLI summaries +- Nsight Systems GUI for CUDA kernel timelines -**Note**: vLLM-Omni reuses the PyTorch Profiler infrastructure from vLLM. See the official vLLM profiler documentation: [vLLM Profiling Guide](https://docs.vllm.ai/en/stable/contributing/profiling/) +vLLM-Omni reuses the vLLM profiling infrastructure where possible. For the upstream reference, see the [vLLM profiling guide](https://docs.vllm.ai/en/stable/contributing/profiling/). diff --git a/tests/diffusion/test_diffusion_worker_cuda_profiler.py b/tests/diffusion/test_diffusion_worker_cuda_profiler.py new file mode 100644 index 00000000000..ddc2aed2fc2 --- /dev/null +++ b/tests/diffusion/test_diffusion_worker_cuda_profiler.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +@pytest.fixture +def mock_od_config(mocker: MockerFixture): + """Create a mock OmniDiffusionConfig with a CUDA profiler backend.""" + config = mocker.Mock() + config.profiler_config = mocker.Mock() + config.profiler_config.profiler = "cuda" + config.diffusion_load_format = "default" + return config + + +@pytest.fixture +def mock_diffusion_worker_dependencies(mocker: MockerFixture): + """Patch heavy worker dependencies for focused profiler tests.""" + mocker.patch.object(DiffusionWorker, "init_device") + mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.DiffusionModelRunner") + + +class TestDiffusionWorkerCudaProfiler: + def test_creates_cuda_profiler_wrapper( + self, + mocker: MockerFixture, + mock_od_config, + mock_diffusion_worker_dependencies, + ): + fake_profiler = mocker.Mock() + cuda_profiler = mocker.patch( + "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper", + return_value=fake_profiler, + ) + create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler") + + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + cuda_profiler.assert_called_once_with(mock_od_config.profiler_config) + create_omni_profiler.assert_not_called() + assert worker.profiler is fake_profiler + + def test_profile_start_stop_delegates_to_cuda_profiler( + self, + mocker: MockerFixture, + mock_od_config, + mock_diffusion_worker_dependencies, + ): + fake_profiler = mocker.Mock() + fake_profiler.start = MagicMock() + fake_profiler.stop = MagicMock() + mocker.patch( + "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper", + return_value=fake_profiler, + ) + + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + assert worker.profile(is_start=True) is None + assert worker.profile(is_start=False) is None + + fake_profiler.start.assert_called_once_with() + fake_profiler.stop.assert_called_once_with() + + def test_returns_none_when_profiler_config_is_missing( + self, + mocker: MockerFixture, + mock_od_config, + mock_diffusion_worker_dependencies, + ): + mock_od_config.profiler_config = None + cuda_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper") + create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler") + + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + cuda_profiler.assert_not_called() + create_omni_profiler.assert_not_called() + assert worker.profiler is None + + def test_cuda_backend_does_not_use_torch_profiler_factory( + self, + mocker: MockerFixture, + mock_od_config, + mock_diffusion_worker_dependencies, + ): + mocker.patch( + "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper", + return_value=mocker.Mock(), + ) + create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler") + + DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True) + + create_omni_profiler.assert_not_called() diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 422ef479b0c..52a8f385479 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -361,15 +361,11 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus ) def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: - """Start or stop torch profiling on all diffusion workers. + """Start or stop profiling on all diffusion workers. Args: is_start: True to start profiling, False to stop. - profile_prefix: Optional prefix for trace filename (vLLM compat). - - Note: - Matches vLLM's worker.profile() signature for consistency. - Traces are saved automatically via on_trace_ready callback. + profile_prefix: Optional prefix for trace filename. """ if is_start: if profile_prefix is None: diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index ea4b9d96f71..160309e0d8d 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -20,6 +20,7 @@ from vllm.config import CompilationConfig, DeviceConfig, VllmConfig, set_current_vllm_config from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger +from vllm.profiler.wrapper import CudaProfilerWrapper, WorkerProfiler from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.mem_utils import GiB_bytes from vllm.v1.worker.workspace import init_workspace_manager @@ -83,15 +84,7 @@ def __init__( od_config=self.od_config, device=self.device, ) - # Initialize profiler if configured - self.profiler: OmniTorchProfilerWrapper | None = None - profiler_config = self.od_config.profiler_config - if profiler_config and profiler_config.profiler == "torch": - self.profiler = create_omni_profiler( - profiler_config=profiler_config, - worker_name=f"diffusion_worker_{self.rank}", - local_rank=self.local_rank, - ) + self.profiler: WorkerProfiler | None = self._create_profiler() if not skip_load_model: self.load_model(load_format=self.od_config.diffusion_load_format) self.init_lora_manager() @@ -122,6 +115,7 @@ def init_device(self) -> None: vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel + vllm_config.profiler_config = self.od_config.profiler_config self.vllm_config = vllm_config # Initialize distributed environment @@ -147,6 +141,24 @@ def init_device(self) -> None: ) init_workspace_manager(self.device) + def _create_profiler(self) -> WorkerProfiler | None: + profiler_config = self.od_config.profiler_config + profiler_type = getattr(profiler_config, "profiler", None) + if profiler_type == "torch": + return create_omni_profiler( + profiler_config=profiler_config, + worker_name=f"diffusion_rank{self.rank}", + local_rank=self.local_rank, + ) + if profiler_type == "cuda": + return CudaProfilerWrapper(profiler_config) + if profiler_type is not None: + logger.warning("Unknown profiler backend %r on diffusion worker %s", profiler_type, self.rank) + return None + + def _get_profiler(self) -> WorkerProfiler | None: + return getattr(self, "profiler", None) + def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None: """Load the diffusion model using DiffusionModelRunner.""" with ( @@ -192,27 +204,21 @@ def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> N Args: is_start: True to start profiling, False to stop. - profile_prefix: Optional prefix for trace filename (vLLM compat). - - Note: - Matches vLLM's worker.profile() signature for consistency. - Traces are saved automatically via on_trace_ready callback. + profile_prefix: Optional prefix for trace filename. """ - if self.profiler is None: - logger.warning("Profiler not initialized, skipping profile(%s)", is_start) + profiler = self._get_profiler() + if profiler is None: return if is_start: - from vllm_omni.profiler import OmniTorchProfilerWrapper - - if isinstance(self.profiler, OmniTorchProfilerWrapper): + if isinstance(profiler, OmniTorchProfilerWrapper): import time - filename = profile_prefix or f"diffusion_{int(time.time())}" - self.profiler.set_trace_filename(filename) - self.profiler.start() + filename = profile_prefix or f"diffusion_rank{self.rank}_{int(time.time())}" + profiler.set_trace_filename(filename) + profiler.start() else: - self.profiler.stop() + profiler.stop() def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput: """Execute a forward pass by delegating to the model runner.""" @@ -224,7 +230,13 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi if req.sampling_params.lora_request is not None: raise logger.warning("LoRA activation skipped: %s", exc) - return self.model_runner.execute_model(req) + profiler = self._get_profiler() + ctx = profiler.annotate_context_manager("diffusion_forward") if profiler else nullcontext() + with ctx: + output = self.model_runner.execute_model(req) + if profiler: + profiler.step() + return output def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: """Execute one diffusion step by delegating to the model runner.""" @@ -236,8 +248,13 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): raise ValueError("Step mode does not support LoRA yet.") - - return self.model_runner.execute_stepwise(scheduler_output) + profiler = self._get_profiler() + ctx = profiler.annotate_context_manager("diffusion_step") if profiler else nullcontext() + with ctx: + output = self.model_runner.execute_stepwise(scheduler_output) + if profiler: + profiler.step() + return output def load_weights(self, weights) -> set[str]: """Load weights by delegating to the model runner."""