Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions tests/dfx/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def _build_serve_args(serve_args: Any) -> list[str]:
def create_unique_server_params(
configs: list[dict[str, Any]],
stage_configs_dir: Path,
) -> list[tuple[str, str, str | None, str | None, tuple[str, ...]]]:
"""Return one row per unique server configuration (same 5-tuple shape as upstream).
) -> list[tuple[str, str, str | None, str | None, tuple[str, ...], bool]]:
"""Return one row per unique server configuration.

``(test_name, model, deploy_yaml_path, stage_overrides_json, extra_cli_args)``.
``(test_name, model, deploy_yaml_path, stage_overrides_json, extra_cli_args, use_omni)``.

JSON ``server_params.serve_args`` (dict/list) is expanded via ``_build_serve_args``
and **prepended** to ``extra_cli_args`` so perf / stability ``omni_server`` fixtures
stay identical to main while still honoring ``serve_args`` in benchmark JSON.
"""
unique_params: list[tuple[str, str, str | None, str | None, tuple[str, ...]]] = []
seen: set[tuple[str, str, str | None, str | None, tuple[str, ...]]] = set()
unique_params: list[tuple[str, str, str | None, str | None, tuple[str, ...], bool]] = []
seen: set[tuple[str, str, str | None, str | None, tuple[str, ...], bool]] = set()
for config in configs:
test_name = config["test_name"]
server_params = config["server_params"]
Expand All @@ -104,8 +104,16 @@ def create_unique_server_params(
serve_flat = _build_serve_args(server_params.get("serve_args"))
raw_extra = tuple(server_params.get("extra_cli_args") or ())
extra_cli_args = tuple(serve_flat) + raw_extra

server_param = (test_name, model, stage_config_path, stage_overrides_json, extra_cli_args)
use_omni = bool(server_params.get("use_omni", True))

server_param = (
test_name,
model,
stage_config_path,
stage_overrides_json,
extra_cli_args,
use_omni,
)
if server_param not in seen:
seen.add(server_param)
unique_params.append(server_param)
Expand Down
4 changes: 2 additions & 2 deletions tests/dfx/perf/scripts/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def omni_server(request):
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param
test_name, model, stage_config_path, stage_overrides, extra_cli_args, use_omni = request.param

print(f"Starting OmniServer with test: {test_name}, model: {model}")

Expand All @@ -78,7 +78,7 @@ def omni_server(request):
server_args = ["--stage-overrides", stage_overrides] + server_args
if extra_cli_args:
server_args = list(extra_cli_args) + server_args
with OmniServer(model, server_args) as server:
with OmniServer(model, server_args, use_omni=use_omni) as server:
server.test_name = test_name
print("OmniServer started successfully")
yield server
Expand Down
6 changes: 3 additions & 3 deletions tests/dfx/stability/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def omni_server(request: pytest.FixtureRequest):
"""Start OmniServer for stability tests, with per-module timeout override."""
timeout_args = getattr(request.module, "STABILITY_SERVER_TIMEOUT_ARGS", DEFAULT_STABILITY_SERVER_TIMEOUT_ARGS)
with _omni_server_lock:
# Same 5-tuple and CLI composition as ``tests/dfx/perf/scripts/run_benchmark.py`` on main;
# Same tuple and CLI composition as ``tests/dfx/perf/scripts/run_benchmark.py``;
# ``serve_args`` from JSON are folded into ``extra_cli_args`` inside
# ``create_unique_server_params``.
test_name, model, deploy_path, stage_overrides, extra_cli_args = request.param
test_name, model, deploy_path, stage_overrides, extra_cli_args, use_omni = request.param

print(f"Starting OmniServer with test: {test_name}, model: {model}")
server_args = list(timeout_args)
Expand All @@ -48,7 +48,7 @@ def omni_server(request: pytest.FixtureRequest):
server_args = ["--stage-overrides", stage_overrides] + server_args
if extra_cli_args:
server_args = list(extra_cli_args) + server_args
with OmniServer(model, server_args) as server:
with OmniServer(model, server_args, use_omni=use_omni) as server:
server.test_name = test_name
print("OmniServer started successfully")
yield server
Expand Down
5 changes: 4 additions & 1 deletion vllm_omni/benchmarks/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def calculate_metrics(
total_input += outputs[i].prompt_len
tpot = 0
if output_len > 1:
latency_minus_ttft = outputs[i].text_latency - outputs[i].ttft
try:
latency_minus_ttft = outputs[i].text_latency - outputs[i].ttft
except Exception:
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
tpot = latency_minus_ttft / (output_len - 1)
tpots.append(tpot)
# Note: if output_len <= 1, we regard tpot as 0 for goodput
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.v1.core.sched.async_scheduler import AsyncScheduler as VLLMScheduler
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
Expand Down
14 changes: 13 additions & 1 deletion vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,19 @@ async def create_chat_completion(
if raw_request:
raw_request.state.request_metadata = request_metadata

output_modalities = getattr(request, "modalities", self.engine_client.output_modalities)
# NOTE:
# - OpenAI python client flattens extra_body fields into model_extra.
# - Raw HTTP requests may keep them under request.extra_body.
# Keep modalities resolution tolerant so `--extra_body '{"modalities":["text"]}'`
# can reliably drive multi-stage routing.
request_extra_body = getattr(request, "extra_body", None) or request.model_extra or {}
output_modalities = getattr(request, "modalities", None)
if output_modalities is None:
output_modalities = request_extra_body.get("modalities")
if isinstance(output_modalities, str):
output_modalities = [output_modalities]
if output_modalities is not None:
output_modalities = [str(m).lower() for m in output_modalities]
request.modalities = (
output_modalities if output_modalities is not None else self.engine_client.output_modalities
)
Expand Down
81 changes: 66 additions & 15 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
_HIDDEN_LAYER_KEY = "24"


def _layer_tensor(layers: dict[Any, Any], key: str) -> torch.Tensor | None:
"""Fetch layer tensor with tolerant key lookup (str/int)."""
if not isinstance(layers, dict):
return None
key_int = int(key)
val = layers.get(key_int)
if val is None:
val = layers.get(key)
return val if isinstance(val, torch.Tensor) else None


def _compute_talker_prompt_ids_length(info: OmniPayload, device: torch.device | str = "cuda") -> int:
im_start_token_id = 151644
system_token_id = 8948
Expand Down Expand Up @@ -300,9 +311,23 @@ def thinker2talker_async_chunk(

request_id = request.external_req_id
chunk_id = transfer_manager.put_req_chunk[request_id]
if not isinstance(pooling_output, dict):
logger.debug("thinker2talker_async_chunk: skip non-dict pooling_output for req=%s", request_id)
return None

thinker_hs = pooling_output.get("hidden_states", {})
thinker_layers = thinker_hs.get("layers", {})
thinker_embed = pooling_output.get("embed", {})
thinker_layers = thinker_hs.get("layers", {}) if isinstance(thinker_hs, dict) else {}
thinker_embed = pooling_output.get("embed", {}) if isinstance(pooling_output.get("embed", {}), dict) else {}
thinker_emb = _layer_tensor(thinker_layers, _EMBED_LAYER_KEY)
thinker_hid = _layer_tensor(thinker_layers, _HIDDEN_LAYER_KEY)
if thinker_emb is None or thinker_hid is None:
logger.debug(
"thinker2talker_async_chunk: missing thinker layers for req=%s (embed=%s hidden=%s)",
request_id,
thinker_emb is not None,
thinker_hid is not None,
)
return None

if chunk_id == 0:
all_token_ids = request.all_token_ids # prefill + decode
Expand All @@ -312,13 +337,19 @@ def thinker2talker_async_chunk(
prompt_token_ids = _ensure_list(prompt_token_ids)
payload: OmniPayload = {
"embed": {
"prefill": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu(),
"prefill": thinker_emb.detach().cpu(),
# Provide thinker-side TTS token embeddings for talker projection
"tts_bos": thinker_embed["tts_bos"].detach().cpu(),
"tts_eos": thinker_embed["tts_eos"].detach().cpu(),
"tts_pad": thinker_embed["tts_pad"].detach().cpu(),
"tts_bos": thinker_embed.get("tts_bos").detach().cpu()
if isinstance(thinker_embed.get("tts_bos"), torch.Tensor)
else None,
"tts_eos": thinker_embed.get("tts_eos").detach().cpu()
if isinstance(thinker_embed.get("tts_eos"), torch.Tensor)
else None,
"tts_pad": thinker_embed.get("tts_pad").detach().cpu()
if isinstance(thinker_embed.get("tts_pad"), torch.Tensor)
else None,
},
"hidden_states": {"output": thinker_layers[int(_HIDDEN_LAYER_KEY)].detach().cpu()},
"hidden_states": {"output": thinker_hid.detach().cpu()},
"ids": {"all": all_token_ids, "prompt": prompt_token_ids},
"meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)},
}
Expand Down Expand Up @@ -366,12 +397,12 @@ def thinker2talker_async_chunk(

if output_token_ids:
talker_additional_info["meta"]["override_keys"] = [("embed", "decode"), ("ids", "output")]
talker_additional_info["embed"] = {"decode": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu()}
talker_additional_info["embed"] = {"decode": thinker_emb.detach().cpu()}
talker_additional_info["ids"] = {"output": output_token_ids}
else:
# When prefilling a chunked thinker, thinker_hidden_states needs to be updated.
talker_additional_info["embed"] = {"prefill": thinker_layers[0].detach().cpu()}
talker_additional_info["hidden_states"] = {"output": thinker_layers[24].detach().cpu()}
talker_additional_info["embed"] = {"prefill": thinker_emb.detach().cpu()}
talker_additional_info["hidden_states"] = {"output": thinker_hid.detach().cpu()}
return talker_additional_info


Expand Down Expand Up @@ -431,11 +462,20 @@ def thinker2talker(
thinker_sequences = prompt_token_ids + output_ids
thinker_input_ids = prompt_token_ids
new_seq_length = len(prompt_token_ids + output_ids) - 1
thinker_mm: OmniPayload = output.multimodal_output
thinker_mm_raw = getattr(output, "multimodal_output", None)
if not isinstance(thinker_mm_raw, dict):
logger.debug("thinker2talker: skip req=%s due to empty multimodal_output", req_id)
continue
thinker_mm: OmniPayload = thinker_mm_raw
mm_hs = thinker_mm.get("hidden_states", {})
mm_layers = mm_hs.get("layers", {})
thinker_emb = mm_layers[int(_EMBED_LAYER_KEY)].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
thinker_hid = mm_layers[int(_HIDDEN_LAYER_KEY)].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
mm_layers = mm_hs.get("layers", {}) if isinstance(mm_hs, dict) else {}
emb_layer = _layer_tensor(mm_layers, _EMBED_LAYER_KEY)
hid_layer = _layer_tensor(mm_layers, _HIDDEN_LAYER_KEY)
if emb_layer is None or hid_layer is None:
logger.debug("thinker2talker: skip req=%s due to missing hidden-state layers", req_id)
continue
thinker_emb = emb_layer.detach().to(device=device, dtype=torch.float)[-new_seq_length:]
thinker_hid = hid_layer.detach().to(device=device, dtype=torch.float)[-new_seq_length:]

prefill_mm: dict[str, Any] | None = None
if prefill_stage is not None:
Expand Down Expand Up @@ -507,7 +547,11 @@ def talker2code2wav_async_chunk(
"""
Pooling version.
"""
if not isinstance(pooling_output, dict):
return None
talker_codes = pooling_output.get("codes", {})
if not isinstance(talker_codes, dict):
return None
code_predictor_codes = talker_codes.get("audio")
if code_predictor_codes is None:
return None
Expand Down Expand Up @@ -600,7 +644,14 @@ def talker2code2wav(
is_streaming_session = bool(getattr(streaming_context, "enabled", False))
if is_streaming_session:
seq_len = _get_streaming_codec_delta_len(cur_seq_len, req_id, talker_output, streaming_context)
mm: OmniPayload = output.multimodal_output
mm_raw = getattr(output, "multimodal_output", None)
if not isinstance(mm_raw, dict):
logger.debug("talker2code2wav: skip req=%s due to empty multimodal_output", req_id)
continue
mm: OmniPayload = mm_raw
if "codes" not in mm or not isinstance(mm.get("codes"), dict) or "audio" not in mm["codes"]:
logger.debug("talker2code2wav: skip req=%s due to missing codes.audio", req_id)
continue
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
codec_codes = (
Expand Down
Loading
Loading