-
Notifications
You must be signed in to change notification settings - Fork 951
[Profiler] Support online profiling #1136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b40fff2
8604491
85bdc96
3194760
77807d9
0b3c4cb
6e6207e
238ed30
9d2f1d1
74642f1
012ef73
2d24711
bf9a5e9
0be1acb
718a43e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -298,3 +298,29 @@ async def pin_lora(self, lora_id: int) -> bool: | |
| None, | ||
| ) | ||
| return all(results) if isinstance(results, list) else results | ||
|
|
||
| async def start_profile(self, trace_filename: str | None = None) -> None: | ||
| """Start profiling for the diffusion model. | ||
|
|
||
| Args: | ||
| trace_filename: Optional base filename for trace files. | ||
| If None, a timestamp-based name will be generated. | ||
| """ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's going to be consistent with other async API in the current codebase. Maybe we should consider leave it to a following-up PR, which replaces |
||
| loop = asyncio.get_event_loop() | ||
| await loop.run_in_executor( | ||
| self._executor, | ||
| self.engine.start_profile, | ||
| trace_filename, | ||
| ) | ||
|
|
||
| async def stop_profile(self) -> dict: | ||
| """Stop profiling and return profiling results. | ||
|
|
||
| Returns: | ||
| Dictionary containing paths to trace and table files. | ||
| """ | ||
| loop = asyncio.get_event_loop() | ||
| return await loop.run_in_executor( | ||
| self._executor, | ||
| self.engine.stop_profile, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1150,16 +1150,15 @@ async def _force_log(): | |
| await stage_engine.reset_mm_cache() | ||
| logger.debug("[Stage-%s] Engine initialized", stage_id) | ||
|
|
||
| async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: | ||
| async def handle_profiler_task_async(task_type: OmniStageTaskType) -> dict: | ||
| """Handle profiler task asynchronously for both LLM and diffusion stages.""" | ||
| if task_type == OmniStageTaskType.PROFILER_START: | ||
| if stage_type == "diffusion": | ||
| try: | ||
| # Sync call is safe here — diffusion profiling is lightweight | ||
| profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| os.makedirs(profile_dir, exist_ok=True) | ||
| trace_filename = f"stage_{stage_id}_diffusion_{int(time.time())}" | ||
| stage_engine.start_profile(trace_filename=trace_filename) | ||
| await stage_engine.start_profile(trace_filename=trace_filename) | ||
| logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) | ||
| except Exception as e: | ||
| logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) | ||
|
|
@@ -1169,14 +1168,17 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: | |
| logger.info("[Stage-%s] vLLM profiler started", stage_id) | ||
| except Exception as e: | ||
| logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) | ||
| return {} | ||
|
|
||
| elif task_type == OmniStageTaskType.PROFILER_STOP: | ||
| result_data: dict = {} | ||
| if stage_type == "diffusion": | ||
| try: | ||
| trace_files = stage_engine.stop_profile() | ||
| trace_files = await stage_engine.stop_profile() | ||
| logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) | ||
| if trace_files: | ||
| logger.info("Diffusion trace files: %s", trace_files) | ||
| result_data = trace_files | ||
| except Exception as e: | ||
| logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) | ||
| else: | ||
|
|
@@ -1185,6 +1187,8 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: | |
| logger.info("[Stage-%s] vLLM profiler stopped", stage_id) | ||
| except Exception as e: | ||
| logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) | ||
| return result_data | ||
| return {} | ||
|
|
||
| # Signal readiness to orchestrator and send vllm_config back to main process | ||
| try: | ||
|
|
@@ -1286,7 +1290,10 @@ async def generation_single_request(task: dict[str, Any]): | |
| rid = task["request_id"] | ||
| asyncio.create_task(stage_engine.abort(rid)) | ||
| elif is_profiler_task(task_type): | ||
| await handle_profiler_task_async(task_type) | ||
| profiler_data = await handle_profiler_task_async(task_type) | ||
| # Send result back to orchestrator for STOP command | ||
| if task_type == OmniStageTaskType.PROFILER_STOP: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the orchestrator in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we don't need to add more handling method for it. It has checked |
||
| out_q.put({"type": "profiler_result", "data": profiler_data}) | ||
| else: | ||
| asyncio.create_task(generation_single_request(task)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile | ||
| from fastapi.responses import JSONResponse, StreamingResponse | ||
| from PIL import Image | ||
| from pydantic import BaseModel, Field | ||
| from starlette.datastructures import State | ||
| from starlette.routing import Route | ||
| from vllm import SamplingParams | ||
|
|
@@ -105,6 +106,30 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
| router = APIRouter() | ||
|
hsliuustc0106 marked this conversation as resolved.
|
||
| profiler_router = APIRouter() | ||
|
|
||
|
|
||
| def _should_enable_profiler_endpoints(args: Namespace) -> bool: | ||
| # Check upstream vLLM's profiler_config | ||
| profiler_config = getattr(args, "profiler_config", None) | ||
| if profiler_config is not None: | ||
| # profiler_config exists, check if profiler is set | ||
| profiler = getattr(profiler_config, "profiler", None) | ||
| if profiler is not None: | ||
| return True | ||
|
|
||
| # TODO: remove this env after refactoring torch profiler to CLI args | ||
| env_value = os.environ.get("VLLM_TORCH_PROFILER_DIR") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The VLLM repository uses profiler-config. After 985 was merged, will there be any issues using the VLLM_TORCH_PROFILER_DIR environment?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be honest, I'm not really sure. For now, we have two profiler ways:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vllm fully moved away from VLLM_TORCH_PROFILER_DIR in latest HEAD, I think we should just commit to one way if we're aiming to land this PR with current vLLM version. |
||
| return env_value is not None | ||
|
|
||
|
|
||
| class ProfileRequest(BaseModel): | ||
| """Request model for profiling endpoints.""" | ||
|
|
||
| stages: list[int] | None = Field( | ||
| default=None, | ||
| description="List of stage IDs to profile. If None, profiles all stages.", | ||
| ) | ||
|
|
||
|
|
||
| def _remove_route_from_router( | ||
|
|
@@ -227,6 +252,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, | |
|
|
||
| await omni_init_app_state(engine_client, app.state, args) | ||
|
|
||
| # Conditionally register profiler endpoints based on config or env var | ||
| if _should_enable_profiler_endpoints(args): | ||
| logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!") | ||
| app.include_router(profiler_router) | ||
|
|
||
| vllm_config = await engine_client.get_vllm_config() | ||
|
|
||
| # Check if pure diffusion mode (vllm_config will be None) | ||
|
|
@@ -1488,6 +1518,58 @@ def apply_stage_default_sampling_params( | |
| setattr(sampling_params, param_name, param_value) | ||
|
|
||
|
|
||
| @profiler_router.post("/start_profile") | ||
| async def start_profile(raw_request: Request, request: ProfileRequest | None = None): | ||
| """Start profiling for the engine. | ||
|
|
||
| Args: | ||
| request: Optional request body with stages to profile. | ||
| - stages: List of stage IDs to profile. If None, profiles all stages. | ||
|
|
||
| Example: | ||
| POST /start_profile | ||
| {"stages": [0, 1]} # Profile only stages 0 and 1 | ||
| """ | ||
| try: | ||
| stages = request.stages if request else None | ||
| logger.info("Starting profiler for stages: %s", stages if stages else "all") | ||
| engine_client = raw_request.app.state.engine_client | ||
| result = await engine_client.start_profile(stages=stages) | ||
| logger.info("Profiler started.") | ||
| return JSONResponse(content=result) | ||
| except Exception as e: | ||
| logger.exception("Failed to start profiler: %s", e) | ||
| raise HTTPException( | ||
| status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to start profiler: {str(e)}" | ||
| ) | ||
|
|
||
|
|
||
| @profiler_router.post("/stop_profile") | ||
| async def stop_profile(raw_request: Request, request: ProfileRequest | None = None): | ||
| """Stop profiling for the engine. | ||
|
|
||
| Args: | ||
| request: Optional request body with stages to stop profiling. | ||
| - stages: List of stage IDs to stop profiling. If None, stops all stages. | ||
|
|
||
| Example: | ||
| POST /stop_profile | ||
| {"stages": [0, 1]} # Stop profiling only stages 0 and 1 | ||
| """ | ||
| try: | ||
| stages = request.stages if request else None | ||
| logger.info("Stopping profiler for stages: %s", stages if stages else "all") | ||
| engine_client = raw_request.app.state.engine_client | ||
| result = await engine_client.stop_profile(stages=stages) | ||
| logger.info("Profiler stopped.") | ||
| return JSONResponse(content=result) | ||
| except Exception as e: | ||
| logger.exception("Failed to stop profiler: %s", e) | ||
| raise HTTPException( | ||
| status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}" | ||
| ) | ||
|
|
||
|
|
||
| async def _run_video_generation( | ||
| request: VideoGenerationRequest, | ||
| raw_request: Request, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.