From d83cb9c96aa55111b2ed420d97acb21e7dcb7e21 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 24 Mar 2026 20:34:30 +0800 Subject: [PATCH 1/3] Fix async talker handoff decode progression Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/qwen3_omni/qwen3_omni.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 6dcd278acae..edf3f21f564 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -611,15 +611,31 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, else: # decode if not info_dict.get("decode_flag", False): - info_dict["num_processed_tokens"] = 0 + # Prefill already consumed the first text token via the + # assistant bootstrap path, so decode starts from the + # remaining-text boundary rather than cumulative index 0. + prefill_consumed_text_tokens = info_dict.get("prefill_consumed_text_tokens") + if prefill_consumed_text_tokens is None: + raise RuntimeError("Missing prefill_consumed_text_tokens for talker decode handoff.") + info_dict["num_processed_tokens"] = prefill_consumed_text_tokens update_dict["decode_flag"] = True + update_dict["prefill_consumed_text_tokens"] = prefill_consumed_text_tokens last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode( input_ids, input_embeds, update_dict, **info_dict ) update_dict["mtp_inputs"] = last_talker_hidden, text_step + if "run_talker_mtp" not in update_dict: + update_dict["run_talker_mtp"] = True + if not update_dict["run_talker_mtp"]: + update_dict["code_predictor_codes"] = torch.zeros( + (1, self.talker.num_code_groups), + device=self._module_device(self.talker), + dtype=torch.long, + ) - update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + span_len + processed_delta = update_dict.pop("num_processed_tokens_delta", span_len) + update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + processed_delta return input_ids, input_embeds, update_dict def talker_mtp( @@ -778,6 +794,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch update_dict["tts_pad_embed_projected"] = pad_proj.detach() except Exception: pass + update_dict["prefill_consumed_text_tokens"] = 1 self._talker_cache_thinker_decode_embeds(info_dict, update_dict) return req_input_ids[start_index:end_index], req_embeds[start_index:end_index], update_dict @@ -905,20 +922,34 @@ def _thinker_decode_to_talker_decode( start_index = info_dict.get("num_processed_tokens", 0) thinker_output_token_ids = info_dict.get("thinker_output_token_ids", []) if start_index >= len(thinker_output_token_ids) - 1: + update_dict["num_processed_tokens_delta"] = 0 if info_dict.get("finished_flag"): + update_dict["run_talker_mtp"] = False return self.tts_pad_embed.to(device) + update_dict["run_talker_mtp"] = True update_dict["finished_flag"] = True return self.tts_eos_embed.to(device) - if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]: - cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) + + cached_len = 0 + if cached_thinker_decode_embeds is not None: + if cached_thinker_decode_embeds.device != device: + cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) + cached_len = cached_thinker_decode_embeds.shape[0] + + if cached_thinker_decode_embeds is not None and start_index < cached_len: thinker_embed = cached_thinker_decode_embeds[start_index] - if thinker_decode_embed is not None: - thinker_decode_embed = thinker_decode_embed.to(device) - cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0) - update_dict["cached_thinker_decode_embeddings"] = cached_thinker_decode_embeds else: - thinker_embed = thinker_decode_embed.to(device) + if thinker_decode_embed is None: + update_dict["run_talker_mtp"] = False + update_dict["num_processed_tokens_delta"] = 0 + return self.tts_pad_embed.to(device) + thinker_embed = thinker_decode_embed + if thinker_embed.device != device: + thinker_embed = thinker_embed.to(device) + update_dict["thinker_decode_embeddings"] = None + update_dict["run_talker_mtp"] = True + update_dict["num_processed_tokens_delta"] = 1 return self.talker.text_projection(thinker_embed).to(device) def talker_preprocess_decode( From f918ebfa8b098298aaa14f3c93d31c520b2aa89d Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 24 Mar 2026 23:00:03 +0800 Subject: [PATCH 2/3] Remove model-side talker MTP gating Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/qwen3_omni/qwen3_omni.py | 32 +++++-------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index edf3f21f564..e5f8e524062 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -608,6 +608,7 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, dtype=torch.long, ) update_dict["code_predictor_codes"] = code_predictor_codes + update_dict["num_processed_tokens_delta"] = span_len else: # decode if not info_dict.get("decode_flag", False): @@ -619,22 +620,13 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, raise RuntimeError("Missing prefill_consumed_text_tokens for talker decode handoff.") info_dict["num_processed_tokens"] = prefill_consumed_text_tokens update_dict["decode_flag"] = True - update_dict["prefill_consumed_text_tokens"] = prefill_consumed_text_tokens last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode( input_ids, input_embeds, update_dict, **info_dict ) update_dict["mtp_inputs"] = last_talker_hidden, text_step - if "run_talker_mtp" not in update_dict: - update_dict["run_talker_mtp"] = True - if not update_dict["run_talker_mtp"]: - update_dict["code_predictor_codes"] = torch.zeros( - (1, self.talker.num_code_groups), - device=self._module_device(self.talker), - dtype=torch.long, - ) - processed_delta = update_dict.pop("num_processed_tokens_delta", span_len) + processed_delta = update_dict.pop("num_processed_tokens_delta") update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + processed_delta return input_ids, input_embeds, update_dict @@ -924,31 +916,23 @@ def _thinker_decode_to_talker_decode( if start_index >= len(thinker_output_token_ids) - 1: update_dict["num_processed_tokens_delta"] = 0 if info_dict.get("finished_flag"): - update_dict["run_talker_mtp"] = False return self.tts_pad_embed.to(device) - update_dict["run_talker_mtp"] = True update_dict["finished_flag"] = True return self.tts_eos_embed.to(device) - cached_len = 0 - if cached_thinker_decode_embeds is not None: - if cached_thinker_decode_embeds.device != device: - cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) - cached_len = cached_thinker_decode_embeds.shape[0] - - if cached_thinker_decode_embeds is not None and start_index < cached_len: + if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]: + cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) thinker_embed = cached_thinker_decode_embeds[start_index] + if thinker_decode_embed is not None: + thinker_decode_embed = thinker_decode_embed.to(device) + cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0) + update_dict["cached_thinker_decode_embeddings"] = cached_thinker_decode_embeds else: - if thinker_decode_embed is None: - update_dict["run_talker_mtp"] = False - update_dict["num_processed_tokens_delta"] = 0 - return self.tts_pad_embed.to(device) thinker_embed = thinker_decode_embed if thinker_embed.device != device: thinker_embed = thinker_embed.to(device) update_dict["thinker_decode_embeddings"] = None - update_dict["run_talker_mtp"] = True update_dict["num_processed_tokens_delta"] = 1 return self.talker.text_projection(thinker_embed).to(device) From d2d3c9a7d8e2fb6128dee9afbb3a2d04bd8de85c Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Wed, 25 Mar 2026 17:46:07 +0800 Subject: [PATCH 3/3] Add async decode handoff unit tests Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/test_qwen3_omni_async_decode.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/model_executor/models/test_qwen3_omni_async_decode.py diff --git a/tests/model_executor/models/test_qwen3_omni_async_decode.py b/tests/model_executor/models/test_qwen3_omni_async_decode.py new file mode 100644 index 00000000000..2ba2331c3e8 --- /dev/null +++ b/tests/model_executor/models/test_qwen3_omni_async_decode.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import torch + +from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni import Qwen3OmniMoeForConditionalGeneration + + +def _make_model() -> Qwen3OmniMoeForConditionalGeneration: + model = object.__new__(Qwen3OmniMoeForConditionalGeneration) + model.talker = SimpleNamespace( + num_code_groups=2, + text_projection=lambda x: x + 10, + ) + model.tts_eos_embed = torch.tensor([[100.0, 101.0]], dtype=torch.bfloat16) + model.tts_pad_embed = torch.tensor([[200.0, 201.0]], dtype=torch.bfloat16) + return model + + +def test_talker_preprocess_decode_starts_from_prefill_consumed_boundary(): + model = _make_model() + observed = {} + + def fake_decode(input_ids, input_embeds, update_dict, **info_dict): + observed["num_processed_tokens"] = info_dict["num_processed_tokens"] + update_dict["num_processed_tokens_delta"] = 1 + return torch.zeros((1, 2), dtype=torch.bfloat16), torch.ones((1, 2), dtype=torch.bfloat16), update_dict + + model.talker_preprocess_decode = fake_decode + + _, _, update_dict = model.talker_preprocess( + input_ids=torch.tensor([1], dtype=torch.long), + input_embeds=torch.ones((1, 2), dtype=torch.bfloat16), + decode_flag=False, + num_processed_tokens=0, + prefill_consumed_text_tokens=1, + ) + + assert observed["num_processed_tokens"] == 1 + assert update_dict["decode_flag"] is True + assert update_dict["num_processed_tokens"] == 2 + + +def test_async_decode_terminal_steps_do_not_advance_processed_tokens(): + model = _make_model() + device = torch.device("cpu") + + update_dict = {} + text_step = model._thinker_decode_to_talker_decode( + { + "cached_thinker_decode_embeddings": torch.tensor([[1.0, 2.0]], dtype=torch.bfloat16), + "num_processed_tokens": 1, + "thinker_output_token_ids": [11, 12], + }, + device, + update_dict, + ) + + assert torch.equal(text_step, model.tts_eos_embed) + assert update_dict["finished_flag"] is True + assert update_dict["num_processed_tokens_delta"] == 0 + + update_dict = {} + text_step = model._thinker_decode_to_talker_decode( + { + "cached_thinker_decode_embeddings": torch.tensor([[1.0, 2.0]], dtype=torch.bfloat16), + "num_processed_tokens": 1, + "thinker_output_token_ids": [11, 12], + "finished_flag": True, + }, + device, + update_dict, + ) + + assert torch.equal(text_step, model.tts_pad_embed) + assert update_dict["num_processed_tokens_delta"] == 0 + + +def test_async_decode_consumes_cached_embedding_and_appends_new_runtime_embed(): + model = _make_model() + device = torch.device("cpu") + update_dict = {} + + text_step = model._thinker_decode_to_talker_decode( + { + "cached_thinker_decode_embeddings": torch.tensor([[1.0, 2.0]], dtype=torch.bfloat16), + "thinker_decode_embeddings": torch.tensor([[3.0, 4.0]], dtype=torch.bfloat16), + "num_processed_tokens": 0, + "thinker_output_token_ids": [11, 12, 13], + }, + device, + update_dict, + ) + + assert torch.equal(text_step, torch.tensor([11.0, 12.0], dtype=torch.bfloat16)) + assert torch.equal( + update_dict["cached_thinker_decode_embeddings"], + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.bfloat16), + ) + assert update_dict["num_processed_tokens_delta"] == 1 + + +def test_async_decode_consumes_runtime_embedding_when_cache_is_empty(): + model = _make_model() + device = torch.device("cpu") + update_dict = {} + + text_step = model._thinker_decode_to_talker_decode( + { + "thinker_decode_embeddings": torch.tensor([[5.0, 6.0]], dtype=torch.bfloat16), + "num_processed_tokens": 0, + "thinker_output_token_ids": [11, 12, 13], + }, + device, + update_dict, + ) + + assert torch.equal(text_step, torch.tensor([15.0, 16.0], dtype=torch.bfloat16)) + assert update_dict["thinker_decode_embeddings"] is None + assert update_dict["num_processed_tokens_delta"] == 1