diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 99b07cb9869..3e54aed8f9a 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -131,9 +131,6 @@ python image_to_video.py \ 2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video) -> **Note:** -As of now, asynchronous (online) profiling is not fully supported in vLLM-Omni. While start_profile() and stop_profile() methods exist, they are only reliable in offline inference scripts (e.g., the provided end2end.py examples). Do not use them in server-mode or streaming scenarios—traces may be incomplete or fail to flush. - ### 4. Analyzing Omni Traces Output files are saved to your configured ```VLLM_TORCH_PROFILER_DIR```. diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 11a3c07e135..816a7099e49 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -298,3 +298,29 @@ async def pin_lora(self, lora_id: int) -> bool: None, ) return all(results) if isinstance(results, list) else results + + async def start_profile(self, trace_filename: str | None = None) -> None: + """Start profiling for the diffusion model. + + Args: + trace_filename: Optional base filename for trace files. + If None, a timestamp-based name will be generated. + """ + loop = asyncio.get_event_loop() + await loop.run_in_executor( + self._executor, + self.engine.start_profile, + trace_filename, + ) + + async def stop_profile(self) -> dict: + """Stop profiling and return profiling results. + + Returns: + Dictionary containing paths to trace and table files. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.engine.stop_profile, + ) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index f30cd7d368e..7dccae6c08e 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -395,6 +395,22 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: logger.warning(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + def _is_profiler_enabled(self, stage_id: int) -> bool: + """Check if profiler config is set for a given stage.""" + stage = self.stage_list[stage_id] + # For diffusion stages, profiling is controlled by VLLM_TORCH_PROFILER_DIR env var + if stage.stage_type == "diffusion": + return True + # For LLM stages, check if profiler_config is set in engine_args + engine_args = getattr(stage.stage_config, "engine_args", None) + if engine_args is None: + return False + profiler_config = getattr(engine_args, "profiler_config", None) + if profiler_config is None: + return False + profiler = getattr(profiler_config, "profiler", None) + return profiler is not None + def start_profile(self, stages: list[int] | None = None) -> None: """Start profiling for specified stages. @@ -419,6 +435,13 @@ def start_profile(self, stages: list[int] | None = None) -> None: for stage_id in stages: if stage_id < len(self.stage_list): + if not self._is_profiler_enabled(stage_id): + logger.info( + "[%s] Skipping start_profile for stage-%s: profiler config not set", + self._name, + stage_id, + ) + continue try: self.stage_list[stage_id].submit({"type": OmniStageTaskType.PROFILER_START}) logger.info("[%s] Sent start_profile to stage-%s", self._name, stage_id) @@ -442,6 +465,13 @@ def stop_profile(self, stages: list[int] | None = None) -> dict: for stage_id in stages: if stage_id < len(self.stage_list): + if not self._is_profiler_enabled(stage_id): + logger.info( + "[%s] Skipping stop_profile for stage-%s: profiler config not set", + self._name, + stage_id, + ) + continue stage = self.stage_list[stage_id] # Check if the stage object has our new bridge method diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 6c9723b6b5b..2f78355f0b6 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -1150,16 +1150,15 @@ async def _force_log(): await stage_engine.reset_mm_cache() logger.debug("[Stage-%s] Engine initialized", stage_id) - async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: + async def handle_profiler_task_async(task_type: OmniStageTaskType) -> dict: """Handle profiler task asynchronously for both LLM and diffusion stages.""" if task_type == OmniStageTaskType.PROFILER_START: if stage_type == "diffusion": try: - # Sync call is safe here — diffusion profiling is lightweight profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") os.makedirs(profile_dir, exist_ok=True) trace_filename = f"stage_{stage_id}_diffusion_{int(time.time())}" - stage_engine.start_profile(trace_filename=trace_filename) + await stage_engine.start_profile(trace_filename=trace_filename) logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) except Exception as e: logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) @@ -1169,14 +1168,17 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: logger.info("[Stage-%s] vLLM profiler started", stage_id) except Exception as e: logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) + return {} elif task_type == OmniStageTaskType.PROFILER_STOP: + result_data: dict = {} if stage_type == "diffusion": try: - trace_files = stage_engine.stop_profile() + trace_files = await stage_engine.stop_profile() logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) if trace_files: logger.info("Diffusion trace files: %s", trace_files) + result_data = trace_files except Exception as e: logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) else: @@ -1185,6 +1187,8 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: logger.info("[Stage-%s] vLLM profiler stopped", stage_id) except Exception as e: logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) + return result_data + return {} # Signal readiness to orchestrator and send vllm_config back to main process try: @@ -1286,7 +1290,10 @@ async def generation_single_request(task: dict[str, Any]): rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) elif is_profiler_task(task_type): - await handle_profiler_task_async(task_type) + profiler_data = await handle_profiler_task_async(task_type) + # Send result back to orchestrator for STOP command + if task_type == OmniStageTaskType.PROFILER_STOP: + out_q.put({"type": "profiler_result", "data": profiler_data}) else: asyncio.create_task(generation_single_request(task)) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 3898da3081e..84a81a86139 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -22,6 +22,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse from PIL import Image +from pydantic import BaseModel, Field from starlette.datastructures import State from starlette.routing import Route from vllm import SamplingParams @@ -105,6 +106,30 @@ logger = init_logger(__name__) router = APIRouter() +profiler_router = APIRouter() + + +def _should_enable_profiler_endpoints(args: Namespace) -> bool: + # Check upstream vLLM's profiler_config + profiler_config = getattr(args, "profiler_config", None) + if profiler_config is not None: + # profiler_config exists, check if profiler is set + profiler = getattr(profiler_config, "profiler", None) + if profiler is not None: + return True + + # TODO: remove this env after refactoring torch profiler to CLI args + env_value = os.environ.get("VLLM_TORCH_PROFILER_DIR") + return env_value is not None + + +class ProfileRequest(BaseModel): + """Request model for profiling endpoints.""" + + stages: list[int] | None = Field( + default=None, + description="List of stage IDs to profile. If None, profiles all stages.", + ) def _remove_route_from_router( @@ -227,6 +252,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, await omni_init_app_state(engine_client, app.state, args) + # Conditionally register profiler endpoints based on config or env var + if _should_enable_profiler_endpoints(args): + logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!") + app.include_router(profiler_router) + vllm_config = await engine_client.get_vllm_config() # Check if pure diffusion mode (vllm_config will be None) @@ -1488,6 +1518,58 @@ def apply_stage_default_sampling_params( setattr(sampling_params, param_name, param_value) +@profiler_router.post("/start_profile") +async def start_profile(raw_request: Request, request: ProfileRequest | None = None): + """Start profiling for the engine. + + Args: + request: Optional request body with stages to profile. + - stages: List of stage IDs to profile. If None, profiles all stages. + + Example: + POST /start_profile + {"stages": [0, 1]} # Profile only stages 0 and 1 + """ + try: + stages = request.stages if request else None + logger.info("Starting profiler for stages: %s", stages if stages else "all") + engine_client = raw_request.app.state.engine_client + result = await engine_client.start_profile(stages=stages) + logger.info("Profiler started.") + return JSONResponse(content=result) + except Exception as e: + logger.exception("Failed to start profiler: %s", e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to start profiler: {str(e)}" + ) + + +@profiler_router.post("/stop_profile") +async def stop_profile(raw_request: Request, request: ProfileRequest | None = None): + """Stop profiling for the engine. + + Args: + request: Optional request body with stages to stop profiling. + - stages: List of stage IDs to stop profiling. If None, stops all stages. + + Example: + POST /stop_profile + {"stages": [0, 1]} # Stop profiling only stages 0 and 1 + """ + try: + stages = request.stages if request else None + logger.info("Stopping profiler for stages: %s", stages if stages else "all") + engine_client = raw_request.app.state.engine_client + result = await engine_client.stop_profile(stages=stages) + logger.info("Profiler stopped.") + return JSONResponse(content=result) + except Exception as e: + logger.exception("Failed to stop profiler: %s", e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}" + ) + + async def _run_video_generation( request: VideoGenerationRequest, raw_request: Request,