Skip to content

[Bugfix] Fix Triton operator usage for multimodal models based on the mrope_interleaved parameter#6042

Merged
wangxiyuan merged 1 commit intovllm-project:mainfrom
zhangxinyuehfad:zxy_fix_omini_acc
Jan 22, 2026
Merged

[Bugfix] Fix Triton operator usage for multimodal models based on the mrope_interleaved parameter#6042
wangxiyuan merged 1 commit intovllm-project:mainfrom
zhangxinyuehfad:zxy_fix_omini_acc

Conversation

@zhangxinyuehfad
Copy link
Copy Markdown
Collaborator

@zhangxinyuehfad zhangxinyuehfad commented Jan 20, 2026

What this PR does / why we need it?

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails during the profiling/warmup stage with the following error: AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035. The vector core execution is abnormal.

error log: https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the triton_mrope kernel when handling the unique mrope_section configurations of the Omni model. Other multimodal models with standard sections (e.g., [16, 24, 24]) or standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check before calling forward_triton.

  1. For standard LLMs (mrope_interleaved = True ), it continues to use Triton for acceleration.

  2. For complex configurations (like Qwen2.5-Omni mrope_interleaved = False ), it now falls back to the native super().forward_oot() path, which uses the stable torch_npu or PyTorch implementation.

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@vllm-ascend-ci vllm-ascend-ci added ready read for review ready-for-test start test by label for PR accuracy-test enable all accuracy test for PR labels Jan 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug causing NPU vector core errors when running models with complex mrope_section configurations, such as Qwen2.5-Omni. The fix involves modifying AscendMRotaryEmbedding.forward_oot to bypass the triton_mrope kernel and fall back to more stable torch_npu or native PyTorch implementations.

The change correctly prevents the failing Triton kernel from being called. However, the implementation is confusing and logically flawed. The condition self.mrope_section is None is always false for AscendMRotaryEmbedding, making the if block dead code. I've provided a suggestion to make the intent of disabling the Triton path explicit and improve code clarity and maintainability.

Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
@zhangxinyuehfad zhangxinyuehfad force-pushed the zxy_fix_omini_acc branch 2 times, most recently from d89f384 to 057ccbc Compare January 20, 2026 14:15
@vllm-ascend-ci vllm-ascend-ci added accuracy-test enable all accuracy test for PR and removed accuracy-test enable all accuracy test for PR labels Jan 20, 2026
@zhangxinyuehfad zhangxinyuehfad force-pushed the zxy_fix_omini_acc branch 2 times, most recently from 8cbf798 to c4c950d Compare January 21, 2026 03:29
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
@zhangxinyuehfad zhangxinyuehfad changed the title [Bugfix] skip triton mrope for complex mrope_sections to prevent NPU vector core error [Bugfix] Fix Triton operator usage for multimodal models based on the mrope_interleaved parameter Jan 22, 2026
yiz-liu pushed a commit that referenced this pull request Jan 22, 2026
…d on `the mrope_interleaved` parameter (#6074)

### What this PR does / why we need it?
picked from #6042

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
@wangxiyuan wangxiyuan merged commit 9bba0a2 into vllm-project:main Jan 22, 2026
20 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Jan 22, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (51 commits)
  [Bugfix] Remove `use_aclgraph` in mtp_proposer and use `use_cuda_graph` (vllm-project#6032)
  [BugFix] fix 3vl dense model load quant weight (vllm-project#6100)
  [CP&SP] Integrate FIA operator in mla_cp._forward_decode (vllm-project#5641)
  [CI][Doc] Upgrade wheel building's CANN to 8.5.0 and update the Docs (vllm-project#6145)
  [CI]Install clang in dokerfile for triton ascend (vllm-project#4409)
  [Main] Upgrade PTA to 2.9.0 (vllm-project#6112)
  [Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias (vllm-project#5721)
  [P/D][PCP]bugfix pcp force free twice caused logger error (vllm-project#6124)
  [BugFix]converting pa get_workspace back to capturing (vllm-project#5833)
  [CI] optimize lint term (vllm-project#5986)
  [Bugfix] Fix Triton operator usage for multimodal models based on `the mrope_interleaved` parameter (vllm-project#6042)
  [bugfix][npugraph_ex]fix the model output type issue caused by manually modify FX graph (vllm-project#6015)
  [BugFix] Support setting tp=1 for the Eagle draft model to take effect (vllm-project#6097)
  [Misc] Bump mooncake version to v0.3.8.post1 (vllm-project#6110)
  [Feature]Enable DispatchGmmCombineDecode when eagle is moe with w8a8 or not moe [RFC: issue 5476] (vllm-project#5758)
  [bugfix] adapt_remote_request_id (vllm-project#6051)
  [Feature] Add support of new W4A4_LAOS_DYNAMIC quantization method (vllm-project#5143)
  [Feature] Support DSA-CP for Hybrid scenario (vllm-project#5702)
  [CI] Upgrade CANN to 8.5.0 (vllm-project#6070)
  Default enable MLAPO (vllm-project#5952)
  ...
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…e mrope_interleaved` parameter (vllm-project#6042)

### What this PR does / why we need it?

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…d on `the mrope_interleaved` parameter (vllm-project#6074)

### What this PR does / why we need it?
picked from vllm-project#6042

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
tangtiangu pushed a commit to tangtiangu/jiusi-vllm-ascend that referenced this pull request Feb 24, 2026
…d on `the mrope_interleaved` parameter (vllm-project#6074)

### What this PR does / why we need it?
picked from vllm-project#6042

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
tangtiangu pushed a commit to tangtiangu/jiusi-vllm-ascend that referenced this pull request Feb 24, 2026
…d on `the mrope_interleaved` parameter (vllm-project#6074)

### What this PR does / why we need it?
picked from vllm-project#6042

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…e mrope_interleaved` parameter (vllm-project#6042)

### What this PR does / why we need it?

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
NickJudyHvv added a commit to NickJudyHvv/vllm-ascend that referenced this pull request Mar 2, 2026
Adapted from vllm-ascend PR vllm-project#5664 and PR vllm-project#6042. Adds Triton-based mRoPE
support to AscendMRotaryEmbedding, with fix to only use Triton path when
mrope_interleaved is True.
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…e mrope_interleaved` parameter (vllm-project#6042)

### What this PR does / why we need it?

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…e mrope_interleaved` parameter (vllm-project#6042)

### What this PR does / why we need it?

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…e mrope_interleaved` parameter (vllm-project#6042)

### What this PR does / why we need it?

When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails
during the profiling/warmup stage with the following error:
`AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035.
The vector core execution is abnormal.`

error log:
https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412

This error is specifically triggered by the `triton_mrope` kernel when
handling the unique `mrope_section` configurations of the Omni model.
Other multimodal models with standard sections (e.g., [16, 24, 24]) or
standard LLMs work correctly with Triton.

Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check
before calling forward_triton.

1. For standard LLMs (mrope_interleaved = True ), it continues to use
Triton for acceleration.

2. For complex configurations (like Qwen2.5-Omni mrope_interleaved =
False ), it now falls back to the native super().forward_oot() path,
which uses the stable torch_npu or PyTorch implementation.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
@zhangxinyuehfad zhangxinyuehfad deleted the zxy_fix_omini_acc branch March 19, 2026 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

accuracy-test enable all accuracy test for PR module:ops ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants