Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions benchmarks/diffusion/diffusion_benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,50 @@ def calculate_metrics(
return metrics


def start_profile(base_url: str) -> bool:
"""Start profiling on the server.

Args:
base_url: Base URL of the server (e.g., http://localhost:8091)

Returns:
True if profiling started successfully, False otherwise.
"""
try:
resp = requests.post(f"{base_url}/start_profile", timeout=30)
if resp.status_code == 200:
print("Profiling started on server.")
return True
else:
print(f"Failed to start profiling: HTTP {resp.status_code}")
return False
except requests.exceptions.RequestException as e:
print(f"Failed to start profiling: {e}")
return False


def stop_profile(base_url: str) -> bool:
"""Stop profiling on the server.

Args:
base_url: Base URL of the server (e.g., http://localhost:8091)

Returns:
True if profiling stopped successfully, False otherwise.
"""
try:
resp = requests.post(f"{base_url}/stop_profile", timeout=60)
if resp.status_code == 200:
print("Profiling stopped on server. Trace files saved.")
return True
else:
print(f"Failed to stop profiling: HTTP {resp.status_code}")
return False
except requests.exceptions.RequestException as e:
print(f"Failed to stop profiling: {e}")
return False


def wait_for_service(base_url: str, timeout: int = 120) -> None:
print(f"Waiting for service at {base_url}...")
start_time = time.time()
Expand Down Expand Up @@ -906,6 +950,10 @@ async def limited_request_func(req, session, pbar):
args=args,
)

# Start profiling if requested (after warmup, before main benchmark)
if args.profile:
start_profile(args.base_url)

start_time = time.perf_counter()
tasks = []
async for req in iter_requests(requests_list=requests_list, request_rate=args.request_rate):
Expand All @@ -915,6 +963,10 @@ async def limited_request_func(req, session, pbar):
outputs = await asyncio.gather(*tasks)
total_duration = time.perf_counter() - start_time

# Stop profiling if it was started
if args.profile:
stop_profile(args.base_url)

pbar.close()

# Calculate metrics
Expand Down Expand Up @@ -1064,6 +1116,11 @@ async def limited_request_func(req, session, pbar):
help="SLO target multiplier: slo_ms = estimated_exec_time_ms * slo_scale (default: 3).",
)
parser.add_argument("--disable-tqdm", action="store_true", help="Disable progress bar.")
parser.add_argument(
"--profile",
action="store_true",
help="Enable profiling. Calls /start_profile before benchmark and /stop_profile after.",
)

args = parser.parse_args()

Expand Down
21 changes: 21 additions & 0 deletions vllm_omni/entrypoints/async_omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,24 @@ async def pin_lora(self, lora_id: int) -> bool:
None,
)
return all(results) if isinstance(results, list) else results

def start_profile(self, trace_filename: str | None = None) -> None:
"""Start profiling on the diffusion engine.

Delegates to the underlying DiffusionEngine's start_profile method
which sets up torch profiling on all diffusion workers.

Args:
trace_filename: Optional base filename (without extension or rank suffix).
If None, generates one using current timestamp.
"""
self.engine.start_profile(trace_filename=trace_filename)

def stop_profile(self) -> None:
"""Stop profiling and return trace file paths.

Delegates to the underlying DiffusionEngine's stop_profile method
which stops profiling on all workers and collects trace paths.

"""
self.engine.stop_profile()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Return profiler result dict from AsyncOmniDiffusion.stop_profile

AsyncOmniDiffusion.stop_profile() invokes DiffusionEngine.stop_profile() but does not return its result, so callers always receive None. In the new async stage path, handle_profiler_task_async does result_data = stage_engine.stop_profile() or {} and forwards that to the orchestrator, which means profiling artifacts are always dropped from the response even when traces were successfully written. This breaks the new result-collection flow added in this commit.

Useful? React with 👍 / 👎.

24 changes: 18 additions & 6 deletions vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,8 +1150,12 @@ async def _force_log():
await stage_engine.reset_mm_cache()
logger.debug("[Stage-%s] Engine initialized", stage_id)

async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None:
"""Handle profiler task asynchronously for both LLM and diffusion stages."""
async def handle_profiler_task_async(task_type: OmniStageTaskType) -> dict:
"""Handle profiler task asynchronously for both LLM and diffusion stages.

Returns:
dict: For PROFILER_STOP, returns the profiler result data. Empty dict otherwise.
"""
if task_type == OmniStageTaskType.PROFILER_START:
if stage_type == "diffusion":
try:
Expand All @@ -1169,14 +1173,16 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None:
logger.info("[Stage-%s] vLLM profiler started", stage_id)
except Exception as e:
logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e)
return {}

elif task_type == OmniStageTaskType.PROFILER_STOP:
result_data = {}
if stage_type == "diffusion":
try:
trace_files = stage_engine.stop_profile()
result_data = stage_engine.stop_profile() or {}
logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id)
if trace_files:
logger.info("Diffusion trace files: %s", trace_files)
if result_data:
logger.info("Diffusion trace files: %s", result_data)
except Exception as e:
logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e)
else:
Expand All @@ -1185,6 +1191,9 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None:
logger.info("[Stage-%s] vLLM profiler stopped", stage_id)
except Exception as e:
logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e)
return result_data

return {}

# Signal readiness to orchestrator and send vllm_config back to main process
try:
Expand Down Expand Up @@ -1286,7 +1295,10 @@ async def generation_single_request(task: dict[str, Any]):
rid = task["request_id"]
asyncio.create_task(stage_engine.abort(rid))
elif is_profiler_task(task_type):
await handle_profiler_task_async(task_type)
profiler_data = await handle_profiler_task_async(task_type)
# Send profiler result back to orchestrator for PROFILER_STOP
if task_type == OmniStageTaskType.PROFILER_STOP:
out_q.put({"type": "profiler_result", "data": profiler_data})
else:
asyncio.create_task(generation_single_request(task))

Expand Down
41 changes: 40 additions & 1 deletion vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import httpx
import vllm.envs as envs
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import JSONResponse, Response, StreamingResponse
from PIL import Image
from starlette.datastructures import State
from starlette.routing import Route
Expand Down Expand Up @@ -929,6 +929,45 @@ async def show_available_models(raw_request: Request) -> JSONResponse:
)


# Profiling API endpoints
def _get_engine_client(raw_request: Request) -> AsyncOmni:
engine_client = getattr(raw_request.app.state, "engine_client", None)
if engine_client is None:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
detail="Engine not initialized.",
)
return engine_client


@router.post("/start_profile")
async def start_profile(raw_request: Request):
"""Start profiling for the running server.

Enables torch profiling to capture CPU/CUDA activities, memory allocation,
and other performance metrics. Use /stop_profile to stop and save the trace.
"""
logger.info("Starting profiler...")
engine_client = _get_engine_client(raw_request)
await engine_client.start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
Comment on lines +952 to +954

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate profiling failures instead of always returning 200

The new /start_profile and /stop_profile APIs always return 200 OK once these calls complete, but worker-side profiling errors are caught and only logged in stage handlers, so failures (e.g., invalid/unwritable profiler output path) are silently reported as success to clients. This can invalidate benchmark/profiling runs because automation has no reliable signal that profiling did not actually start/stop correctly.

Useful? React with 👍 / 👎.



@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
"""Stop profiling and save the trace.

Stops the profiler started by /start_profile and saves the trace file.
The trace location is determined by the VLLM_TORCH_PROFILER_DIR environment variable.
"""
logger.info("Stopping profiler...")
engine_client = _get_engine_client(raw_request)
await engine_client.stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)


# Image generation API endpoints


Expand Down