Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:

step_total_ms = (time.perf_counter() - diffusion_engine_start_time) * 1000
logger.info(
"DiffusionEngine.step breakdown: preprocess=%.2f ms, "
"add_req_and_wait=%.2f ms, postprocess=%.2f ms, total=%.2f ms",
"[StageTiming stage=%s diffusion] total=%.2fs preprocess=%.2fms exec=%.2fs postprocess=%.2fms",
self.od_config.stage_id,
step_total_ms / 1000.0,
preprocess_time * 1000,
exec_total_time * 1000,
exec_total_time,
postprocess_time * 1000,
step_total_ms,
)

# Convert to OmniRequestOutput format
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _initialize_client(
self.default_sampling_params = metadata.default_sampling_params
self.custom_process_input_func = metadata.custom_process_input_func
self.engine_input_source = metadata.engine_input_source
self.model_stage = metadata.model_stage
self._proc = proc
self._owns_process = proc is not None

Expand Down
7 changes: 7 additions & 0 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,8 @@ def _build_add_request_message(
message_type: str = "add_request",
) -> dict[str, Any]:
"""Build an add_request message after stage-0 preprocessing."""
build_add_request_message_start = time.perf_counter()
input_preprocess_time_ms = 0.0
effective_sampling_params_list = (
list(sampling_params_list) if sampling_params_list is not None else list(self.default_sampling_params_list)
)
Expand All @@ -1032,6 +1034,7 @@ def _build_add_request_message(
_inject_global_id(item, request_id)

# Full input processing (tokenization, multimodal, etc.)
input_preprocess_start = time.perf_counter()
request = self.input_processor.process_inputs(
request_id=request_id,
prompt=prompt,
Expand All @@ -1045,6 +1048,7 @@ def _build_add_request_message(
data_parallel_rank=data_parallel_rank,
resumable=resumable,
)
input_preprocess_time_ms = (time.perf_counter() - input_preprocess_start) * 1000.0
# TODO (Peiqi): add this for Qwen3-TTS only. Other models don't have
# additional_information field in the prompt.
request = _upgrade_to_omni_request(request, prompt)
Expand Down Expand Up @@ -1074,13 +1078,16 @@ def _build_add_request_message(
)
prompt = request

build_add_request_message_time_ms = (time.perf_counter() - build_add_request_message_start) * 1000.0
return {
"type": message_type,
"request_id": request_id,
"prompt": prompt,
"original_prompt": original_prompt,
"sampling_params_list": effective_sampling_params_list,
"final_stage_id": final_stage_id,
"input_preprocess_time_ms": input_preprocess_time_ms,
"build_add_request_message_time_ms": build_add_request_message_time_ms,
}

def _enqueue_cfg_companions(
Expand Down
6 changes: 6 additions & 0 deletions vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class OrchestratorRequestState:

# Metrics: timestamp when request was submitted to each stage
stage_submit_ts: dict[int, float] = field(default_factory=dict)
input_preprocess_time_ms: float = 0.0
build_add_request_message_time_ms: float = 0.0
mm_processor_kwargs: dict | None = None
mm_features: list | None = None

Expand Down Expand Up @@ -419,6 +421,8 @@ async def _route_output(
"metrics": stage_metrics,
"finished": finished and stage_id == req_state.final_stage_id,
"stage_submit_ts": submit_ts,
"input_preprocess_time_ms": req_state.input_preprocess_time_ms,
"build_add_request_message_time_ms": req_state.build_add_request_message_time_ms,
}
)
elif stage_metrics is not None:
Expand Down Expand Up @@ -877,6 +881,8 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
prompt=original_prompt,
sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id,
input_preprocess_time_ms=float(msg.get("input_preprocess_time_ms", 0.0)),
build_add_request_message_time_ms=float(msg.get("build_add_request_message_time_ms", 0.0)),
mm_features=getattr(prompt, "mm_features", None), # Save mm_features for PD
)
req_state.streaming.enabled = is_streaming
Expand Down
5 changes: 3 additions & 2 deletions vllm_omni/engine/stage_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
stage_id: int = stage_config.stage_id
stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm")
engine_args = stage_config.engine_args
model_stage = getattr(engine_args, "model_stage", None)

if current_omni_platform.is_rocm():
if engine_args.get("attention_backend") is None:
Expand Down Expand Up @@ -333,12 +334,11 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
final_output_type=final_output_type,
default_sampling_params=default_sampling_params,
custom_process_input_func=custom_process_input_func,
model_stage=None,
model_stage=model_stage,
runtime_cfg=runtime_cfg,
cfg_kv_collect_func=cfg_kv_collect_func,
)

model_stage = getattr(engine_args, "model_stage", None)
engine_output_type = getattr(engine_args, "engine_output_type", None)
is_comprehension = getattr(stage_config, "is_comprehension", False)
requires_multimodal_data = getattr(runtime_cfg, "requires_multimodal_data", False)
Expand Down Expand Up @@ -755,6 +755,7 @@ def finalize_initialized_stages(
"final_output": stage_client.final_output,
"final_output_type": stage_client.final_output_type,
"stage_type": stage_client.stage_type,
"model_stage": getattr(stage_client, "model_stage", None),
}
for stage_client in initialized_stage_clients
]
Expand Down
10 changes: 10 additions & 0 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ async def generate(
wall_start_ts,
final_stage_id_for_e2e,
)
metrics.stage_labels = dict(self._stage_labels)
req_state = ClientRequestState(request_id)
req_state.metrics = metrics
self.request_states[request_id] = req_state
Expand All @@ -309,6 +310,7 @@ async def generate(
input_stream=prompt,
sampling_params_list=req_sp_list,
final_stage_id=final_stage_id_for_e2e,
req_start_ts=req_start_ts,
)
else:
await self.engine.add_request_async(
Expand Down Expand Up @@ -356,6 +358,7 @@ async def _add_streaming_input_request(
input_stream: AsyncGenerator[StreamingInput, None],
sampling_params_list: Sequence[OmniSamplingParams],
final_stage_id: int,
req_start_ts: dict[str, float],
) -> asyncio.Task:
"""Submit a streaming input generator as incremental stage-0 updates."""
if not sampling_params_list:
Expand All @@ -372,6 +375,11 @@ async def _add_streaming_input_request(

has_submitted_first_chunk = False

def mark_stage0_submit() -> None:
submit_ts = time.time()
req_state.metrics.stage_first_ts[0] = submit_ts
req_start_ts[request_id] = submit_ts

async def handle_inputs() -> None:
nonlocal has_submitted_first_chunk
cancelled = False
Expand All @@ -393,6 +401,7 @@ async def handle_inputs() -> None:
final_stage_id=final_stage_id,
resumable=True,
)
mark_stage0_submit()
has_submitted_first_chunk = True
else:
await self.engine.add_streaming_update_async(
Expand Down Expand Up @@ -433,6 +442,7 @@ async def handle_inputs() -> None:
final_stage_id=final_stage_id,
resumable=False,
)
mark_stage0_submit()

input_stream_task = asyncio.create_task(handle_inputs())
req_state.input_stream_task = input_stream_task
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _run_generation(
wall_start_ts,
final_stage_id,
)
metrics.stage_labels = dict(self._stage_labels)
req_state = ClientRequestState(req_id)
req_state.metrics = metrics
self.request_states[req_id] = req_state
Expand Down
25 changes: 24 additions & 1 deletion vllm_omni/entrypoints/omni_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
self._stage_meta_list = [
types.SimpleNamespace(**self.engine.get_stage_metadata(i)) for i in range(self.engine.num_stages)
]
self._stage_labels = self.build_stage_labels()

logger.info(
"[%s] Initialized with %s stages for model %s",
Expand Down Expand Up @@ -247,8 +248,12 @@ def _log_summary_and_cleanup(self, request_id: str) -> None:
try:
if req_state is None or req_state.metrics is None:
return
timing_line = req_state.metrics.format_request_timing_line(request_id)
if timing_line is not None and self.num_stages > 1:
logger.info("%s", timing_line)
summary = req_state.metrics.build_and_log_summary()
logger.info("[Summary] %s", pformat(summary, sort_dicts=False))
if summary:
logger.info("[Summary] %s", pformat(summary, sort_dicts=False))
except Exception:
logger.exception(
"[%s] Failed to build/log summary for req=%s",
Expand All @@ -265,6 +270,18 @@ def _compute_final_stage_id(self, output_modalities: list[str] | None) -> int:
self._stage_meta_list,
)

def build_stage_labels(self) -> dict[int, str]:
stage_labels: dict[int, str] = {}
for idx, stage_meta in enumerate(self._stage_meta_list):
stage_type = getattr(stage_meta, "stage_type", None)
model_stage = getattr(stage_meta, "model_stage", None)
if stage_type == "diffusion":
stage_name = "diffusion"
else:
stage_name = model_stage or stage_type or f"stage_{idx}"
stage_labels[idx] = f"{idx}:{stage_name}"
return stage_labels

def _process_stage_metrics_message(self, msg: dict[str, Any]) -> None:
req_id = msg.get("request_id")
req_state = self.request_states.get(req_id)
Expand Down Expand Up @@ -377,6 +394,12 @@ def _process_single_result(
metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, now)

_m = result.get("metrics")
input_preprocess_time_ms = result.get("input_preprocess_time_ms")
if input_preprocess_time_ms is not None:
metrics.input_preprocess_time_ms = float(input_preprocess_time_ms)
build_add_request_message_time_ms = result.get("build_add_request_message_time_ms")
if build_add_request_message_time_ms is not None:
metrics.build_add_request_message_time_ms = float(build_add_request_message_time_ms)
if finished and _m is not None:
metrics.on_stage_metrics(stage_id, req_id, _m)

Expand Down
Loading
Loading