Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 91 additions & 10 deletions tests/unit_tests/test_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,9 @@ def test_real_scenario_fallback_ctx_7408_not_truncated():
def test_exponential_decode_block_limit_uncapped(monkeypatch):
"""Verify that decode block limit is computed from log2(max_decode_blocks).

With the new approach, excessive warmup buckets are controlled by
filters in generate_buckets() (num_ctx_tokens_less_or_equal_batched_max_model_len)
rather than by capping the block limit in get_decode_cfgs().
For contiguous PA, max_decode_blocks = min(max_blocks, ceil(max_model_len/block_size)*max_num_seqs).
The block range is already bounded by max_blocks, so no additional
ctx filter is applied to contiguous PA decode buckets.
"""
monkeypatch.setenv("VLLM_EXPONENTIAL_BUCKETING", "true")
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true")
Expand Down Expand Up @@ -669,7 +669,6 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con
# --- Tests that num_ctx_tokens_less_or_equal_batched_max_model_len filter is applied ---


@pytest.mark.parametrize("use_contiguous_pa", [True, False], ids=["contiguous_pa", "non_contiguous_pa"])
@pytest.mark.parametrize(
("max_model_len", "block_size", "max_num_seqs", "max_blocks", "max_num_batched_tokens"),
[
Expand All @@ -679,13 +678,15 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con
],
ids=["qwen3_32b", "small_model", "long_ctx"],
)
def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_model_len, block_size, max_num_seqs,
max_blocks, max_num_batched_tokens):
"""Every decode bucket returned by generate_buckets must satisfy
num_ctx_tokens_less_or_equal_batched_max_model_len:
ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0])
def test_decode_buckets_satisfy_ctx_filter_non_contiguous_pa(monkeypatch, max_model_len, block_size, max_num_seqs,
max_blocks, max_num_batched_tokens):
"""For non-contiguous PA, every decode bucket returned by generate_buckets
must satisfy ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0]).

The filter is only applied to non-contiguous PA; contiguous PA decode
buckets are not filtered since their block range is already bounded by max_blocks.
"""
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", str(use_contiguous_pa).lower())
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "false")
clear_config()
get_config()

Expand Down Expand Up @@ -727,3 +728,83 @@ def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_m
f"ctx <= ceil(max_model_len/block_size) * bs "
f"(max_blocks_per_seq={max_blocks_per_seq}):\n" +
"\n".join(f" bs={bs}, query={query}, ctx={ctx}" for bs, query, ctx in violations[:20]))


def test_contiguous_pa_decode_buckets_not_filtered_by_ctx(monkeypatch):
"""For contiguous PA, the ctx filter must NOT be applied to decode buckets.

Reproduces std_out.txt issue: with max_model_len=2048, block_size=256,
max_num_seqs=256, the bucket (256, 1, 2112) was incorrectly filtered
because 2112 > ceil(2048/256)*256 = 2048.
"""
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true")
clear_config()
get_config()

max_model_len = 2048
block_size = 256
max_num_seqs = 256
max_blocks = 2113
max_num_batched_tokens = 1048832

bs_range = [256]
query_range = [1]
ctx_range = list(range(1280, 2113, 64)) # 1280, 1344, ..., 2048, 2112
ctx_range.append(max_blocks) # append num_hpu_blocks as done in generate_decode_buckets

buckets = generate_buckets(
bs_range=bs_range,
query_range=query_range,
ctx_range=ctx_range,
is_prompt=False,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
max_num_prefill_seqs=1,
max_num_batched_tokens=max_num_batched_tokens,
block_size=block_size,
max_blocks=max_blocks,
)

bucket_ctxs = [ctx for _, _, ctx in buckets]
assert 2112 in bucket_ctxs, (f"Bucket ctx=2112 was incorrectly filtered out. "
f"Max ctx in buckets: {max(bucket_ctxs)}")
assert max_blocks in bucket_ctxs, (f"Bucket ctx={max_blocks} (num_hpu_blocks) was incorrectly filtered out.")


def test_file_buckets_bypass_filters(monkeypatch):
"""File-based bucketing (VLLM_BUCKETING_FROM_FILE) skips all filters.

Buckets (1,1,256) and (2,1,512) would normally be rejected by the
batch_size_smaller_than_blocks or ctx filters in non-file mode.
Since file buckets bypass filters entirely, all provided buckets
must appear in the output unchanged.
"""
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true")
clear_config()
get_config()

max_model_len = 2048
block_size = 256
max_num_seqs = 32
max_blocks = 2424

# (512,1,256) would be rejected by batch_size_smaller_than_blocks (bs > ctx)
# All buckets pass through because file_buckets bypass filters entirely
file_buckets = [(1, 1, 256), (1, 1, 512), (2, 1, 256), (512, 1, 256), (32, 1, 2424)]

buckets = generate_buckets(
bs_range=[],
query_range=[],
ctx_range=[],
is_prompt=False,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
max_num_prefill_seqs=1,
max_num_batched_tokens=8192,
block_size=block_size,
max_blocks=max_blocks,
file_buckets=file_buckets,
)

assert set(buckets) == set(file_buckets), (f"All file buckets should pass through unfiltered.\n"
f"Expected: {sorted(file_buckets)}\nGot: {sorted(buckets)}")
9 changes: 5 additions & 4 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,11 @@ def batch_size_smaller_than_blocks(bs, query, ctx):
return bs <= ctx

def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx):
is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True
ctx_min = ctx_range[0] if ctx_range else 0
is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_min else True
if not is_valid:
omitted_buckets.add(
("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True",
("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_min else True",
"-> bs, query, ctx: ", bs, query, ctx))
return is_valid

Expand All @@ -463,7 +464,7 @@ def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx):
},
"decode": {
# depends only on contiguous PA
True: [num_ctx_tokens_less_or_equal_batched_max_model_len],
True: [],
False: [batch_size_smaller_than_blocks, num_ctx_tokens_less_or_equal_batched_max_model_len],
}
}
Expand All @@ -490,7 +491,7 @@ def is_ctx_allowed(ctx):
buckets = set()
buckets_2d = set()
omitted_buckets = set()
filters = get_filters(is_prompt, use_merged_prefill, use_contiguous_pa)
filters = [] if file_buckets else get_filters(is_prompt, use_merged_prefill, use_contiguous_pa)
corrector = get_corrector(is_prompt, use_contiguous_pa)

if file_buckets:
Expand Down
Loading