diff --git a/.buildkite/test_areas/engine.yaml b/.buildkite/test_areas/engine.yaml index be83bab8fa29..ed0df3e4d879 100644 --- a/.buildkite/test_areas/engine.yaml +++ b/.buildkite/test_areas/engine.yaml @@ -70,3 +70,15 @@ steps: device: mi325_4 depends_on: - image-build-amd + +- label: V1 e2e (4xH100) + timeout_in_minutes: 60 + device: h100 + num_devices: 4 + optional: true + source_file_dependencies: + - vllm/v1/attention/backends/utils.py + - vllm/v1/worker/gpu_model_runner.py + - tests/v1/e2e/test_hybrid_chunked_prefill.py + commands: + - pytest -v -s v1/e2e/test_hybrid_chunked_prefill.py diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py index 6265e12f9a7d..f59740238da7 100644 --- a/tests/v1/attention/test_batch_reordering.py +++ b/tests/v1/attention/test_batch_reordering.py @@ -10,9 +10,10 @@ class MockInputBatch: - def __init__(self, req_ids, num_computed_tokens_cpu): + def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens): self.req_ids = req_ids self.num_computed_tokens_cpu = num_computed_tokens_cpu + self.num_prompt_tokens = num_prompt_tokens def swap_states(self, i, j): self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i] @@ -20,6 +21,10 @@ def swap_states(self, i, j): self.num_computed_tokens_cpu[j], self.num_computed_tokens_cpu[i], ) + self.num_prompt_tokens[i], self.num_prompt_tokens[j] = ( + self.num_prompt_tokens[j], + self.num_prompt_tokens[i], + ) class MockSchedulerOutput: @@ -29,96 +34,139 @@ def __init__(self, num_scheduled_tokens): @dataclass class ReorderTestCase: - requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens) + # (num_scheduled_tokens, num_computed_tokens, num_prompt_tokens) + requests: list[tuple[int, int, int]] expected_order: list[int] expected_modified: bool decode_threshold: int = 1 # Test cases for batch reordering +# Format: (num_scheduled, num_computed, num_prompt) REORDER_TEST_CASES = { "all_decodes": ReorderTestCase( - requests=[(1, 10), (1, 20), (1, 30)], + requests=[(1, 10, 10), (1, 20, 20), (1, 30, 30)], expected_order=[0, 1, 2], expected_modified=False, ), - "all_prefills": ReorderTestCase( - requests=[(100, 100), (200, 200), (300, 300)], + "all_long_extends": ReorderTestCase( + requests=[(100, 100, 100), (200, 200, 200), (300, 300, 300)], expected_order=[0, 1, 2], expected_modified=False, ), - "mixed_interleaved": ReorderTestCase( - requests=[(100, 100), (1, 10), (200, 200), (1, 20)], - expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place + "mixed_decodes_long_extends": ReorderTestCase( + requests=[(100, 100, 100), (1, 10, 10), (200, 200, 200), (1, 20, 20)], + expected_order=[3, 1, 2, 0], expected_modified=True, ), "already_ordered": ReorderTestCase( - requests=[(1, 10), (1, 20), (100, 100), (200, 0)], + requests=[(1, 10, 10), (1, 20, 20), (100, 100, 100), (200, 0, 200)], expected_order=[0, 1, 2, 3], expected_modified=False, ), "single_request": ReorderTestCase( - requests=[(1, 10)], + requests=[(1, 10, 10)], expected_order=[0], expected_modified=False, ), "higher_threshold": ReorderTestCase( - requests=[(2, 10), (3, 20), (5, 30), (6, 40)], + requests=[(2, 10, 10), (3, 20, 20), (5, 30, 30), (6, 40, 40)], expected_order=[0, 1, 2, 3], expected_modified=False, decode_threshold=4, ), "decodes_at_end": ReorderTestCase( - requests=[(100, 100), (200, 200), (1, 10), (1, 20)], + requests=[(100, 100, 100), (200, 200, 200), (1, 10, 10), (1, 20, 20)], expected_order=[2, 3, 0, 1], expected_modified=True, ), - "decode_extend_prefill": ReorderTestCase( - requests=[(100, 0), (10, 50), (1, 10)], + "decode_long_extend_prefill": ReorderTestCase( + requests=[(100, 0, 100), (10, 50, 50), (1, 10, 10)], expected_order=[2, 1, 0], expected_modified=True, ), - "extend_prefill_only": ReorderTestCase( - requests=[(100, 0), (10, 50), (200, 0), (20, 75)], - expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place + "long_extend_prefill_only": ReorderTestCase( + requests=[(100, 0, 100), (10, 50, 50), (200, 0, 200), (20, 75, 75)], + expected_order=[3, 1, 2, 0], expected_modified=True, ), - "complicated_mixed_interleaved": ReorderTestCase( + "complicated_mixed": ReorderTestCase( requests=[ - (1, 20), - (1, 50), - (374, 0), - (300, 20), - (1, 20), - (256, 0), - (1, 5), - (27, 0), - (1, 4), + (1, 20, 20), # decode + (1, 50, 50), # decode + (374, 0, 374), # prefill + (300, 20, 20), # long_extend + (1, 20, 20), # decode + (256, 0, 256), # prefill + (1, 5, 5), # decode + (27, 0, 27), # prefill + (1, 4, 4), # decode ], expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5], expected_modified=True, ), "new_request_single_token_prefill": ReorderTestCase( requests=[ - (100, 0), - (1, 0), # New request with only 1 token (STILL prefill) - (50, 100), - (1, 10), + (100, 0, 100), # prefill + (1, 0, 1), # prefill (single token, still prefill) + (50, 100, 100), # long_extend + (1, 10, 10), # decode ], - # Only index 3 is a true decode (has num_computed_tokens > 0) expected_order=[3, 2, 0, 1], expected_modified=True, ), "multiple_new_requests_single_token_prefill": ReorderTestCase( requests=[ - (1, 0), # New prefill (1 token, no computed) - (1, 0), # New prefill (1 token, no computed) - (1, 50), - (200, 0), + (1, 0, 1), # prefill + (1, 0, 1), # prefill + (1, 50, 50), # decode + (200, 0, 200), # prefill ], expected_order=[2, 1, 0, 3], expected_modified=True, ), + "four_way_already_ordered": ReorderTestCase( + requests=[ + (1, 100, 100), # decode + (1, 50, 100), # short_extend + (10, 50, 100), # long_extend + (100, 0, 100), # prefill + ], + expected_order=[0, 1, 2, 3], + expected_modified=False, + ), + "four_way_needs_reorder": ReorderTestCase( + requests=[ + (100, 0, 100), # prefill + (1, 50, 100), # short_extend + (1, 100, 100), # decode + (10, 50, 100), # long_extend + ], + expected_order=[2, 1, 3, 0], + expected_modified=True, + ), + "four_way_multiple_short_extends": ReorderTestCase( + requests=[ + (2, 100, 100), # decode + (2, 50, 200), # short_extend + (2, 75, 150), # short_extend + (2, 200, 200), # decode + ], + expected_order=[0, 3, 2, 1], + expected_modified=True, + decode_threshold=2, + ), + "four_way_spec_decode_threshold": ReorderTestCase( + requests=[ + (5, 100, 100), # decode + (5, 50, 100), # short_extend + (5, 0, 100), # prefill + (10, 50, 100), # long_extend + ], + expected_order=[0, 1, 3, 2], + expected_modified=True, + decode_threshold=5, + ), } @@ -129,8 +177,9 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase) req_ids = [f"r{i}" for i in range(len(test_case.requests))] num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32) num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)} + num_prompt_tokens = np.array([r[2] for r in test_case.requests], dtype=np.int32) - input_batch = MockInputBatch(req_ids, num_computed_tokens) + input_batch = MockInputBatch(req_ids, num_computed_tokens, num_prompt_tokens) scheduler_output = MockSchedulerOutput(num_scheduled_tokens) modified = reorder_batch_to_split_decodes_and_prefills( diff --git a/tests/v1/e2e/test_hybrid_chunked_prefill.py b/tests/v1/e2e/test_hybrid_chunked_prefill.py index 030081a38af3..1790343ca836 100644 --- a/tests/v1/e2e/test_hybrid_chunked_prefill.py +++ b/tests/v1/e2e/test_hybrid_chunked_prefill.py @@ -43,7 +43,7 @@ pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]), pytest.param( "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8", - marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2), + marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=4), ), ], ) @@ -68,7 +68,7 @@ def test_mtp_speculative_mixed_batch_short_prefill( max_num_batched_tokens=chunk_size, max_model_len=512, enforce_eager=True, - tensor_parallel_size=2, + tensor_parallel_size=4, trust_remote_code=True, enable_chunked_prefill=True, enable_prefix_caching=enable_prefix_caching, diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index d7283b6c846f..cd49ea30e6f4 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -362,6 +362,11 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + is_prefilling: torch.Tensor | None = None + """(batch_size,) bool tensor: True if request is still in prefill phase + (num_computed_tokens < num_prompt_tokens). Used by some backends to + distinguish actual decodes from short extends.""" + # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) _seq_lens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None @@ -443,6 +448,7 @@ def unpadded( encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu), dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens), dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu), + is_prefilling=maybe_slice_reqs(self.is_prefilling), ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index bdb820eac35e..59f2e7ca51a6 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -358,7 +358,9 @@ def _compute_common_metadata( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=decode_threshold + common_attn_metadata, + decode_threshold=decode_threshold, + treat_short_extends_as_decodes=False, ) ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 42459815ef9e..0f41993fc695 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -489,11 +489,15 @@ def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, require_uniform: bool = False, + treat_short_extends_as_decodes: bool = True, ) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. + The batch is expected to be ordered as: + decode → short_extend → long_extend → prefill + Args: common_attn_metadata: CommonAttentionMetadata object containing the batch metadata. @@ -501,6 +505,9 @@ def split_decodes_and_prefills( require_uniform: If True, requires that all decode requests have the same query length. When set, some queries may be considered prefills even if they are <= decode_threshold, in order to ensure uniformity. + treat_short_extends_as_decodes: If True (default), short extends + (query_len <= threshold but still prefilling) are counted as + decodes. If False, they are counted as prefills. Returns: num_decodes: The number of decode requests. @@ -513,8 +520,10 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold and ( - not require_uniform or decode_threshold <= 1 + if ( + max_query_len <= decode_threshold + and (not require_uniform or decode_threshold <= 1) + and treat_short_extends_as_decodes ): return num_reqs, 0, num_tokens, 0 @@ -533,11 +542,14 @@ def split_decodes_and_prefills( else: is_prefill = query_lens > decode_threshold + if not treat_short_extends_as_decodes: + assert common_attn_metadata.is_prefilling is not None + is_prefill |= common_attn_metadata.is_prefilling + if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes num_decode_tokens = query_start_loc[first_prefill].item() @@ -581,39 +593,52 @@ def reorder_batch_to_split_decodes_and_prefills( Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch. + The batch is reordered into 4 regions: + decode: (num_scheduled <= threshold AND is not prefilling) + short_extend: (num_scheduled <= threshold AND is chunked prefilling) + long_extend: (num_scheduled > threshold AND is chunked prefilling) + prefill: (num_computed == 0) # First chunks + Returns: True if the batch was modified, False otherwise. """ - # We now want to reorder the batch into decode → extend → prefill order - # where: - # decode: request with num_scheduled_tokens <= decode_threshold - # extend: non-decode request with existing context - # prefill: non-decode request with no existing context - # NOTE for now we loosely use "decode" to mean requests where attention is - # likely memory-bound and "prefill" to mean requests where attention is - # likely compute-bound, num_reqs = len(input_batch.req_ids) num_scheduled_tokens = [ scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids ] num_scheduled_tokens_np = np.array(num_scheduled_tokens) num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] - - is_prefill = num_computed_tokens_np == 0 - is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill) - is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill) - - # Desired order: decode → extend → prefill - req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default - req_regions[is_extend] = 1 - req_regions[is_prefill] = 2 + num_prompt_tokens_np = input_batch.num_prompt_tokens[:num_reqs] + + has_context = num_computed_tokens_np > 0 + is_below_threshold = num_scheduled_tokens_np <= decode_threshold + done_prefilling = num_computed_tokens_np >= num_prompt_tokens_np + + # Mutually exclusive categories (exactly one True per request): + # 1. No context yet -> prefill + # 2. Has context, above threshold -> long_extend + # 3. Has context, below threshold, still prefilling -> short_extend + # 4. Has context, below threshold, done prefilling -> decode + is_pure_prefill = ~has_context + is_long_extend = has_context & ~is_below_threshold + is_short_extend = has_context & is_below_threshold & ~done_prefilling + is_decode = has_context & is_below_threshold & done_prefilling + + # Desired order: decode → short_extend → long_extend → prefill + req_regions = np.zeros(num_reqs, dtype=np.int32) # 0 = decode by default + req_regions[is_short_extend] = 1 + req_regions[is_long_extend] = 2 + req_regions[is_pure_prefill] = 3 num_decodes = int(is_decode.sum()) - num_extends = int(is_extend.sum()) + num_short_extends = int(is_short_extend.sum()) + num_long_extends = int(is_long_extend.sum()) + num_prefills = int(is_pure_prefill.sum()) - target_regions = np.zeros(num_reqs, dtype=np.int32) - target_regions[num_decodes : num_decodes + num_extends] = 1 - target_regions[num_decodes + num_extends :] = 2 + target_regions = np.repeat( + [0, 1, 2, 3], + [num_decodes, num_short_extends, num_long_extends, num_prefills], + ).astype(np.int32) needs_swap = req_regions != target_regions diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 579c9b7a5acc..0e65385d784a 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -134,7 +134,13 @@ def __init__( pin_memory=pin_memory, ) self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy() - self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens_cpu_tensor = torch.zeros( + (max_num_reqs,), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy() self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs,), device="cpu", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af5dca71f9c0..d3ea944c67af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -739,19 +739,6 @@ def __init__( self.uniform_decode_query_len = 1 + self.num_spec_tokens - # When spec decode is active, the mamba backend classifies requests - # with query_len <= reorder_batch_threshold as "decodes". Prefill - # chunks that fall under this threshold get processed via the decode - # path, which stores intermediate states at sequential slots. We must - # set num_accepted_tokens to the chunk's query_len for those requests - # so the next iteration reads from the correct final-state slot. - # Prefills that went through the actual prefill path should keep the - # default value of 1 (the prefill path stores state at slot 0 only). - self.needs_prefill_as_decode_slots: bool = False - self.prefill_as_decode_num_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int32 - ) - # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -1368,16 +1355,6 @@ def _update_states_after_model_execute( .int() .argmax(-1) ) - spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens) - if self.needs_prefill_as_decode_slots and spec_decode_active: - mamba_utils.update_accepted_tokens_for_prefill_as_decode( - self.input_batch, - self.prefill_as_decode_num_tokens, - self.num_accepted_tokens.gpu, - scheduler_output, - self.reorder_batch_threshold, - num_reqs, - ) if self.cache_config.mamba_cache_mode == "align": for i, num_tokens in enumerate( @@ -1981,14 +1958,23 @@ def _get_block_table(kv_cache_gid: int): attn_gid = self.routed_experts_attn_gid slot_mapping_attn = slot_mappings[attn_gid] self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() + # Compute is_prefilling: True if request is still in prefill phase + # (num_computed_tokens < num_prompt_tokens). Used by mamba backends to + # distinguish actual decodes from short extends. + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ] + num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[ + :num_reqs_padded + ] + is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu + cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], seq_lens=self.seq_lens.gpu[:num_reqs_padded], _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], - _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ], + _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, @@ -1996,6 +1982,7 @@ def _get_block_table(kv_cache_gid: int): block_table_tensor=block_table_gid_0, slot_mapping=slot_mapping_gid_0, causal=True, + is_prefilling=is_prefilling, ) if self.dcp_world_size > 1: @@ -2047,8 +2034,6 @@ def _build_attn_group_metadata( else 0 ) - if isinstance(builder, Mamba2AttentionMetadataBuilder): - self.needs_prefill_as_decode_slots = True extra_attn_metadata_args = {} if use_spec_decode and isinstance( builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 68172133eb99..2bd5d2b3fea8 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -266,45 +266,3 @@ def postprocess_mamba( if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) - - -def update_accepted_tokens_for_prefill_as_decode( - input_batch: GPUInputBatch, - prefill_as_decode_num_tokens: CpuGpuBuffer, - num_accepted_tokens_gpu: torch.Tensor, - scheduler_output: SchedulerOutput, - decode_qlen_threshold: int | None, - num_reqs: int, -): - """ - Adjusts num_accepted_tokens for prefill chunks processed via the decode path. - This ensures subsequent iterations read from the correct sequential state slot - instead of the default prefill slot 0. Not used by GDN attention, which manually - separates short prefills and short decodes when building the attention metadata. - """ - any_is_prefill = False - for i in range(num_reqs): - num_computed = input_batch.num_computed_tokens_cpu[i] - num_prompt = input_batch.num_prompt_tokens[i] - is_prefill = num_computed < num_prompt - req_id = input_batch.req_ids[i] - query_len = scheduler_output.num_scheduled_tokens[req_id] - - if is_prefill: - classified_as_decode = ( - decode_qlen_threshold is not None and query_len <= decode_qlen_threshold - ) - num_tokens = query_len if classified_as_decode else 1 - any_is_prefill = True - else: - num_tokens = -1 - prefill_as_decode_num_tokens.np[i] = num_tokens - - # We can skip the GPU transfer if there aren't any values to update - if any_is_prefill: - prefill_as_decode_num_tokens.copy_to_gpu(num_reqs) - num_accepted_tokens_gpu[:num_reqs] = torch.where( - prefill_as_decode_num_tokens.gpu[:num_reqs] != -1, - prefill_as_decode_num_tokens.gpu[:num_reqs], - num_accepted_tokens_gpu[:num_reqs], - )