Skip to content

[BugFix] Fix V2 model runner profile_run crash for Omni Talker stages#2819

Merged
tzhouam merged 2 commits intovllm-project:dev/migrate-MR-v2from
Sy0307:fix/v2-omni-model-compat
Apr 15, 2026
Merged

[BugFix] Fix V2 model runner profile_run crash for Omni Talker stages#2819
tzhouam merged 2 commits intovllm-project:dev/migrate-MR-v2from
Sy0307:fix/v2-omni-model-compat

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Apr 15, 2026

Purpose

Fix three V2 model runner bugs that crash during profile_run for Omni Talker stages (Qwen3-Omni and Qwen2.5-Omni). These bugs block V2 runner adoption for all Omni models with Talker stages.

Reproduces CI failures in buildkite builds #6704 (H100) and #6691 (L4).

Root causes and fixes:

  1. encoder_runner buffer shape mismatch — Talker shares Qwen3OmniMoeForConditionalGeneration architecture with Thinker, so multimodal registry marks supports_mm_inputs=True. encoder_runner allocates inputs_embeds buffer using hf_text_config.hidden_size (2048) but Talker's actual embedding dim is 1024 → RuntimeError: expanded size (2048) must match existing size (1024). Fix: disable MM path for models with has_preprocess (embeddings injected via run_preprocess(), not encoder_runner).

  2. OmniOutput not handled in execute_modelOmniOutput is a NamedTuple (len=4), but the tuple intercept only checks len==2, leaving hidden_states as a non-tensor → AssertionError. Fix: check isinstance(OmniOutput) first and extract text_hidden_states.

  3. Qwen2.5-Omni M-RoPE signature mismatch — V2 RopeState calls get_mrope_input_positions(input_tokens, mm_features) but the old signature required keyword-only args (hf_config, image_grid_thw, etc.) → TypeError. Fix: new signature accepts all params as optional, delegates to thinker.get_mrope_input_positions() which extracts grid data from mm_features via gather_kwargs. Old implementation preserved as _get_mrope_input_positions_v1.

  4. Qwen2.5-Omni not in _OMNI_ARCHITECTURES — V2 fell back to DefaultModelState (no run_preprocess) → AttributeError. Fix: register Qwen2_5OmniForConditionalGeneration.

Test Plan

  • Qwen3-Omni 30B V1 text-only (H20 2xGPU) — pass
  • Qwen3-Omni 30B V2 text-only (H20 2xGPU) — pass
  • Qwen2.5-Omni 7B V2 text-only (H20 2xGPU) — pass
  • Qwen3-TTS V1 end2end Base (H20) — pass
  • Qwen3-TTS V2 end2end Base (H20) — pass

Test Result

All 5 test configurations pass with returncode=0 and Processed prompts: 100%.

V2 runner previously crashed at profile_run for all Omni Talker stages; after fix all stages initialize and run inference successfully.

cc @Fattysand @tzhouam

@Sy0307 Sy0307 marked this pull request as ready for review April 15, 2026 08:55
@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner April 15, 2026 08:55
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

For bug 2 (OmniOutput handling), you now set self._last_aux_output = None for OmniOutput but preserve it for the (tensor, dict) tuple path. Is this intentional? If OmniOutput returns aux_dict, should we store it like the tuple path does?

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

BLOCKER scan:

Category Result
Correctness PASS
Reliability/Safety PASS
Breaking Changes PASS
Test Coverage PASS - Local test results provided for 5 configurations
Documentation PASS - PR description is comprehensive
Security PASS

OVERALL: NO BLOCKERS

VERDICT: COMMENT


This is a focused bug fix that addresses four specific issues with the V2 model runner for Omni Talker stages:

  1. encoder_runner buffer shape mismatch - Correctly disabled MM path for preprocess models to avoid embedding dimension mismatch
  2. OmniOutput handling - Added isinstance check for OmniOutput before the tuple intercept
  3. Qwen2.5-Omni M-RoPE signature - Updated get_mrope_input_positions to accept optional parameters
  4. Qwen2.5-Omni registration - Added Qwen2_5OmniForConditionalGeneration to _OMNI_ARCHITECTURES

The fixes are well-reasoned and the local test results show all 5 test configurations passing successfully.

Minor suggestion: Consider adding a regression test for these specific bugs to prevent them from reoccurring. While the existing tests cover the happy path, a targeted test that verifies the fix (e.g., checking that profile_run succeeds for Talker stages with has_preprocess=True) would be valuable.

Note: DCO check is currently failing - please add commit sign-off (e.g., git commit --amend --signoff and force push).

…tages

Fix three V2 model runner bugs that crash during profile_run for
Omni Talker stages:

1. Disable encoder_runner MM path for preprocess models — Talker
   shares the same architecture as Thinker, so multimodal registry
   marks it supports_mm_inputs=True. But encoder_runner allocates
   inputs_embeds buffer using hf_text_config.hidden_size (2048)
   while Talker's actual embedding dim is 1024, causing shape
   mismatch during dummy_run.

2. Handle OmniOutput (NamedTuple, len=4) in execute_model — the
   existing tuple intercept only checks len==2, missing OmniOutput
   and leaving hidden_states as a non-tensor that fails the final
   assert.

3. Adapt Qwen2.5-Omni get_mrope_input_positions signature for V2
   RopeState — V2 calls (input_tokens, mm_features) but the old
   signature required keyword-only args. Delegate to thinker's
   implementation which extracts grid data from mm_features.

4. Register Qwen2_5OmniForConditionalGeneration in
   _OMNI_ARCHITECTURES so V2 uses OmniModelState instead of
   DefaultModelState.

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the fix/v2-omni-model-compat branch from c614190 to 2c16696 Compare April 15, 2026 09:54
# make_omni_output can retrieve them.
if hasattr(self.model, "_last_captured_layers"):
self.model._last_captured_layers = second
hidden_states = first
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the hidden state extraction logic can we do like the following to handle last_aux_output?

if isinstance(model_output, tuple) and len(model_output) == 2:
    first, second = model_output
    if isinstance(first, torch.Tensor):
        self._last_aux_output = second
        if hasattr(self.model, "_last_captured_layers"):
            self.model._last_captured_layers = second
elif isinstance(model_output, OmniOutput) or isinstance(model_output, torch.Tensor):
    first = model_output
else:
    raise TypeError("Error Message")
hidden_states = first if not isinstance(model_output, OmniOutput) else model_output.text_hidden_states
```bash

- Hoist _last_aux_output=None before branches to prevent state leakage
- Remove over-defensive isinstance(first, torch.Tensor) inner check
- Replace silent fallback with raise TypeError for unexpected output types

Signed-off-by: Sy03 <1370724210@qq.com>
@tzhouam tzhouam merged commit 8e6a879 into vllm-project:dev/migrate-MR-v2 Apr 15, 2026
2 checks passed
Sy0307 added a commit to Sy0307/vllm-omni that referenced this pull request Apr 16, 2026
PR vllm-project#2819 added Qwen2_5OmniForConditionalGeneration to
_OMNI_ARCHITECTURES but did not update the corresponding unit
test, causing test_omni_architectures_set_contains_expected to
fail on both simple-unit-test and modelrunner-v2-unit-test CI jobs.

Signed-off-by: Sy03 <1370724210@qq.com>
Sy0307 added a commit to Sy0307/vllm-omni that referenced this pull request Apr 19, 2026
PR vllm-project#2819 added Qwen2_5OmniForConditionalGeneration to
_OMNI_ARCHITECTURES but did not update the corresponding unit
test, causing test_omni_architectures_set_contains_expected to
fail on both simple-unit-test and modelrunner-v2-unit-test CI jobs.

Signed-off-by: Sy03 <1370724210@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants