diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index 3145c8d3c5..bc2fa92e91 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -530,9 +530,9 @@ def test_real_scenario_fallback_ctx_7408_not_truncated(): def test_exponential_decode_block_limit_uncapped(monkeypatch): """Verify that decode block limit is computed from log2(max_decode_blocks). - With the new approach, excessive warmup buckets are controlled by - filters in generate_buckets() (num_ctx_tokens_less_or_equal_batched_max_model_len) - rather than by capping the block limit in get_decode_cfgs(). + For contiguous PA, max_decode_blocks = min(max_blocks, ceil(max_model_len/block_size)*max_num_seqs). + The block range is already bounded by max_blocks, so no additional + ctx filter is applied to contiguous PA decode buckets. """ monkeypatch.setenv("VLLM_EXPONENTIAL_BUCKETING", "true") monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true") @@ -669,7 +669,6 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con # --- Tests that num_ctx_tokens_less_or_equal_batched_max_model_len filter is applied --- -@pytest.mark.parametrize("use_contiguous_pa", [True, False], ids=["contiguous_pa", "non_contiguous_pa"]) @pytest.mark.parametrize( ("max_model_len", "block_size", "max_num_seqs", "max_blocks", "max_num_batched_tokens"), [ @@ -679,13 +678,15 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con ], ids=["qwen3_32b", "small_model", "long_ctx"], ) -def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_model_len, block_size, max_num_seqs, - max_blocks, max_num_batched_tokens): - """Every decode bucket returned by generate_buckets must satisfy - num_ctx_tokens_less_or_equal_batched_max_model_len: - ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0]) +def test_decode_buckets_satisfy_ctx_filter_non_contiguous_pa(monkeypatch, max_model_len, block_size, max_num_seqs, + max_blocks, max_num_batched_tokens): + """For non-contiguous PA, every decode bucket returned by generate_buckets + must satisfy ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0]). + + The filter is only applied to non-contiguous PA; contiguous PA decode + buckets are not filtered since their block range is already bounded by max_blocks. """ - monkeypatch.setenv("VLLM_CONTIGUOUS_PA", str(use_contiguous_pa).lower()) + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "false") clear_config() get_config() @@ -727,3 +728,83 @@ def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_m f"ctx <= ceil(max_model_len/block_size) * bs " f"(max_blocks_per_seq={max_blocks_per_seq}):\n" + "\n".join(f" bs={bs}, query={query}, ctx={ctx}" for bs, query, ctx in violations[:20])) + + +def test_contiguous_pa_decode_buckets_not_filtered_by_ctx(monkeypatch): + """For contiguous PA, the ctx filter must NOT be applied to decode buckets. + + Reproduces std_out.txt issue: with max_model_len=2048, block_size=256, + max_num_seqs=256, the bucket (256, 1, 2112) was incorrectly filtered + because 2112 > ceil(2048/256)*256 = 2048. + """ + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true") + clear_config() + get_config() + + max_model_len = 2048 + block_size = 256 + max_num_seqs = 256 + max_blocks = 2113 + max_num_batched_tokens = 1048832 + + bs_range = [256] + query_range = [1] + ctx_range = list(range(1280, 2113, 64)) # 1280, 1344, ..., 2048, 2112 + ctx_range.append(max_blocks) # append num_hpu_blocks as done in generate_decode_buckets + + buckets = generate_buckets( + bs_range=bs_range, + query_range=query_range, + ctx_range=ctx_range, + is_prompt=False, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + max_num_prefill_seqs=1, + max_num_batched_tokens=max_num_batched_tokens, + block_size=block_size, + max_blocks=max_blocks, + ) + + bucket_ctxs = [ctx for _, _, ctx in buckets] + assert 2112 in bucket_ctxs, (f"Bucket ctx=2112 was incorrectly filtered out. " + f"Max ctx in buckets: {max(bucket_ctxs)}") + assert max_blocks in bucket_ctxs, (f"Bucket ctx={max_blocks} (num_hpu_blocks) was incorrectly filtered out.") + + +def test_file_buckets_bypass_filters(monkeypatch): + """File-based bucketing (VLLM_BUCKETING_FROM_FILE) skips all filters. + + Buckets (1,1,256) and (2,1,512) would normally be rejected by the + batch_size_smaller_than_blocks or ctx filters in non-file mode. + Since file buckets bypass filters entirely, all provided buckets + must appear in the output unchanged. + """ + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true") + clear_config() + get_config() + + max_model_len = 2048 + block_size = 256 + max_num_seqs = 32 + max_blocks = 2424 + + # (512,1,256) would be rejected by batch_size_smaller_than_blocks (bs > ctx) + # All buckets pass through because file_buckets bypass filters entirely + file_buckets = [(1, 1, 256), (1, 1, 512), (2, 1, 256), (512, 1, 256), (32, 1, 2424)] + + buckets = generate_buckets( + bs_range=[], + query_range=[], + ctx_range=[], + is_prompt=False, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + max_num_prefill_seqs=1, + max_num_batched_tokens=8192, + block_size=block_size, + max_blocks=max_blocks, + file_buckets=file_buckets, + ) + + assert set(buckets) == set(file_buckets), (f"All file buckets should pass through unfiltered.\n" + f"Expected: {sorted(file_buckets)}\nGot: {sorted(buckets)}") diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 2fd85dbdae..870cdb0ac3 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -448,10 +448,11 @@ def batch_size_smaller_than_blocks(bs, query, ctx): return bs <= ctx def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx): - is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True + ctx_min = ctx_range[0] if ctx_range else 0 + is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_min else True if not is_valid: omitted_buckets.add( - ("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True", + ("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_min else True", "-> bs, query, ctx: ", bs, query, ctx)) return is_valid @@ -463,7 +464,7 @@ def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx): }, "decode": { # depends only on contiguous PA - True: [num_ctx_tokens_less_or_equal_batched_max_model_len], + True: [], False: [batch_size_smaller_than_blocks, num_ctx_tokens_less_or_equal_batched_max_model_len], } } @@ -490,7 +491,7 @@ def is_ctx_allowed(ctx): buckets = set() buckets_2d = set() omitted_buckets = set() - filters = get_filters(is_prompt, use_merged_prefill, use_contiguous_pa) + filters = [] if file_buckets else get_filters(is_prompt, use_merged_prefill, use_contiguous_pa) corrector = get_corrector(is_prompt, use_contiguous_pa) if file_buckets: