Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tests/v1/attention/test_attention_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):


def apply_split_decodes_and_prefills(
query_lens: list[int], decode_threshold: int, require_uniform: bool
query_lens: list[int],
decode_threshold: int,
require_uniform: bool,
padded_num_tokens: int | None = None,
):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
Expand All @@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
block_size=16,
device=device,
)

if padded_num_tokens is not None:
common_metadata.num_actual_tokens = padded_num_tokens

return split_decodes_and_prefills(
common_metadata,
decode_threshold=decode_threshold,
Expand Down Expand Up @@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens


def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
# This triggers the padded uniform path at line 891
query_lens = [2, 2, 2, 0]
padded_num_tokens = 8
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
)
# With uniform batch, all requests are treated as decodes
assert num_decodes == 4
assert num_prefills == 0
assert num_decode_tokens == padded_num_tokens
assert num_prefill_tokens == 0


@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[
Expand Down
10 changes: 7 additions & 3 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,11 +883,15 @@ def split_decodes_and_prefills(
return 0, num_reqs, 0, num_tokens

if require_uniform:
# check if we are in a padded uniform batch; this is used for full-CGs, some
# requests may have a query length of 0 but since they are padding its fine
# to treat them as decodes (ensures num_decodes matches the captured size)
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
return num_reqs, 0, num_tokens, 0 # all decodes
is_prefill = query_lens != query_lens[0]
else:
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
is_prefill = query_lens > decode_threshold

if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
Expand Down