Skip to content

[Bugfix] Fix problematic dummy_run & improper input_batch_size in eagle#6517

Merged
wangxiyuan merged 4 commits intovllm-project:mainfrom
slippersss:bugfix_merge
Feb 7, 2026
Merged

[Bugfix] Fix problematic dummy_run & improper input_batch_size in eagle#6517
wangxiyuan merged 4 commits intovllm-project:mainfrom
slippersss:bugfix_merge

Conversation

@slippersss
Copy link
Copy Markdown
Contributor

@slippersss slippersss commented Feb 3, 2026

What this PR does / why we need it?

This PR aims to fix problematic dummy_run that will cause excessive npu memory and to fix improper input_batch_size that will degrade running performance.

Does this PR introduce any user-facing change?

N/A

How was this patch tested?

by ci

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @slippersss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces critical bug fixes within the eagle_proposer module, specifically targeting performance and memory efficiency. It refines the batch size calculation in the dummy_run function to prevent excessive NPU memory consumption during profiling and adjusts the input_batch_size logic in _run_merged_draft to ensure optimal performance by correctly handling batch sizes based on CUDA graph usage.

Highlights

  • Dummy Run Batch Size Calculation: The batch_size in the dummy_run function is now conditionally set to self.runner.max_num_reqs when is_profile is true. This change addresses excessive NPU memory usage that occurred during profiling with the previous calculation.
  • Input Batch Size Logic Refinement: The input_batch_size in the _run_merged_draft function is now dynamically determined. It uses num_input_tokens if self.use_cuda_graph is enabled, otherwise it defaults to batch_size, preventing performance degradation caused by improper batch sizing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/spec_decode/eagle_proposer.py
    • Modified the dummy_run method to conditionally set the batch_size based on the is_profile flag, using self.runner.max_num_reqs when profiling.
    • Updated the _run_merged_draft method to conditionally determine input_batch_size based on the self.use_cuda_graph flag.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 3, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two fixes for speculative decoding in eagle_proposer.py. The first change adjusts the batch_size calculation during profiling runs in dummy_run to prevent excessive memory usage. The second change corrects the input_batch_size in _run_merged_draft for non-graph execution paths to improve performance. While the second change appears correct, the first change introduces a critical bug that could lead to an index out-of-bounds error. My review includes a comment with a suggested fix for this issue.

model_previous_hidden_states = self.hidden_states[:num_tokens]

batch_size = num_tokens // (self.num_speculative_tokens + 1)
batch_size = num_tokens // (self.num_speculative_tokens + 1) if not is_profile else self.runner.max_num_reqs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of batch_size when is_profile is true can lead to an IndexError. batch_size is set to self.runner.max_num_reqs, but num_tokens (the input to dummy_run) can be smaller than self.runner.max_num_reqs. This can cause issues later in _run_merged_draft where tensors of size num_tokens are indexed with values related to batch_size.

To fix this, batch_size should first be derived from num_tokens and then capped by self.runner.max_num_reqs during profiling to ensure it doesn't exceed the number of available tokens.

        batch_size = num_tokens // (self.num_speculative_tokens + 1)
        if is_profile:
            batch_size = min(batch_size, self.runner.max_num_reqs)

yiz-liu pushed a commit that referenced this pull request Feb 3, 2026
…e in eagle (#6518)

### What this PR does / why we need it?
This PR is cherry-picked from #6517.

This PR aims to fix problematic dummy_run that will cause excessive npu
memory and to fix improper input_batch_size that will degrade running
performance.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

---------

Signed-off-by: Zetong Li <slippersss@126.com>
@linfeng-yuan linfeng-yuan added ready read for review ready-for-test start test by label for PR labels Feb 5, 2026

if self.method == "mtp":
input_batch_size = num_input_tokens
else:
Copy link
Copy Markdown

@winson-00178005 winson-00178005 Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段逻辑可以表述为“
input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

slippersss and others added 4 commits February 6, 2026 18:44
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: Zetong Li <slippersss@126.com>
@wangxiyuan wangxiyuan merged commit 4fa7cf6 into vllm-project:main Feb 7, 2026
27 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 9, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend:
  [Patch] Remove the patch of MiniCPM (vllm-project#5975)
  [P/D] layerwise connector support recompute scheduler (vllm-project#5900)
  [CI] Add workflow support for lint image build (vllm-project#6489)
  [Bugfix] Fix problematic dummy_run & improper input_batch_size in eagle (vllm-project#6517)
  [Refactor]310p_e2e test case update (vllm-project#6539)
  [Refactor]refactor p2p connector (vllm-project#6551)
  [Refactor]refactor 310p attention impl and add ut (vllm-project#6579)
  [Refactor]refactor 310p ops and add ut (vllm-project#6591)
  [Ops][Refactor] Remove custom rotary_embedding operator (vllm-project#6523)
  [Lint]Style: Convert `vllm-ascend/` to ruff format(new Batch vllm-project#8) (vllm-project#6604)
  [Test] Add initial multi modal cases of Qwen2.5-VL-7B-Instruct for disaggregated encoder  (vllm-project#5301)
  [CI] Fix broken CI (vllm-project#6599)
  [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch vllm-project#10) (vllm-project#6173)
  [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch vllm-project#11) (vllm-project#6176)
  [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch vllm-project#8) (vllm-project#6129)
  [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch vllm-project#7) (vllm-project#6023)
  [CI][Misc] Some improvement for github action (vllm-project#6587)
  [Image] Bump mooncake version to v0.3.8.post1 (vllm-project#6428)
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…le (vllm-project#6517)

### What this PR does / why we need it?
This PR aims to fix problematic dummy_run that will cause excessive npu
memory and to fix improper input_batch_size that will degrade running
performance.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Co-authored-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…le (vllm-project#6517)

### What this PR does / why we need it?
This PR aims to fix problematic dummy_run that will cause excessive npu
memory and to fix improper input_batch_size that will degrade running
performance.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Co-authored-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…le (vllm-project#6517)

### What this PR does / why we need it?
This PR aims to fix problematic dummy_run that will cause excessive npu
memory and to fix improper input_batch_size that will degrade running
performance.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Co-authored-by: lilinsiman <lilinsiman@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…le (vllm-project#6517)

### What this PR does / why we need it?
This PR aims to fix problematic dummy_run that will cause excessive npu
memory and to fix improper input_batch_size that will degrade running
performance.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Co-authored-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…le (vllm-project#6517)

### What this PR does / why we need it?
This PR aims to fix problematic dummy_run that will cause excessive npu
memory and to fix improper input_batch_size that will degrade running
performance.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Co-authored-by: lilinsiman <lilinsiman@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants