Skip to content

[bugfix][npugraph_ex]duplicate pattern issue#6513

Merged
wangxiyuan merged 2 commits intovllm-project:mainfrom
ChenCangtao:duplicate-pattern
Feb 4, 2026
Merged

[bugfix][npugraph_ex]duplicate pattern issue#6513
wangxiyuan merged 2 commits intovllm-project:mainfrom
ChenCangtao:duplicate-pattern

Conversation

@ChenCangtao
Copy link
Copy Markdown
Contributor

@ChenCangtao ChenCangtao commented Feb 3, 2026

What this PR does / why we need it?

When the draft model also uses vllmbackend for graph compilation, the fusion pass registration occurs again, resulting in errors due to duplicate patterns.

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: chencangtao <chencangtao@huawei.com>
@ChenCangtao ChenCangtao requested a review from yiz-liu as a code owner February 3, 2026 12:12
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ChenCangtao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an issue where graph fusion patterns were being redundantly registered, leading to errors, especially when multiple compilation passes or draft models were active. By implementing a centralized and guarded registration mechanism, the system now ensures that each fusion pattern is registered only once, improving stability and preventing runtime exceptions related to duplicate pattern definitions in the graph compilation process.

Highlights

  • Duplicate Pattern Prevention: Introduced a new utility function check_and_register_fusion_pass to prevent RuntimeError caused by attempting to register the same graph fusion pattern multiple times.
  • Refactored Fusion Pass Registration: Modified the initialization of GraphEXMatmulAllReduceAddRMSNormPass, GraphEXAddRMSNormQuantPass, and GraphEXQKNormRopeFusionPass to utilize the new check_and_register_fusion_pass for safe pattern registration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py
    • Imported check_and_register_fusion_pass from npugraph_ex_utils_check.
    • Updated the __init__ method of GraphEXMatmulAllReduceAddRMSNormPass to use check_and_register_fusion_pass for registering GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern and GraphEXLastLayerMatmulAllReduceAddRMSNormPattern.
  • vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py
    • Imported check_and_register_fusion_pass from npugraph_ex_utils_check.
    • Updated the __init__ method of GraphEXAddRMSNormQuantPass to use check_and_register_fusion_pass for registering various GraphEXAddRMSNormQuantPattern types.
  • vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py
    • Imported check_and_register_fusion_pass from npugraph_ex_utils_check.
    • Updated the __init__ method of GraphEXQKNormRopeFusionPass to use check_and_register_fusion_pass for registering GraphEXQKNormRopeFusionPattern and GraphEXQKNormRopeFusionPatternWithBias.
  • vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py
    • Imported VllmConfig.
    • Added a global _register_patterns set to track registered patterns.
    • Implemented check_and_register_fusion_pass function, which attempts to register a pattern only if it hasn't been registered before, catching and logging RuntimeError for duplicate patterns.
Activity
  • The pull request was opened by ChenCangtao to address an issue where duplicate pattern registrations caused errors during graph compilation, especially when a draft model used vllmbackend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to prevent duplicate registration of fusion passes, which is a good approach to solve the described problem. The changes correctly apply a new helper function across several files. However, the implementation of the core helper function check_and_register_fusion_pass contains a critical typo that makes the fix ineffective. I've provided a comment with a suggested fix for this issue.

Comment on lines +60 to +76
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _resgister_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _resgister_patterns:
return

pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a typo in the global variable name _resgister_patterns on lines 61 and 64. It should be _register_patterns. This typo causes the check for already registered patterns to always fail because it refers to a different, empty set. This defeats the purpose of this function and will not prevent duplicate registration errors. This is a critical bug.

Suggested change
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _resgister_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _resgister_patterns:
return
pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _register_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _register_patterns:
return
pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 3, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: chencangtao <chencangtao@huawei.com>

Signed-off-by: chencangtao <chencangtao@huawei.com>
@ChenCangtao ChenCangtao changed the title duplicate pattern issue [bugfix][npugraph_ex]duplicate pattern issue Feb 3, 2026
@wangxiyuan wangxiyuan merged commit fa56abe into vllm-project:main Feb 4, 2026
16 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 6, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (59 commits)
  [Feat.]: 310p support MOE models (vllm-project#6530)
  [Doc] backport 0.13.0 release note (vllm-project#6584)
  [CI] Update UT CANN version to 8.5.0 for main branch (vllm-project#6564)
  [CI] Change A2 runner (vllm-project#6557)
  [Bugfix] Fix the incorrect use of the output parameter in _forward_fia_slidingwindow (vllm-project#6469)
  [main2main] upgrade vllm main 0202 (vllm-project#6560)
  [CI][npugraph_ex]Fix npugraph ex e2e test (vllm-project#6553)
  [Feature]KV pool supports sparse attention (vllm-project#6339)
  [bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (vllm-project#6491)
  perf: adaptive block size selection in linear_persistent kernel (vllm-project#6537)
  [ModelRunner][Fix] Pads query_start_loc to satisfy FIA/TND constraint (vllm-project#6475)
  [Bugfix]Fix of Pooling Code and Update of Pooling Usage Guide (vllm-project#6126)
  [Fusion] Add rmsnorm dynamic quant fusion pass (vllm-project#6274)
  [Bugfix] Synchronize only the current stream to avoid device sync (vllm-project#6432)
  [CI] Add long and short prompt tests for DeepSeek-V3.2 (vllm-project#6499)
  [Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (vllm-project#6442)
  [bugfix][npugraph_ex]duplicate pattern issue (vllm-project#6513)
  [bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (vllm-project#6430)
  [Quant] GLM4.7-Flash Support W8A8 (vllm-project#6492)
  [Nightly][BugFix] Remove kv_cache nz test case for test_mla_preprocess_nq.py (vllm-project#6505)
  ...
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
jiangyunfan1 pushed a commit to jiangyunfan1/vllm-ascend that referenced this pull request Apr 9, 2026
### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants