Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/scripts/simple_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ VENV_PYTHON="${VENV_DIR}/bin/python"
"${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/
"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/
"${VENV_PYTHON}" -m pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py
"${VENV_PYTHON}" -m pytest -v -s tests/worker/
123 changes: 123 additions & 0 deletions tests/worker/test_omni_gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from contextlib import contextmanager

import torch

from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner


class DummyBuffer:
"""A minimal buffer wrapper that exposes the `.gpu` attribute."""

def __init__(self, t: torch.Tensor):
self.gpu = t


class DummyInputBatch:
"""A minimal input batch that only provides `req_ids`."""

def __init__(self, req_ids):
self.req_ids = req_ids


class DummyReqState:
"""A minimal request state container."""

pass


class DummyTalkerMTP(torch.nn.Module):
"""A fake talker_mtp module for deterministic CPU testing."""

def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step):
# Deterministic behavior:
# - output embeds = input embeds + 1
# - output codes = [[0], [1], ...]
bsz = req_embeds.shape[0]
new_embeds = req_embeds + 1.0
codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1)
return new_embeds, codes


@contextmanager
def _noop_forward_context(*args, **kwargs):
"""A no-op context manager to replace vLLM forward context in CPU tests."""
yield


def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
# Create an instance without calling OmniGPUModelRunner.__init__
runner = object.__new__(OmniGPUModelRunner)

# Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
runner.input_batch = DummyInputBatch(list(req_ids))
runner.requests = {rid: DummyReqState() for rid in req_ids}

# query_start_loc.cpu[req_index] is used to locate the token position
# in the flattened `inputs_embeds`.
runner.query_start_loc = type("QSL", (), {})()
# Map: r1 -> offset 0, r2 -> offset 3
runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32)

bsz = len(req_ids)
runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64))
runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))

runner.talker_mtp = DummyTalkerMTP()
runner.vllm_config = object()

# Provide a minimal implementation that returns the expected 4-tuple.
def _determine_batch_execution_and_padding(**kwargs):
return None, object(), None, None

runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding

# Use the real merge method from OmniGPUModelRunner.
return runner


def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch):
# Patch the module-level `set_forward_context` symbol used inside
# OmniGPUModelRunner._talker_mtp_forward.
import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner

monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)

runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4)

# Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds)
runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0])
runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0])

# Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten
inputs_embeds = torch.zeros((6, 4), dtype=torch.float32)

# Call the original implementation from OmniGPUModelRunner (no re-implementation)
OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds)

# Validate embeds were written back (+1)
assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0]))
assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0]))

# Validate per-request additional_information_cpu was updated
info_r1 = runner.requests["r1"].additional_information_cpu
info_r2 = runner.requests["r2"].additional_information_cpu
assert int(info_r1["code_predictor_codes"][0, 0]) == 0
assert int(info_r2["code_predictor_codes"][0, 0]) == 1


def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
import vllm_omni.worker.gpu_model_runner as mod

monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)

runner = _make_runner(req_ids=("r1",), hidden_size=4)

inputs_embeds = torch.randn((2, 4))
before = inputs_embeds.clone()

OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds)

# Ensure no changes were made
assert torch.allclose(inputs_embeds, before)
2 changes: 0 additions & 2 deletions vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,6 @@ async def _stage_worker_async(
except Exception as e:
logger.warning("Device setup failed: %s", e)

max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
engine_args["max_num_seqs"] = max_batch_size
# Initialize OmniConnectors if configured to match sync worker behavior
connectors: dict[Any, Any] = {}
if connectors_config:
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None
# Update base_engine_args with stage-specific engine_args if they exist
if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None:
base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args)
if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None:
runtime_cfg = stage_arg.runtime
max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
base_engine_args_tmp["max_num_seqs"] = max_batch_size
stage_arg.engine_args = base_engine_args_tmp
return stage_args

Expand Down
47 changes: 17 additions & 30 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,30 +573,22 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor,
if input_embeds is None and input_ids is not None:
input_embeds = self.talker.embed_input_ids(input_ids)

text_step = torch.zeros(
1,
self.talker_config.text_config.hidden_size,
device=self._module_device(self.talker),
dtype=torch.bfloat16,
)
last_talker_hidden = torch.zeros(
1,
1,
self.talker_config.text_config.hidden_size,
device=self._module_device(self.talker),
dtype=torch.bfloat16,
)

span_len = input_ids.shape[0]
if span_len > 1:
# prefill
input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict)
code_predictor_codes = torch.zeros(
(input_embeds.shape[0], self.talker.num_code_groups),
device=self._module_device(self.talker),
dtype=torch.long,
)
update_dict["code_predictor_codes"] = code_predictor_codes
else:
last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode(
input_ids, input_embeds, **info_dict
)

update_dict["mtp_inputs"] = last_talker_hidden, text_step
update_dict["mtp_inputs"] = last_talker_hidden, text_step

return input_ids, input_embeds, update_dict

Expand All @@ -608,24 +600,19 @@ def talker_mtp(
text_step: torch.Tensor,
):
# TODO(Peiqi): not support intermediate_tensors now
input_ids = safe_tensor_reshape(input_ids, (1, -1))
input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1))
inputs_embeds = safe_tensor_reshape(input_embeds, (-1, self.talker_config.text_config.hidden_size))
text_step = safe_tensor_reshape(text_step, (1, -1))
last_talker_hidden = safe_tensor_reshape(last_talker_hidden, (1, 1, self.talker_config.text_config.hidden_size))
text_step = safe_tensor_reshape(text_step, (-1, self.talker_config.text_config.hidden_size))
last_talker_hidden = safe_tensor_reshape(
last_talker_hidden, (-1, 1, self.talker_config.text_config.hidden_size)
)
# for profiling
if inputs_embeds.shape[-1] == 2048:
inputs_embeds = self.text_projection(inputs_embeds)
if inputs_embeds.shape[0] == 1:
code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
)
inputs_embeds = summed_embeddings.clone()
else:
code_predictor_codes = torch.zeros(
(inputs_embeds.shape[0], self.talker.num_code_groups),
device=self._module_device(self.talker),
dtype=torch.long,
)
code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
)
inputs_embeds = summed_embeddings.clone()
inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size)
return inputs_embeds, code_predictor_codes.squeeze(-1)

Expand Down Expand Up @@ -848,7 +835,7 @@ def talker_preprocess_decode(self, input_ids: torch.Tensor, input_embeds: torch.
use_vec = q_tail[0:1, :]
new_q_tail = (
q_tail[1:, :].detach().to("cpu").contiguous()
if q_tail.shape[1] > 1
if q_tail.shape[0] > 1
else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
)
text_step = use_vec.to(input_embeds.device, dtype=input_embeds.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,7 @@ def code_predictor_forward(
# Use the corresponding lm_head for this layer
logits = self.code_predictor.lm_head[layer_idx](hidden_state[:, -1:, :]) # [batch, 1, vocab_size]

if len(pos_codes) > 1:
input_ids_for_logits_processors = torch.cat(pos_codes[1:], dim=1).to(
device=logits.device, dtype=torch.long
)
else:
input_ids_for_logits_processors = self.empty_code
logits = logits_processors(input_ids_for_logits_processors, logits.squeeze(0)).unsqueeze(0)
logits = logits_processors(None, logits[:, -1])

# Sample from the filtered distribution
probs = F.softmax(logits, dim=-1)
Expand Down Expand Up @@ -288,7 +282,7 @@ def code_predictor_forward(
all_summed_embeddings.append(pos_summed)

# Concatenate across positions: [batch, seq_len, hidden_size]
summed_embeddings = torch.cat(all_summed_embeddings, dim=1)
summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1)

return result_codes, summed_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def talker2code2wav(
# Process each talker output
for i, talker_output in enumerate(talker_outputs):
output = talker_output.outputs[0]
seq_len = len(output.token_ids)
seq_len = len(output.token_ids) - 1
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
codec_codes = (
Expand Down
72 changes: 41 additions & 31 deletions vllm_omni/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,7 @@ def _dummy_run(
ubatch_slices=ubatch_slices,
),
):
if (
getattr(self.model, "talker", None) is not None
and hasattr(self.model, "talker_mtp")
and num_tokens_padded == 1
):
if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
outputs = self.talker_mtp(
self.talker_mtp_input_ids.gpu[:num_tokens_padded],
self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded],
Comment on lines +566 to 569
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard talker_mtp call to decode-populated buffers

This call now runs for every batch, but the talker_mtp_* buffers are only populated for decode requests in _prepare_inputs when span_len == 1. For prefill-only batches (or any batch where num_tokens_padded exceeds the number of decode requests), the slice [:num_tokens_padded] pulls stale/uninitialized entries (and can exceed the max_num_reqs-sized buffers), which can drive talker_mtp with invalid token IDs/hidden states and trigger out-of-range embedding errors or random sampling on garbage data. This should be guarded by the actual decode batch size (or removed) and sliced by that size rather than num_tokens_padded.

Useful? React with 👍 / 👎.

Expand Down Expand Up @@ -884,6 +880,7 @@ def _preprocess(
if hasattr(self.model, "has_preprocess") and self.model.has_preprocess:
# Overlay custom prompt_embeds per request for the prompt portion;
# collect additional_information (tensor/list) for prefill portion only
decode_req_ids = []
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests.get(req_id)
req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
Expand All @@ -897,33 +894,14 @@ def _preprocess(
req_input_ids, req_embeds, update_dict = self.model.preprocess(
input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos
)
# run talker mtp decode
if hasattr(self.model, "talker_mtp"):
_cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding(
num_tokens=span_len,
num_reqs=1,
num_scheduled_tokens_np=num_scheduled_tokens_np[req_index],
max_num_scheduled_tokens=1,
force_eager=span_len > 1,
use_cascade_attn=False,
)
if hasattr(self.model, "talker_mtp") and span_len == 1:
last_talker_hidden, text_step = update_dict.pop("mtp_inputs")
if _cudagraph_mode != CUDAGraphMode.NONE:
self.talker_mtp_input_ids.gpu[:span_len].copy_(req_input_ids)
self.talker_mtp_inputs_embeds.gpu[:span_len].copy_(req_embeds)
self.last_talker_hidden.gpu[:span_len].copy_(last_talker_hidden)
self.text_step.gpu[:span_len].copy_(text_step)
req_input_ids = self.talker_mtp_input_ids.gpu[:span_len]
req_embeds = self.talker_mtp_inputs_embeds.gpu[:span_len]
last_talker_hidden = self.last_talker_hidden.gpu[:span_len]
text_step = self.text_step.gpu[:span_len]
with set_forward_context(
None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
req_embeds, code_predictor_codes = self.talker_mtp(
req_input_ids, req_embeds, last_talker_hidden, text_step
)
update_dict["code_predictor_codes"] = code_predictor_codes
decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1)
self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids)
self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds)
self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden)
self.text_step.gpu[decode_slice].copy_(text_step)
decode_req_ids.append(req_id)

# TODO(Peiqi): the merge stage could move out from the critical path
self._merge_additional_information_update(req_id, update_dict)
Expand All @@ -934,6 +912,10 @@ def _preprocess(
if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len:
input_ids[s : s + seg_len] = req_input_ids

# run talker mtp decode
if hasattr(self.model, "talker_mtp"):
self._talker_mtp_forward(decode_req_ids, inputs_embeds)

return (
input_ids,
inputs_embeds,
Expand All @@ -943,6 +925,34 @@ def _preprocess(
ec_connector_output,
)

def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None:
decode_batch_size = len(decode_req_ids)
if decode_batch_size == 0:
return
_cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding(
num_tokens=decode_batch_size,
num_reqs=decode_batch_size,
num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32),
max_num_scheduled_tokens=1,
use_cascade_attn=False,
)
req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size]
req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size]
last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size]
text_step = self.text_step.gpu[:decode_batch_size]
with set_forward_context(
None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
# update the inputs_embeds and code_predictor_codes
code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous()
for idx, req_id in enumerate(decode_req_ids):
req_index = self.input_batch.req_ids.index(req_id)
start_offset = int(self.query_start_loc.cpu[req_index])
inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By this code we can directly use inputs_embeds = req_embeds? then avoid iterate each request. I can remove the update of "code_predictor_codes" and "merge update" here. If so, we maybe able to directly place batch_mtp_forward inside the model forward.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

req_embeds does not include prefill requests. Assigning inputs_embeds = req_embeds directly maybe incorrect for mixed prefill and decode batch.

update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]}
self._merge_additional_information_update(req_id, update_dict)

def _model_forward(
self,
input_ids: torch.Tensor | None = None,
Expand Down