Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .buildkite/test_areas/engine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ steps:
device: mi325_4
depends_on:
- image-build-amd

- label: V1 e2e (4xH100)
timeout_in_minutes: 60
device: h100
num_devices: 4
optional: true
source_file_dependencies:
- vllm/v1/attention/backends/utils.py
- vllm/v1/worker/gpu_model_runner.py
- tests/v1/e2e/test_hybrid_chunked_prefill.py
commands:
- pytest -v -s v1/e2e/test_hybrid_chunked_prefill.py
123 changes: 86 additions & 37 deletions tests/v1/attention/test_batch_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@


class MockInputBatch:
def __init__(self, req_ids, num_computed_tokens_cpu):
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The MockInputBatch class now requires num_prompt_tokens. Ensure all usages of this class are updated to include this parameter to avoid unexpected behavior or errors. This is a critical change as it affects the instantiation of this mock class throughout the tests.

Suggested change
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
self.req_ids = req_ids
self.num_computed_tokens_cpu = num_computed_tokens_cpu
self.num_prompt_tokens = num_prompt_tokens

self.req_ids = req_ids
self.num_computed_tokens_cpu = num_computed_tokens_cpu
self.num_prompt_tokens = num_prompt_tokens

def swap_states(self, i, j):
self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i]
self.num_computed_tokens_cpu[i], self.num_computed_tokens_cpu[j] = (
self.num_computed_tokens_cpu[j],
self.num_computed_tokens_cpu[i],
)
self.num_prompt_tokens[i], self.num_prompt_tokens[j] = (
self.num_prompt_tokens[j],
self.num_prompt_tokens[i],
)


class MockSchedulerOutput:
Expand All @@ -29,96 +34,139 @@ def __init__(self, num_scheduled_tokens):

@dataclass
class ReorderTestCase:
requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens)
# (num_scheduled_tokens, num_computed_tokens, num_prompt_tokens)
requests: list[tuple[int, int, int]]
expected_order: list[int]
expected_modified: bool
decode_threshold: int = 1


# Test cases for batch reordering
# Format: (num_scheduled, num_computed, num_prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding a comment to describe the format of the REORDER_TEST_CASES dictionary entries improves readability and maintainability. It's important to clarify what each value in the tuple represents.

Suggested change
# Format: (num_scheduled, num_computed, num_prompt)
# Format: (num_scheduled, num_computed, num_prompt)
REORDER_TEST_CASES = {

REORDER_TEST_CASES = {
"all_decodes": ReorderTestCase(
requests=[(1, 10), (1, 20), (1, 30)],
requests=[(1, 10, 10), (1, 20, 20), (1, 30, 30)],
expected_order=[0, 1, 2],
expected_modified=False,
),
"all_prefills": ReorderTestCase(
requests=[(100, 100), (200, 200), (300, 300)],
"all_long_extends": ReorderTestCase(
requests=[(100, 100, 100), (200, 200, 200), (300, 300, 300)],
expected_order=[0, 1, 2],
expected_modified=False,
),
"mixed_interleaved": ReorderTestCase(
requests=[(100, 100), (1, 10), (200, 200), (1, 20)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
"mixed_decodes_long_extends": ReorderTestCase(
requests=[(100, 100, 100), (1, 10, 10), (200, 200, 200), (1, 20, 20)],
expected_order=[3, 1, 2, 0],
expected_modified=True,
),
"already_ordered": ReorderTestCase(
requests=[(1, 10), (1, 20), (100, 100), (200, 0)],
requests=[(1, 10, 10), (1, 20, 20), (100, 100, 100), (200, 0, 200)],
expected_order=[0, 1, 2, 3],
expected_modified=False,
),
"single_request": ReorderTestCase(
requests=[(1, 10)],
requests=[(1, 10, 10)],
expected_order=[0],
expected_modified=False,
),
"higher_threshold": ReorderTestCase(
requests=[(2, 10), (3, 20), (5, 30), (6, 40)],
requests=[(2, 10, 10), (3, 20, 20), (5, 30, 30), (6, 40, 40)],
expected_order=[0, 1, 2, 3],
expected_modified=False,
decode_threshold=4,
),
"decodes_at_end": ReorderTestCase(
requests=[(100, 100), (200, 200), (1, 10), (1, 20)],
requests=[(100, 100, 100), (200, 200, 200), (1, 10, 10), (1, 20, 20)],
expected_order=[2, 3, 0, 1],
expected_modified=True,
),
"decode_extend_prefill": ReorderTestCase(
requests=[(100, 0), (10, 50), (1, 10)],
"decode_long_extend_prefill": ReorderTestCase(
requests=[(100, 0, 100), (10, 50, 50), (1, 10, 10)],
expected_order=[2, 1, 0],
expected_modified=True,
),
"extend_prefill_only": ReorderTestCase(
requests=[(100, 0), (10, 50), (200, 0), (20, 75)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
"long_extend_prefill_only": ReorderTestCase(
requests=[(100, 0, 100), (10, 50, 50), (200, 0, 200), (20, 75, 75)],
expected_order=[3, 1, 2, 0],
expected_modified=True,
),
"complicated_mixed_interleaved": ReorderTestCase(
"complicated_mixed": ReorderTestCase(
requests=[
(1, 20),
(1, 50),
(374, 0),
(300, 20),
(1, 20),
(256, 0),
(1, 5),
(27, 0),
(1, 4),
(1, 20, 20), # decode
(1, 50, 50), # decode
(374, 0, 374), # prefill
(300, 20, 20), # long_extend
(1, 20, 20), # decode
(256, 0, 256), # prefill
(1, 5, 5), # decode
(27, 0, 27), # prefill
(1, 4, 4), # decode
],
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
expected_modified=True,
),
"new_request_single_token_prefill": ReorderTestCase(
requests=[
(100, 0),
(1, 0), # New request with only 1 token (STILL prefill)
(50, 100),
(1, 10),
(100, 0, 100), # prefill
(1, 0, 1), # prefill (single token, still prefill)
(50, 100, 100), # long_extend
(1, 10, 10), # decode
],
# Only index 3 is a true decode (has num_computed_tokens > 0)
expected_order=[3, 2, 0, 1],
expected_modified=True,
),
"multiple_new_requests_single_token_prefill": ReorderTestCase(
requests=[
(1, 0), # New prefill (1 token, no computed)
(1, 0), # New prefill (1 token, no computed)
(1, 50),
(200, 0),
(1, 0, 1), # prefill
(1, 0, 1), # prefill
(1, 50, 50), # decode
(200, 0, 200), # prefill
],
expected_order=[2, 1, 0, 3],
expected_modified=True,
),
"four_way_already_ordered": ReorderTestCase(
requests=[
(1, 100, 100), # decode
(1, 50, 100), # short_extend
(10, 50, 100), # long_extend
(100, 0, 100), # prefill
],
expected_order=[0, 1, 2, 3],
expected_modified=False,
),
"four_way_needs_reorder": ReorderTestCase(
requests=[
(100, 0, 100), # prefill
(1, 50, 100), # short_extend
(1, 100, 100), # decode
(10, 50, 100), # long_extend
],
expected_order=[2, 1, 3, 0],
expected_modified=True,
),
"four_way_multiple_short_extends": ReorderTestCase(
requests=[
(2, 100, 100), # decode
(2, 50, 200), # short_extend
(2, 75, 150), # short_extend
(2, 200, 200), # decode
],
expected_order=[0, 3, 2, 1],
expected_modified=True,
decode_threshold=2,
),
"four_way_spec_decode_threshold": ReorderTestCase(
requests=[
(5, 100, 100), # decode
(5, 50, 100), # short_extend
(5, 0, 100), # prefill
(10, 50, 100), # long_extend
],
expected_order=[0, 1, 3, 2],
expected_modified=True,
decode_threshold=5,
),
}


Expand All @@ -129,8 +177,9 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase)
req_ids = [f"r{i}" for i in range(len(test_case.requests))]
num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32)
num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)}
num_prompt_tokens = np.array([r[2] for r in test_case.requests], dtype=np.int32)

input_batch = MockInputBatch(req_ids, num_computed_tokens)
input_batch = MockInputBatch(req_ids, num_computed_tokens, num_prompt_tokens)
scheduler_output = MockSchedulerOutput(num_scheduled_tokens)

modified = reorder_batch_to_split_decodes_and_prefills(
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/e2e/test_hybrid_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]),
pytest.param(
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2),
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=4),
),
],
)
Expand All @@ -68,7 +68,7 @@ def test_mtp_speculative_mixed_batch_short_prefill(
max_num_batched_tokens=chunk_size,
max_model_len=512,
enforce_eager=True,
tensor_parallel_size=2,
tensor_parallel_size=4,
trust_remote_code=True,
enable_chunked_prefill=True,
enable_prefix_caching=enable_prefix_caching,
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""

is_prefilling: torch.Tensor | None = None
"""(batch_size,) bool tensor: True if request is still in prefill phase
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""

# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
Expand Down Expand Up @@ -443,6 +448,7 @@ def unpadded(
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
is_prefilling=maybe_slice_reqs(self.is_prefilling),
)


Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def _compute_common_metadata(

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold
common_attn_metadata,
decode_threshold=decode_threshold,
treat_short_extends_as_decodes=False,
)
)

Expand Down
Loading
Loading