Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/model_executor/models/test_qwen3_omni_async_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import torch

from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni import Qwen3OmniMoeForConditionalGeneration


def _make_model() -> Qwen3OmniMoeForConditionalGeneration:
model = object.__new__(Qwen3OmniMoeForConditionalGeneration)
model.talker = SimpleNamespace(
num_code_groups=2,
text_projection=lambda x: x + 10,
)
model.tts_eos_embed = torch.tensor([[100.0, 101.0]], dtype=torch.bfloat16)
model.tts_pad_embed = torch.tensor([[200.0, 201.0]], dtype=torch.bfloat16)
return model


def test_talker_preprocess_decode_starts_from_prefill_consumed_boundary():
model = _make_model()
observed = {}

def fake_decode(input_ids, input_embeds, update_dict, **info_dict):
observed["num_processed_tokens"] = info_dict["num_processed_tokens"]
update_dict["num_processed_tokens_delta"] = 1
return torch.zeros((1, 2), dtype=torch.bfloat16), torch.ones((1, 2), dtype=torch.bfloat16), update_dict

model.talker_preprocess_decode = fake_decode

_, _, update_dict = model.talker_preprocess(
input_ids=torch.tensor([1], dtype=torch.long),
input_embeds=torch.ones((1, 2), dtype=torch.bfloat16),
decode_flag=False,
num_processed_tokens=0,
prefill_consumed_text_tokens=1,
)

assert observed["num_processed_tokens"] == 1
assert update_dict["decode_flag"] is True
assert update_dict["num_processed_tokens"] == 2


def test_async_decode_terminal_steps_do_not_advance_processed_tokens():
model = _make_model()
device = torch.device("cpu")

update_dict = {}
text_step = model._thinker_decode_to_talker_decode(
{
"cached_thinker_decode_embeddings": torch.tensor([[1.0, 2.0]], dtype=torch.bfloat16),
"num_processed_tokens": 1,
"thinker_output_token_ids": [11, 12],
},
device,
update_dict,
)

assert torch.equal(text_step, model.tts_eos_embed)
assert update_dict["finished_flag"] is True
assert update_dict["num_processed_tokens_delta"] == 0

update_dict = {}
text_step = model._thinker_decode_to_talker_decode(
{
"cached_thinker_decode_embeddings": torch.tensor([[1.0, 2.0]], dtype=torch.bfloat16),
"num_processed_tokens": 1,
"thinker_output_token_ids": [11, 12],
"finished_flag": True,
},
device,
update_dict,
)

assert torch.equal(text_step, model.tts_pad_embed)
assert update_dict["num_processed_tokens_delta"] == 0


def test_async_decode_consumes_cached_embedding_and_appends_new_runtime_embed():
model = _make_model()
device = torch.device("cpu")
update_dict = {}

text_step = model._thinker_decode_to_talker_decode(
{
"cached_thinker_decode_embeddings": torch.tensor([[1.0, 2.0]], dtype=torch.bfloat16),
"thinker_decode_embeddings": torch.tensor([[3.0, 4.0]], dtype=torch.bfloat16),
"num_processed_tokens": 0,
"thinker_output_token_ids": [11, 12, 13],
},
device,
update_dict,
)

assert torch.equal(text_step, torch.tensor([11.0, 12.0], dtype=torch.bfloat16))
assert torch.equal(
update_dict["cached_thinker_decode_embeddings"],
torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.bfloat16),
)
assert update_dict["num_processed_tokens_delta"] == 1


def test_async_decode_consumes_runtime_embedding_when_cache_is_empty():
model = _make_model()
device = torch.device("cpu")
update_dict = {}

text_step = model._thinker_decode_to_talker_decode(
{
"thinker_decode_embeddings": torch.tensor([[5.0, 6.0]], dtype=torch.bfloat16),
"num_processed_tokens": 0,
"thinker_output_token_ids": [11, 12, 13],
},
device,
update_dict,
)

assert torch.equal(text_step, torch.tensor([15.0, 16.0], dtype=torch.bfloat16))
assert update_dict["thinker_decode_embeddings"] is None
assert update_dict["num_processed_tokens_delta"] == 1
21 changes: 18 additions & 3 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,18 +602,26 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor,
dtype=torch.long,
)
update_dict["code_predictor_codes"] = code_predictor_codes
update_dict["num_processed_tokens_delta"] = span_len
else:
# decode
if not info_dict.get("decode_flag", False):
info_dict["num_processed_tokens"] = 0
# Prefill already consumed the first text token via the
# assistant bootstrap path, so decode starts from the
# remaining-text boundary rather than cumulative index 0.
prefill_consumed_text_tokens = info_dict.get("prefill_consumed_text_tokens")
if prefill_consumed_text_tokens is None:
raise RuntimeError("Missing prefill_consumed_text_tokens for talker decode handoff.")
info_dict["num_processed_tokens"] = prefill_consumed_text_tokens
update_dict["decode_flag"] = True

last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode(
input_ids, input_embeds, update_dict, **info_dict
)
update_dict["mtp_inputs"] = last_talker_hidden, text_step

update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + span_len
processed_delta = update_dict.pop("num_processed_tokens_delta")
update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + processed_delta
return input_ids, input_embeds, update_dict

def talker_mtp(
Expand Down Expand Up @@ -782,6 +790,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
update_dict["tts_pad_embed_projected"] = pad_proj.detach()
except Exception:
pass
update_dict["prefill_consumed_text_tokens"] = 1
self._talker_cache_thinker_decode_embeds(info_dict, update_dict)

return req_input_ids[start_index:end_index], req_embeds[start_index:end_index], update_dict
Expand Down Expand Up @@ -909,10 +918,12 @@ def _thinker_decode_to_talker_decode(
start_index = info_dict.get("num_processed_tokens", 0)
thinker_output_token_ids = info_dict.get("thinker_output_token_ids", [])
if start_index >= len(thinker_output_token_ids) - 1:
update_dict["num_processed_tokens_delta"] = 0
if info_dict.get("finished_flag"):
return self.tts_pad_embed.to(device)
update_dict["finished_flag"] = True
return self.tts_eos_embed.to(device)

if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]:
cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device)
thinker_embed = cached_thinker_decode_embeds[start_index]
Expand All @@ -921,8 +932,12 @@ def _thinker_decode_to_talker_decode(
cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0)
update_dict["cached_thinker_decode_embeddings"] = cached_thinker_decode_embeds
else:
thinker_embed = thinker_decode_embed.to(device)
thinker_embed = thinker_decode_embed
if thinker_embed.device != device:
thinker_embed = thinker_embed.to(device)

update_dict["thinker_decode_embeddings"] = None
update_dict["num_processed_tokens_delta"] = 1
return self.talker.text_projection(thinker_embed).to(device)

def talker_preprocess_decode(
Expand Down
Loading