[Attention] Support distinguishing between short extends and decodes#37303
Conversation
There was a problem hiding this comment.
Code Review
The pull request refines the batch reordering logic to categorize requests into four types: decode, short_extend, long_extend, and prefill, based on num_scheduled_tokens, num_computed_tokens, and a newly introduced num_prompt_tokens. This involves updating the MockInputBatch and GPUInputBatch classes to include num_prompt_tokens, adding an is_prefilling flag to CommonAttentionMetadata, and modifying the split_decodes_and_prefills and reorder_batch_to_split_decodes_and_prefills functions to leverage this detailed classification. The mamba_attn backend is updated to utilize this new classification by explicitly setting treat_short_extends_as_decodes=False. Concurrently, legacy logic related to handling 'prefill as decode' in gpu_model_runner.py and mamba_utils.py is removed. The review comments emphasize the critical need to update all usages of the MockInputBatch constructor due to its changed signature and suggest adding a clarifying comment for the REORDER_TEST_CASES dictionary format to improve readability.
|
|
||
| class MockInputBatch: | ||
| def __init__(self, req_ids, num_computed_tokens_cpu): | ||
| def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens): |
There was a problem hiding this comment.
The MockInputBatch class now requires num_prompt_tokens. Ensure all usages of this class are updated to include this parameter to avoid unexpected behavior or errors. This is a critical change as it affects the instantiation of this mock class throughout the tests.
| def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens): | |
| def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens): | |
| self.req_ids = req_ids | |
| self.num_computed_tokens_cpu = num_computed_tokens_cpu | |
| self.num_prompt_tokens = num_prompt_tokens |
|
|
||
|
|
||
| # Test cases for batch reordering | ||
| # Format: (num_scheduled, num_computed, num_prompt) |
There was a problem hiding this comment.
Adding a comment to describe the format of the REORDER_TEST_CASES dictionary entries improves readability and maintainability. It's important to clarify what each value in the tuple represents.
| # Format: (num_scheduled, num_computed, num_prompt) | |
| # Format: (num_scheduled, num_computed, num_prompt) | |
| REORDER_TEST_CASES = { |
|
does this pass the test added in #35447? |
|
How does this handle "short prefill <= threshold, no context"? It's not extend but is below threshold. Does it get classified as prefill or decode? |
"pure prefills" i.e. no-context are always placed at the back, this is for the AMD attention backend
the first 2 pass the second 2 OOM (I assume because im on H100s); will run it in the CI |
|
Ah, I might need run that test as TP4 instead of TP2 :( |
benchislett
left a comment
There was a problem hiding this comment.
LGTM, thanks for the cleanup.
I would like to see that specific test case passing before we merge this, to ensure that the nemotron-h-mtp-chunkedprefill case is covered
| slot_mapping_attn = slot_mappings[attn_gid] | ||
| self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() | ||
| # Compute is_prefilling: True if request is still in prefill phase | ||
| # (num_computed_tokens < num_prompt_tokens). Used by mamba backends to |
There was a problem hiding this comment.
nit: non-mamba-specific logic is more general and consistent with other comments
| # (num_computed_tokens < num_prompt_tokens). Used by mamba backends to | |
| # (num_computed_tokens < num_prompt_tokens). Used by some backends to |
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
8ab178f to
f31161d
Compare
…llm-project#37303) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Alternative to #35447, support distinguishing between short-extends/prefills and decodes via batch reordering; the batch order is now: