Skip to content

[Attention] Support distinguishing between short extends and decodes#37303

Merged
LucasWilkinson merged 4 commits intovllm-project:mainfrom
neuralmagic:nemotron-h-mtp-4way-batch-split
Mar 20, 2026
Merged

[Attention] Support distinguishing between short extends and decodes#37303
LucasWilkinson merged 4 commits intovllm-project:mainfrom
neuralmagic:nemotron-h-mtp-4way-batch-split

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Mar 17, 2026

Alternative to #35447, support distinguishing between short-extends/prefills and decodes via batch reordering; the batch order is now:

        decode:        (num_scheduled <= threshold AND is not prefilling)
        short_extend:  (num_scheduled <= threshold AND is chunked prefilling)
        long_extend:   (num_scheduled > threshold AND is chunked prefilling)
        prefill:       (num_computed == 0)   # First chunks

@mergify mergify bot added the v1 label Mar 17, 2026
@LucasWilkinson LucasWilkinson marked this pull request as ready for review March 17, 2026 14:41
@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 17, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request refines the batch reordering logic to categorize requests into four types: decode, short_extend, long_extend, and prefill, based on num_scheduled_tokens, num_computed_tokens, and a newly introduced num_prompt_tokens. This involves updating the MockInputBatch and GPUInputBatch classes to include num_prompt_tokens, adding an is_prefilling flag to CommonAttentionMetadata, and modifying the split_decodes_and_prefills and reorder_batch_to_split_decodes_and_prefills functions to leverage this detailed classification. The mamba_attn backend is updated to utilize this new classification by explicitly setting treat_short_extends_as_decodes=False. Concurrently, legacy logic related to handling 'prefill as decode' in gpu_model_runner.py and mamba_utils.py is removed. The review comments emphasize the critical need to update all usages of the MockInputBatch constructor due to its changed signature and suggest adding a clarifying comment for the REORDER_TEST_CASES dictionary format to improve readability.


class MockInputBatch:
def __init__(self, req_ids, num_computed_tokens_cpu):
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The MockInputBatch class now requires num_prompt_tokens. Ensure all usages of this class are updated to include this parameter to avoid unexpected behavior or errors. This is a critical change as it affects the instantiation of this mock class throughout the tests.

Suggested change
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
self.req_ids = req_ids
self.num_computed_tokens_cpu = num_computed_tokens_cpu
self.num_prompt_tokens = num_prompt_tokens



# Test cases for batch reordering
# Format: (num_scheduled, num_computed, num_prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding a comment to describe the format of the REORDER_TEST_CASES dictionary entries improves readability and maintainability. It's important to clarify what each value in the tuple represents.

Suggested change
# Format: (num_scheduled, num_computed, num_prompt)
# Format: (num_scheduled, num_computed, num_prompt)
REORDER_TEST_CASES = {

@benchislett
Copy link
Collaborator

does this pass the test added in #35447?

@benchislett
Copy link
Collaborator

How does this handle "short prefill <= threshold, no context"? It's not extend but is below threshold. Does it get classified as prefill or decode?

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Mar 17, 2026

How does this handle "short prefill <= threshold, no context"? It's not extend but is below threshold. Does it get classified as prefill or decode?

"pure prefills" i.e. no-context are always placed at the back, this is for the AMD attention backend

does this pass the test added in #35447?

the first 2 pass the second 2 OOM (I assume because im on H100s); will run it in the CI

@benchislett
Copy link
Collaborator

benchislett commented Mar 17, 2026

Ah, I might need run that test as TP4 instead of TP2 :(

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the cleanup.

I would like to see that specific test case passing before we merge this, to ensure that the nemotron-h-mtp-chunkedprefill case is covered

slot_mapping_attn = slot_mappings[attn_gid]
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
# Compute is_prefilling: True if request is still in prefill phase
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: non-mamba-specific logic is more general and consistent with other comments

Suggested change
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
# (num_computed_tokens < num_prompt_tokens). Used by some backends to

@mergify mergify bot added the ci/build label Mar 17, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson force-pushed the nemotron-h-mtp-4way-batch-split branch from 8ab178f to f31161d Compare March 18, 2026 18:52
@LucasWilkinson LucasWilkinson merged commit e1d85e5 into vllm-project:main Mar 20, 2026
61 of 62 checks passed
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants