[Bugfix] Fix MTP edge case in split_decodes_and_prefills#32716
[Bugfix] Fix MTP edge case in split_decodes_and_prefills#32716Josephasafg wants to merge 19 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request correctly identifies new requests and ensures they are treated as prefills, which is a necessary change for certain models that require specific handling for new prompts. The implementation introduces an is_new_request flag and applies it consistently. However, I've identified a critical issue with a special-case heuristic for CUDA graph capture. This heuristic may misclassify legitimate batches of short prefills as decodes, potentially causing correctness issues for models like Mamba. This conflict needs to be resolved.
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
88cf726 to
59b96ff
Compare
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
@pisceskkk I moved the PR from draft. Sorry it took time. We can discuss here, let me know what you think |
|
Thanks for the fix! That addresses my concerns! |
vllm/v1/attention/backends/utils.py
Outdated
| # A new request has no prior context (num_computed_tokens == 0). | ||
| # New requests need prefill treatment even if | ||
| # query_lens <= decode_threshold (e.g., for Mamba state init). | ||
| num_computed = common_attn_metadata._num_computed_tokens_cpu |
There was a problem hiding this comment.
_num_computed_tokens_cpu is deprecated (interferes with async scheduling) and should not be used, we will have to come up with something else here
we probably need a has_context flag in common_attn_metadata or something like that instead
There was a problem hiding this comment.
Makes sense. I'll look into adding a has_context flag on CommonAttentionMetadata instead
There was a problem hiding this comment.
@LucasWilkinson I added has_context to CommonAttentionMetadata. Can you please take a look? Thanks
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
@Josephasafg Thanks for the contribution! Could you somehow enable this for mamba only, or figure out a different way to trigger the state initialization? Treating short prefills as decodes is intentional for performance in MLA. For example, |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Thanks for the feedback @MatthewBonanni! I've reworked the approach - removed The idea is that instead of modifying the generic |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
Hi @Josephasafg, thanks for updating the PR! In the interest of simplication and minimizing changes to For further performance improvement, we could also modify |
@MatthewBonanni Thanks! That was actually my initial approach - zeroing inside the mixer. But it didn't play well with CUDA graphs so I added this in every mixer: if has_initial_states_d is not None:
indices = state_indices_tensor_d_input
ssm_gathered = ssm_state[indices]
keep_ssm = has_initial_states_d.to(ssm_gathered.dtype)
keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1)))
ssm_state[indices] = ssm_gathered * keep_ssm
conv_gathered = conv_state[indices]
keep_conv = has_initial_states_d.to(conv_gathered.dtype)
keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1)))
conv_state[indices] = conv_gathered * keep_convThe reason I thought not to add it, was because this requires changes all across different mamba based models ( If you think this way is better, I don't mind changing it. See all changes here - Josephasafg#5 |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
| has_initial_states_d = (num_computed_tokens[:num_decodes] > 0) | ( | ||
| common_attn_metadata.seq_lens[:num_decodes] == 0 | ||
| ) | ||
| if has_initial_states_d.all(): |
There was a problem hiding this comment.
I think this might cause a cpu/gpu sync because has_initial_states_d is a GPU tensor?
There was a problem hiding this comment.
@benchislett You're right but in order to avoid the cpu<>gpu sync I would have to always run this code or run this code (if we decide to keep this version) on every decode step, regardless if has_initial_states_d has any false values. (this could be a cheap "no-op" because we're multiplying by 1 but still or looping over the indices in this PR version. depends what we decide to go with.
Do you maybe have another suggestion? I am seeing other functions using .cpu() which I can also use instead of .all() if better but not sure if its any different
See also #35447 and the included test which reproduces an issue (possibly this specific issue) |
|
|
Nemotron Super supports MTP |
|
Regarding the zero ssm bugfix, can you explain in what way the implementation of that PR differs from this one? As I understand, both work by clearing newly allocated KV blocks, including mamba blocks |
@benchislett From what the #35219 describes - it only zeros Happy to hear your thoughts |
|
Makes sense. Do you have a reproducer for this issue? |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Mtp edge case mod
@benchislett I managed to reproduce it with this script (not sure it should be a test due to the size of the model) from vllm import LLM, SamplingParams
NUM_SPECULATIVE_TOKENS = 3
DECODE_THRESHOLD = 1 + NUM_SPECULATIVE_TOKENS
llm = LLM(
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
speculative_config={
"method": "mtp",
"num_speculative_tokens": NUM_SPECULATIVE_TOKENS,
},
max_num_batched_tokens=2 * DECODE_THRESHOLD + 1,
max_model_len=512,
max_num_seqs=3,
enforce_eager=True,
tensor_parallel_size=2,
trust_remote_code=True,
)
LONG = ("The quick brown fox jumps over the lazy dog. " * 25)
MEDIUM = ("A journey of a thousand miles begins with a single step. " * 4)
SHORT = "Once upon a time in a land far away, there lived a"
results = llm.generate(
[LONG, MEDIUM, SHORT],
SamplingParams(temperature=0.0, max_tokens=100),
)
print(results[2].outputs[0].text) |
|
@Josephasafg this is basically the same reproducer as #35447. Could you review that code and assess if there are any cases in which that fix is insufficient? |
|
I believe #35447 targets the same issue so I can close my PR in favor of it. I believe mamba models should be good in terms of misclassifications after it lands |
Purpose
Followup to #32118. @pisceskkk raised that
split_decodes_and_prefillscould misclassify new requests as decodes in MTP scenarios (decode_threshold > 1). A new request with a short prompt (e.g., 3 tokens withdecode_threshold=4) would be treated as decode, skipping state initialization needed by models like mamba.At the moment, requests are classified purely by
query_lens- ifmax_query_len <= decode_threshold, the entire batch is treated as decode and has no concept of "new request" at all.Fix:
has_contexttoCommonAttentionMetadata(num_computed_tokens > 0), avoiding the deprecated _num_computed_tokens_cpu which interferes with async scheduling as @LucasWilkinson mentioned.has_contextinsplit_decodes_and_prefillsto detect new requests and classify them as prefill regardless of query lengthnum_computed_tokens_cpuin_dummy_runso Full CUDA graph capture batches gethas_context=TrueTest Plan
Added unittests
Test Result
Tests passing
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.