Skip to content

[Bugfix] Fix NemotronH MTP + Chunked Prefill#35447

Merged
tdoublep merged 16 commits intovllm-project:mainfrom
CentML:nemotron-h-mtp-chunkedprefill-bugfix
Mar 17, 2026
Merged

[Bugfix] Fix NemotronH MTP + Chunked Prefill#35447
tdoublep merged 16 commits intovllm-project:mainfrom
CentML:nemotron-h-mtp-chunkedprefill-bugfix

Conversation

@benchislett
Copy link
Copy Markdown
Collaborator

@benchislett benchislett commented Feb 26, 2026

Purpose

Functional

Test Plan

This branch adds a reproducer which causes garbage outputs with NemotronH MTP + Chunked Prefill.

It does not seem to happen with Qwen3-Next due to differences in how they separate out the spec decodes from non-spec decodes (GDN checks num_draft_tokens_cpu and dynamically splits the decodes into spec and non-spec).

Test Result

The diff in this branch fixes the reproducer giving results consistent with baseline. Further evaluation required.

@mergify mergify bot added v1 bug Something isn't working labels Feb 26, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in NemotronH models when using Multi-Token Prediction (MTP) with chunked prefill. The core fix in vllm/v1/attention/backends/mamba_attn.py correctly handles cases where small prefill chunks are misclassified as decodes, preventing an assertion failure. While this fix is sound, the pull request includes some changes that require attention. There's a block of dead code in vllm/v1/worker/gpu_model_runner.py that seems to be a work-in-progress and should be cleaned up before merging. Additionally, an unrelated change in vllm/model_executor/layers/layernorm.py has been identified, which should be reverted to avoid potential side effects.

@benchislett benchislett marked this pull request as ready for review February 27, 2026 17:49
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 27, 2026

Hi @benchislett, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 27, 2026

Hi @benchislett, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 3, 2026

Hi @benchislett, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett force-pushed the nemotron-h-mtp-chunkedprefill-bugfix branch from 206986e to 7b5f9a7 Compare March 3, 2026 23:04
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 3, 2026

Hi @benchislett, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@benchislett
Copy link
Copy Markdown
Collaborator Author

Not sure why pre-commit is breaking. Having a hard time figuring out what the discrepancy is

Comment on lines +1222 to +1224
# with query_len <= reorder_batch_threshold as "decodes". Prefill
# chunks that fall under this threshold get processed via the decode
# path, which stores intermediate states at sequential slots. We must
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to have these prefill chunks processed by the decode path?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how it's done in all other attention backends. It's only GDN attention that does it differently, by manually splitting out the spec-decodes and non-spec-decodes. In my opinion, that strategy is inefficient and more difficult to maintain.

Comment on lines +694 to +695
# If this model has mamba2 layers, we handle num_accepted_tokens_cpu differently
self.is_mamba2_hybrid: bool = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we'd prefer not to have mamba2 specific logic in GPU model runner if possible

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I can try to refactor the behaviour out of gpu model runner and into a helper somewhere

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the core logic into mamba_utils and renamed this toggle. But I'm still not sure what's the cleanest way to set this flag so that we can dispatch based on mamba2 vs gdn attention

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 4, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot removed the needs-rebase label Mar 5, 2026
Copy link
Copy Markdown
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM but would like @asafgardin's eyes on it from the Mamba1 perspective

@tdoublep tdoublep added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
@tdoublep
Copy link
Copy Markdown
Member

Seems like there is a different (?) corner case being explored here: #32716

@Josephasafg
Copy link
Copy Markdown
Contributor

Generally LGTM but would like @asafgardin's eyes on it from the Mamba1 perspective

Thanks @tdoublep
Since Mamba1 does not yet have support for speculative decoding, this change should be ok

@benchislett
Copy link
Copy Markdown
Collaborator Author

No idea why the hybrid test is hanging in CI. It passes locally.

@benchislett
Copy link
Copy Markdown
Collaborator Author

Looking into it.

if self.use_spec_decode:
if self.use_spec_decode and num_accepted_tokens is not None:
assert query_start_loc_d is not None
assert num_accepted_tokens is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this assertion still necessary?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope :p

@benchislett
Copy link
Copy Markdown
Collaborator Author

Tagging @vadiklyutiy to help with the CI issue, I cannot reproduce it

@tdoublep
Copy link
Copy Markdown
Member

I can reproduce the hang locally on an L4 GPU (but not on H100)

@tdoublep
Copy link
Copy Markdown
Member

It looks like for the test that's hanging we only have 38 blocks available on L4, and the request requires 100+ blocks (we need one block per speculative token in MTP for hybrid models) so it just sits there waiting to be scheduled. Can we change the test to require less blocks?

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett
Copy link
Copy Markdown
Collaborator Author

added marks to limit the test to >= H100, and added coverage of nemotron-h + MTP

@tdoublep tdoublep merged commit 8a68046 into vllm-project:main Mar 17, 2026
62 checks passed
zhenwei-intel pushed a commit to zhenwei-intel/vllm that referenced this pull request Mar 17, 2026
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
@benchislett benchislett deleted the nemotron-h-mtp-chunkedprefill-bugfix branch March 17, 2026 22:09
andylolu2 pushed a commit to andylolu2/vllm that referenced this pull request Mar 18, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: wendyliu235 <wenjun.liu@intel.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants