Skip to content

[Bugfix] Fix MTP edge case in split_decodes_and_prefills#32716

Closed
Josephasafg wants to merge 19 commits intovllm-project:mainfrom
Josephasafg:mtp_edge_case
Closed

[Bugfix] Fix MTP edge case in split_decodes_and_prefills#32716
Josephasafg wants to merge 19 commits intovllm-project:mainfrom
Josephasafg:mtp_edge_case

Conversation

@Josephasafg
Copy link
Copy Markdown
Contributor

@Josephasafg Josephasafg commented Jan 20, 2026

Purpose

Followup to #32118. @pisceskkk raised that split_decodes_and_prefills could misclassify new requests as decodes in MTP scenarios (decode_threshold > 1). A new request with a short prompt (e.g., 3 tokens with decode_threshold=4) would be treated as decode, skipping state initialization needed by models like mamba.

At the moment, requests are classified purely by query_lens - if max_query_len <= decode_threshold, the entire batch is treated as decode and has no concept of "new request" at all.

Fix:

  • Add a has_context to CommonAttentionMetadata (num_computed_tokens > 0), avoiding the deprecated _num_computed_tokens_cpu which interferes with async scheduling as @LucasWilkinson mentioned.
  • Use has_context in split_decodes_and_prefills to detect new requests and classify them as prefill regardless of query length
  • Set num_computed_tokens_cpu in _dummy_run so Full CUDA graph capture batches get has_context=True

Test Plan

Added unittests

Test Result

Tests passing

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Jan 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies new requests and ensures they are treated as prefills, which is a necessary change for certain models that require specific handling for new prompts. The implementation introduces an is_new_request flag and applies it consistently. However, I've identified a critical issue with a special-case heuristic for CUDA graph capture. This heuristic may misclassify legitimate batches of short prefills as decodes, potentially causing correctness issues for models like Mamba. This conflict needs to be resolved.

Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
@Josephasafg Josephasafg changed the title Classify new requests as prefills regardless of query length [Bugfix] Fix MTP edge case in split_decodes_and_prefills Mar 2, 2026
@mergify mergify bot added the bug Something isn't working label Mar 2, 2026
@Josephasafg Josephasafg marked this pull request as ready for review March 2, 2026 18:35
@Josephasafg
Copy link
Copy Markdown
Contributor Author

@pisceskkk I moved the PR from draft. Sorry it took time.

We can discuss here, let me know what you think

@pisceskkk
Copy link
Copy Markdown
Contributor

Thanks for the fix! That addresses my concerns!

# A new request has no prior context (num_computed_tokens == 0).
# New requests need prefill treatment even if
# query_lens <= decode_threshold (e.g., for Mamba state init).
num_computed = common_attn_metadata._num_computed_tokens_cpu
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_num_computed_tokens_cpu is deprecated (interferes with async scheduling) and should not be used, we will have to come up with something else here

we probably need a has_context flag in common_attn_metadata or something like that instead

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'll look into adding a has_context flag on CommonAttentionMetadata instead

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson I added has_context to CommonAttentionMetadata. Can you please take a look? Thanks

Signed-off-by: Josephasafg <ajgard7@gmail.com>
@MatthewBonanni
Copy link
Copy Markdown
Collaborator

@Josephasafg Thanks for the contribution! Could you somehow enable this for mamba only, or figure out a different way to trigger the state initialization? Treating short prefills as decodes is intentional for performance in MLA. For example, FLASH_ATTN_MLA has reorder_batch_threshold=512, even though nobody would ever do MTP with 512 speculative tokens

Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
@Josephasafg Josephasafg requested a review from tdoublep as a code owner March 10, 2026 16:39
@Josephasafg
Copy link
Copy Markdown
Contributor Author

Josephasafg commented Mar 10, 2026

@Josephasafg Thanks for the contribution! Could you somehow enable this for mamba only, or figure out a different way to trigger the state initialization? Treating short prefills as decodes is intentional for performance in MLA. For example, FLASH_ATTN_MLA has reorder_batch_threshold=512, even though nobody would ever do MTP with 512 speculative tokens

Thanks for the feedback @MatthewBonanni! I've reworked the approach - removed has_context from CommonAttentionMetadata entirely and scoped the fix to Mamba only per your suggestion.

The idea is that instead of modifying the generic split_decodes_and_prefills logic (which would mess with MLA's intentional short prefill as decode optimization), the mamba metadata builder now computes a has_initial_states_d tensor that tracks which decode requests are actually new. Then a clear_stale_mamba_states() function runs from gpu_model_runner and zeros the cache state for those slots before the kernel reads them just for mamba models.

Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
@MatthewBonanni
Copy link
Copy Markdown
Collaborator

Hi @Josephasafg, thanks for updating the PR! In the interest of simplication and minimizing changes to gpu_model_runner.py, I proposed an alternative approach in Josephasafg#4. Please let me know your thoughts.

For further performance improvement, we could also modify causal_conv1d_update to accept a has_initial_state similar to causal_conv1d_fn

Signed-off-by: Josephasafg <ajgard7@gmail.com>
@Josephasafg
Copy link
Copy Markdown
Contributor Author

Josephasafg commented Mar 10, 2026

Hi @Josephasafg, thanks for updating the PR! In the interest of simplication and minimizing changes to gpu_model_runner.py, I proposed an alternative approach in Josephasafg#4. Please let me know your thoughts.

For further performance improvement, we could also modify causal_conv1d_update to accept a has_initial_state similar to causal_conv1d_fn

@MatthewBonanni Thanks! That was actually my initial approach - zeroing inside the mixer. But it didn't play well with CUDA graphs so I added this in every mixer:

  if has_initial_states_d is not None:
      indices = state_indices_tensor_d_input

      ssm_gathered = ssm_state[indices]
      keep_ssm = has_initial_states_d.to(ssm_gathered.dtype)
      keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1)))
      ssm_state[indices] = ssm_gathered * keep_ssm

      conv_gathered = conv_state[indices]
      keep_conv = has_initial_states_d.to(conv_gathered.dtype)
      keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1)))
      conv_state[indices] = conv_gathered * keep_conv

The reason I thought not to add it, was because this requires changes all across different mamba based models (*_mixer.py), as oppose to a single point of change for all mamba models.

If you think this way is better, I don't mind changing it. See all changes here - Josephasafg#5

Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
has_initial_states_d = (num_computed_tokens[:num_decodes] > 0) | (
common_attn_metadata.seq_lens[:num_decodes] == 0
)
if has_initial_states_d.all():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might cause a cpu/gpu sync because has_initial_states_d is a GPU tensor?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett You're right but in order to avoid the cpu<>gpu sync I would have to always run this code or run this code (if we decide to keep this version) on every decode step, regardless if has_initial_states_d has any false values. (this could be a cheap "no-op" because we're multiplying by 1 but still or looping over the indices in this PR version. depends what we decide to go with.

Do you maybe have another suggestion? I am seeing other functions using .cpu() which I can also use instead of .all() if better but not sure if its any different

@benchislett
Copy link
Copy Markdown
Collaborator

See also #35447 and the included test which reproduces an issue (possibly this specific issue)

@Josephasafg
Copy link
Copy Markdown
Contributor Author

See also #35447 and the included test which reproduces an issue (possibly this specific issue)

  1. I believe [BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU #35219 handles something differently. dtype mismatch with attention
  2. I think the test you added in [Bugfix] Fix NemotronH MTP + Chunked Prefill #35447 could fit here. Which mamba2 models support mtp?

@benchislett
Copy link
Copy Markdown
Collaborator

Nemotron Super supports MTP

@benchislett
Copy link
Copy Markdown
Collaborator

Regarding the zero ssm bugfix, can you explain in what way the implementation of that PR differs from this one? As I understand, both work by clearing newly allocated KV blocks, including mamba blocks

@Josephasafg
Copy link
Copy Markdown
Contributor Author

Josephasafg commented Mar 11, 2026

Regarding the zero ssm bugfix, can you explain in what way the implementation of that PR differs from this one? As I understand, both work by clearing newly allocated KV blocks, including mamba blocks

@benchislett From what the #35219 describes - it only zeros FullAttentionSpec blocks, because it relies on Mamba/SSM state to overwrite/zero their state on their own, which they should if the initial step is classified as prefill. If it is misclassified as decode, then the state is not zeroed, which is what this PR is meant to achieve for mtp > 1. Also notice it only handles hybrid models, whereas my PR should handle pure mamba models as well.

Happy to hear your thoughts

@benchislett
Copy link
Copy Markdown
Collaborator

Makes sense. Do you have a reproducer for this issue?

Josephasafg and others added 2 commits March 13, 2026 22:38
Signed-off-by: Josephasafg <ajgard7@gmail.com>
@Josephasafg
Copy link
Copy Markdown
Contributor Author

Josephasafg commented Mar 13, 2026

Makes sense. Do you have a reproducer for this issue?

@benchislett I managed to reproduce it with this script (not sure it should be a test due to the size of the model)
I verified it with putting a breakpoint in the for loop inside the zeroing function in abstract.py. Although the ssm_state is not really holding any stale values since im not filling it with too many prior requests.

from vllm import LLM, SamplingParams

NUM_SPECULATIVE_TOKENS = 3
DECODE_THRESHOLD = 1 + NUM_SPECULATIVE_TOKENS

llm = LLM(
    "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
    speculative_config={
        "method": "mtp",
        "num_speculative_tokens": NUM_SPECULATIVE_TOKENS,
    },
    max_num_batched_tokens=2 * DECODE_THRESHOLD + 1,
    max_model_len=512,
    max_num_seqs=3,
    enforce_eager=True,
    tensor_parallel_size=2,
    trust_remote_code=True,
)

LONG = ("The quick brown fox jumps over the lazy dog. " * 25)
MEDIUM = ("A journey of a thousand miles begins with a single step. " * 4)
SHORT = "Once upon a time in a land far away, there lived a"

results = llm.generate(
    [LONG, MEDIUM, SHORT],
    SamplingParams(temperature=0.0, max_tokens=100),
)
print(results[2].outputs[0].text)

@benchislett
Copy link
Copy Markdown
Collaborator

@Josephasafg this is basically the same reproducer as #35447. Could you review that code and assess if there are any cases in which that fix is insufficient?

@Josephasafg
Copy link
Copy Markdown
Contributor Author

Josephasafg commented Mar 16, 2026

@benchislett

I believe #35447 targets the same issue so I can close my PR in favor of it.

I believe mamba models should be good in terms of misclassifications after it lands

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants