From 5ac1801039991fc0ea219e2b4753b8125c3828e4 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 12 Mar 2026 10:23:30 +0800 Subject: [PATCH 1/8] fix max decode blocks for non-contiguous pa in linear bucketing Signed-off-by: Youlei Yang --- vllm_gaudi/extension/bucketing/linear.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index d50a8e57a2..8d941ae2bd 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -63,17 +63,18 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ if contiguous_pa: max_decode_blocks = max_blocks decode_block_bucket_cfg = read_bucket_settings('decode', 'block', min=1, step=block_size, max=max_decode_blocks) - if decode_block_bucket_cfg[2] > max_blocks: + if contiguous_pa and decode_block_bucket_cfg[2] > max_blocks: logger().info( f'VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MAX={max_blocks}' ) decode_block_bucket_cfg[2] = max_blocks - if decode_block_bucket_cfg[0] > max_blocks: - decode_block_bucket_min = max(1, max_blocks - decode_block_bucket_cfg[1]) - logger().info( - f'VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}' - ) - decode_block_bucket_cfg[0] = decode_block_bucket_min + + if decode_block_bucket_cfg[0] > decode_block_bucket_cfg[2]: + decode_block_bucket_min = max(1, decode_block_bucket_cfg[2] - decode_block_bucket_cfg[1]) + logger().info( + f"VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={decode_block_bucket_cfg[2]}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}" + ) + decode_block_bucket_cfg[0] = decode_block_bucket_min msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{decode_bs_bucket_cfg}, " From 8d1c242ced5a4d64b1717f732a58d6aaeab1d131 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 5 Mar 2026 13:08:20 +0800 Subject: [PATCH 2/8] fix max decode blocks for non-contiguous pa in exp bucketing Signed-off-by: Youlei Yang --- vllm_gaudi/extension/bucketing/exponential.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index eece7ef7af..0ba9f43951 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -89,13 +89,8 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ decode_bs_limit = math.ceil(math.log2(max_num_seqs)) + 1 decode_bs_bucket_cfg = [1, 2, max_num_seqs, decode_bs_limit] decode_query_bucket_cfg = [1, 1, 1, 1] - # With non-contiguous PA, total block references across all sequences - # can exceed physical num_hpu_blocks (same physical block appears in - # multiple sequence block tables). Use 3x headroom so prepared buckets - # cover realistic prefix-sharing scenarios and avoid costly HPU graph - # recompilation at high KV-cache utilization. - max_decode_blocks = max_blocks if use_contiguous_pa else \ - max_blocks * 3 + max_decode_blocks = math.ceil(max_model_len / block_size) * max_num_seqs + max_decode_blocks = min(max_blocks, max_decode_blocks) if use_contiguous_pa else max_decode_blocks max_decode_block_limit = math.ceil(math.log2(max_decode_blocks)) + 1 decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, max_decode_block_limit] From 85be44c60adccd96be7eb5bc97f5898296626a76 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 5 Mar 2026 13:10:49 +0800 Subject: [PATCH 3/8] fix decode blocks limit for exp bucketing Signed-off-by: Youlei Yang --- vllm_gaudi/extension/bucketing/exponential.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index 0ba9f43951..312294995a 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -91,8 +91,8 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ decode_query_bucket_cfg = [1, 1, 1, 1] max_decode_blocks = math.ceil(max_model_len / block_size) * max_num_seqs max_decode_blocks = min(max_blocks, max_decode_blocks) if use_contiguous_pa else max_decode_blocks - max_decode_block_limit = math.ceil(math.log2(max_decode_blocks)) + 1 - decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, max_decode_block_limit] + decode_blocks_limit = math.ceil(math.log2(max_decode_blocks)) + 1 + decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, decode_blocks_limit] msg = ("Decode bucket config (min, step, max_warmup, limit) " f"bs:{decode_bs_bucket_cfg}, " From 806769dc0ee424dc2a4858f4b8564c26c76418c2 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 12 Mar 2026 10:25:52 +0800 Subject: [PATCH 4/8] fix decode warmup OOM issue for large num_blocks Signed-off-by: Youlei Yang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index caafb5e0b0..1953de9720 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -5096,6 +5096,9 @@ def _add_dummy_request(self, num_scheduled_tokens[req_id] = scheduled_tokens def _generate_seq_lengths(self, num_samples, num_blocks, block_size): + # ensure the actual number of blocks is less than the KV cache blocks + num_blocks = min(self.kv_cache_config.num_blocks, num_blocks) + assert num_samples <= num_blocks blocks = [num_blocks // num_samples] * num_samples missing_blocks = num_blocks - sum(blocks) From c15ec4779c0f722f4efc0826105d45d198ce5554 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 12 Mar 2026 10:28:01 +0800 Subject: [PATCH 5/8] add filter for decode buckets to ensure ctx <= max_model_len / block_size * bs Signed-off-by: Youlei Yang --- vllm_gaudi/extension/bucketing/common.py | 39 ++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index f35cb445f6..a6e6ef9274 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -1,3 +1,4 @@ +import logging import os import bisect import math @@ -446,6 +447,14 @@ def batch_size_smaller_than_blocks(bs, query, ctx): omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx)) return bs <= ctx + def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx): + is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True + if not is_valid: + omitted_buckets.add( + ("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True", + "-> bs, query, ctx: ", bs, query, ctx)) + return is_valid + filters_map = { "prompt": { # depends only on merged_prefill @@ -454,8 +463,8 @@ def batch_size_smaller_than_blocks(bs, query, ctx): }, "decode": { # depends only on contiguous PA - True: [], - False: [batch_size_smaller_than_blocks], + True: [num_ctx_tokens_less_or_equal_batched_max_model_len], + False: [batch_size_smaller_than_blocks, num_ctx_tokens_less_or_equal_batched_max_model_len], } } @@ -519,6 +528,32 @@ def is_ctx_allowed(ctx): raise RuntimeError("Generated 0 " + phase + " buckets. Please adjust the bucketing configuration according to README") + if logger().getEffectiveLevel() <= logging.DEBUG and omitted_buckets: + phase = "prompt" if is_prompt else "decode" + omitted_buckets_str = "\n".join(map(str, sorted(omitted_buckets))) + msg = f"Omitted {len(omitted_buckets)} {phase} buckets:\n{omitted_buckets_str}" + logger().debug(msg) + + return sorted(buckets) + + +def generate_unified_buckets(query_range, shared_ctx_range, unique_ctx_range, bs, block_size, max_model_len): + buckets = set() + is_causal = [0, 1] + + for query, shared_ctx, unique_ctx, causal in itertools.product(query_range, shared_ctx_range, unique_ctx_range, + is_causal): + if causal: + max_bs = min(bs, query) + if math.ceil(shared_ctx * block_size // max_bs) <= max_model_len: + buckets.add((query, shared_ctx, unique_ctx, causal)) + elif query <= bs: + # non causal query = current bs + if shared_ctx > 0 or unique_ctx > 0: + if shared_ctx == 0 or (math.ceil(shared_ctx * block_size // (query // 2)) <= max_model_len): + if shared_ctx > 0 or query <= unique_ctx: + buckets.add((query, shared_ctx, unique_ctx, causal)) + return sorted(buckets) From b1f9660c803f7d7d03a29f4eb3cca59aa4bebc21 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 14 Apr 2026 02:42:45 +0000 Subject: [PATCH 6/8] fix tests Signed-off-by: Youlei Yang --- tests/unit_tests/test_bucketing.py | 94 ++++++++++++++++++------------ 1 file changed, 58 insertions(+), 36 deletions(-) diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index 9807aa3389..d2f9b4d513 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. ############################################################################### +import math + import pytest from unittest.mock import patch @@ -189,25 +191,25 @@ def __init__(self, **kwargs): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') -def test_exponential_decode_cfgs_non_contiguous_pa_bounded(mock_get_config): - """max_decode_blocks should be max_blocks * 3 when use_contiguous_pa=False. - - The 3x multiplier accounts for prefix-cache block sharing: the same - physical block can appear in multiple sequences' block tables, so total - block references may exceed num_hpu_blocks. +def test_exponential_decode_cfgs_non_contiguous_pa_unbounded(mock_get_config): + """max_decode_blocks should be ceil(max_model_len/block_size)*max_num_seqs + when use_contiguous_pa=False. Actual bounding of generated buckets + happens via filters in generate_buckets(). """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() max_blocks = 3593 block_size = 128 - _, _, block_cfg = strategy.get_decode_cfgs(max_num_seqs=256, + max_model_len = 91964 + max_num_seqs = 256 + _, _, block_cfg = strategy.get_decode_cfgs(max_num_seqs=max_num_seqs, block_size=block_size, max_num_batched_tokens=131072, - max_model_len=91964, + max_model_len=max_model_len, max_blocks=max_blocks) - expected_max = max_blocks * 3 # 10779 + expected_max = math.ceil(max_model_len / block_size) * max_num_seqs assert block_cfg[2] == expected_max, (f"Expected max_decode_blocks={expected_max}, got {block_cfg[2]}") @@ -229,8 +231,12 @@ def test_exponential_decode_cfgs_contiguous_pa_uses_max_blocks(mock_get_config): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') -def test_exponential_decode_max_never_exceeds_bounded_value(mock_get_config): - """Regression test: large max_model_len must NOT produce gigantic decode buckets.""" +def test_exponential_decode_cfgs_non_contiguous_pa_formula(mock_get_config): + """Verify non-contiguous PA decode cfg uses ceil(max_model_len/block_size)*max_num_seqs. + + Actual bounding of excessive buckets happens via the + num_ctx_tokens_less_or_equal_batched_max_model_len filter in generate_buckets(). + """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() @@ -245,14 +251,8 @@ def test_exponential_decode_max_never_exceeds_bounded_value(mock_get_config): max_model_len=max_model_len, max_blocks=max_blocks) - # The old (buggy) formula would produce min(91964//128*256, ...) = 183808 - # The fix should give max_blocks * 3 = 10779 - assert block_cfg[2] <= max_blocks * 3, (f"Decode bucket max {block_cfg[2]} exceeds bounded limit " - f"{max_blocks * 3}. Buckets are too large!") - # Sanity: must not be the old gigantic value - old_buggy_value = max_model_len // block_size * max_num_seqs - assert block_cfg[2] < old_buggy_value, (f"Decode bucket max {block_cfg[2]} matches buggy formula output " - f"{old_buggy_value}") + expected_max = math.ceil(max_model_len / block_size) * max_num_seqs + assert block_cfg[2] == expected_max, (f"Expected max_decode_blocks={expected_max}, got {block_cfg[2]}") @patch('vllm_gaudi.extension.bucketing.exponential.get_config') @@ -315,8 +315,8 @@ def test_fallback_bucket_ctx_uses_calc_fallback(): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') def test_real_scenario_decode_cfg_matches_fixed_log(mock_get_config): """Verify decode bucket config matches expected values for real scenario. - - With max_blocks * 3: block config should be [1, 256, 10779, 14] + With non-contiguous PA: block config should be + [1, 256, ceil(91964/128)*256, ceil(log2(that))+1] """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() @@ -327,13 +327,11 @@ def test_real_scenario_decode_cfg_matches_fixed_log(mock_get_config): max_model_len=_REAL_MAX_MODEL_LEN, max_blocks=_REAL_MAX_BLOCKS) - # Expected: [1, 256, 10779, 14] + expected_max = math.ceil(_REAL_MAX_MODEL_LEN / _REAL_BLOCK_SIZE) * _REAL_MAX_NUM_SEQS + expected_limit = math.ceil(math.log2(expected_max)) + 1 assert block_cfg[0] == 1, f"block min: expected 1, got {block_cfg[0]}" assert block_cfg[1] == _REAL_MAX_NUM_SEQS, (f"block step: expected {_REAL_MAX_NUM_SEQS}, got {block_cfg[1]}") - assert block_cfg[2] == _REAL_FIXED_MAX_DECODE_BLOCKS, ( - f"block max: expected {_REAL_FIXED_MAX_DECODE_BLOCKS}, got {block_cfg[2]}") - import math - expected_limit = math.ceil(math.log2(_REAL_MAX_BLOCKS * 3)) + 1 # 14 + assert block_cfg[2] == expected_max, (f"block max: expected {expected_max}, got {block_cfg[2]}") assert block_cfg[3] == expected_limit, (f"block limit: expected {expected_limit}, got {block_cfg[3]}") @@ -357,10 +355,11 @@ def test_real_scenario_decode_cfg_matches_fixed_bs_log(mock_get_config): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') -def test_real_scenario_decode_block_range_bounded(mock_get_config): - """Verify generated decode block range stays within bounds (real scenario). +def test_real_scenario_decode_block_range_within_cfg_max(mock_get_config): + """Verify generated decode block range stays within cfg max (real scenario). - Fixed log showed blocks up to 3721. Buggy run had blocks up to 183808. + The block range from get_range() extends up to max_decode_blocks. + Actual bounding per (bs, ctx) pair happens via filters in generate_buckets(). """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() @@ -372,14 +371,9 @@ def test_real_scenario_decode_block_range_bounded(mock_get_config): max_blocks=_REAL_MAX_BLOCKS) block_range = strategy.get_range(block_cfg) + expected_max = math.ceil(_REAL_MAX_MODEL_LEN / _REAL_BLOCK_SIZE) * _REAL_MAX_NUM_SEQS - assert max(block_range) <= _REAL_FIXED_MAX_DECODE_BLOCKS, ( - f"Largest block bucket {max(block_range)} exceeds bounded max " - f"{_REAL_FIXED_MAX_DECODE_BLOCKS}") - assert max(block_range) < _REAL_BUGGY_MAX_DECODE_BLOCKS, ( - f"Block range still contains buggy value {_REAL_BUGGY_MAX_DECODE_BLOCKS}") - # Verify reasonable number of buckets (log showed 13 unique block values) - assert len(block_range) <= 20, (f"Too many block buckets: {len(block_range)}") + assert max(block_range) <= expected_max, (f"Largest block bucket {max(block_range)} exceeds cfg max {expected_max}") @patch('vllm_gaudi.extension.bucketing.exponential.get_config') @@ -533,6 +527,34 @@ def test_real_scenario_fallback_ctx_7408_not_truncated(): assert new_ctx == calc_fallback_value(7408, 32), (f"Fallback ctx {new_ctx} should equal calc_fallback_value result") +def test_exponential_decode_block_limit_uncapped(monkeypatch): + """Verify that decode block limit is computed from log2(max_decode_blocks). + + With the new approach, excessive warmup buckets are controlled by + filters in generate_buckets() (num_ctx_tokens_less_or_equal_batched_max_model_len) + rather than by capping the block limit in get_decode_cfgs(). + """ + monkeypatch.setenv("VLLM_EXPONENTIAL_BUCKETING", "true") + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true") + clear_config() + get_config() + + strategy = ExponentialBucketingStrategy() + max_num_seqs = 21 + block_size = 128 + max_num_batched_tokens = 8192 + max_model_len = 131072 + max_blocks = 65536 + + bs_cfg, query_cfg, block_cfg = strategy.get_decode_cfgs(max_num_seqs, block_size, max_num_batched_tokens, + max_model_len, max_blocks) + + # max_decode_blocks = min(65536, ceil(131072/128)*21) = min(65536, 21504) = 21504 + expected_max_decode_blocks = min(max_blocks, math.ceil(max_model_len / block_size) * max_num_seqs) + expected_limit = math.ceil(math.log2(expected_max_decode_blocks)) + 1 + assert block_cfg[2] == expected_max_decode_blocks, ( + f"Expected max_decode_blocks={expected_max_decode_blocks}, got {block_cfg[2]}") + assert block_cfg[3] == expected_limit, (f"Expected decode_blocks_limit={expected_limit}, got {block_cfg[3]}") # --- Padding-aware bucketing tests --- From c3cc743eb6490f5630029871c3ced0fa504611a1 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 14 Apr 2026 03:19:34 +0000 Subject: [PATCH 7/8] remove unified buckets Signed-off-by: Youlei Yang --- vllm_gaudi/extension/bucketing/common.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index a6e6ef9274..2fd85dbdae 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -537,26 +537,6 @@ def is_ctx_allowed(ctx): return sorted(buckets) -def generate_unified_buckets(query_range, shared_ctx_range, unique_ctx_range, bs, block_size, max_model_len): - buckets = set() - is_causal = [0, 1] - - for query, shared_ctx, unique_ctx, causal in itertools.product(query_range, shared_ctx_range, unique_ctx_range, - is_causal): - if causal: - max_bs = min(bs, query) - if math.ceil(shared_ctx * block_size // max_bs) <= max_model_len: - buckets.add((query, shared_ctx, unique_ctx, causal)) - elif query <= bs: - # non causal query = current bs - if shared_ctx > 0 or unique_ctx > 0: - if shared_ctx == 0 or (math.ceil(shared_ctx * block_size // (query // 2)) <= max_model_len): - if shared_ctx > 0 or query <= unique_ctx: - buckets.add((query, shared_ctx, unique_ctx, causal)) - - return sorted(buckets) - - def is_greater_or_equal(tuple1, tuple2): return tuple1[0] >= tuple2[0] and tuple1[1] >= tuple2[1] \ and tuple1[2] >= tuple2[2] From 2c3b44b7979749536c4c0da923d7cf7c5208a447 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 27 Apr 2026 00:38:36 +0000 Subject: [PATCH 8/8] add UT for the newly added filter Signed-off-by: Youlei Yang --- tests/unit_tests/test_bucketing.py | 65 ++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index d2f9b4d513..3145c8d3c5 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -555,6 +555,8 @@ def test_exponential_decode_block_limit_uncapped(monkeypatch): assert block_cfg[2] == expected_max_decode_blocks, ( f"Expected max_decode_blocks={expected_max_decode_blocks}, got {block_cfg[2]}") assert block_cfg[3] == expected_limit, (f"Expected decode_blocks_limit={expected_limit}, got {block_cfg[3]}") + + # --- Padding-aware bucketing tests --- @@ -662,3 +664,66 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con max_blocks=3593) assert block_cfg == [3465, 128, 3593, 899, 25] + + +# --- Tests that num_ctx_tokens_less_or_equal_batched_max_model_len filter is applied --- + + +@pytest.mark.parametrize("use_contiguous_pa", [True, False], ids=["contiguous_pa", "non_contiguous_pa"]) +@pytest.mark.parametrize( + ("max_model_len", "block_size", "max_num_seqs", "max_blocks", "max_num_batched_tokens"), + [ + (91964, 128, 256, 3593, 2048), # Qwen3-32B real scenario + (4096, 128, 64, 500, 2048), # small model + (131072, 128, 21, 65536, 8192), # long context + ], + ids=["qwen3_32b", "small_model", "long_ctx"], +) +def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_model_len, block_size, max_num_seqs, + max_blocks, max_num_batched_tokens): + """Every decode bucket returned by generate_buckets must satisfy + num_ctx_tokens_less_or_equal_batched_max_model_len: + ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0]) + """ + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", str(use_contiguous_pa).lower()) + clear_config() + get_config() + + strategy = ExponentialBucketingStrategy() + + bs_cfg, query_cfg, block_cfg = strategy.get_decode_cfgs( + max_num_seqs=max_num_seqs, + block_size=block_size, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + max_blocks=max_blocks, + ) + bs_range = strategy.get_range(bs_cfg) + query_range = strategy.get_range(query_cfg) + ctx_range = strategy.get_range(block_cfg) + + buckets = generate_buckets( + bs_range=bs_range, + query_range=query_range, + ctx_range=ctx_range, + is_prompt=False, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + max_num_prefill_seqs=1, + max_num_batched_tokens=max_num_batched_tokens, + block_size=block_size, + max_blocks=max_blocks, + ) + + ctx_min = ctx_range[0] + max_blocks_per_seq = math.ceil(max_model_len / block_size) + + violations = [] + for bs, query, ctx in buckets: + if ctx > ctx_min and ctx > max_blocks_per_seq * bs: + violations.append((bs, query, ctx)) + + assert not violations, (f"Found {len(violations)} decode bucket(s) violating " + f"ctx <= ceil(max_model_len/block_size) * bs " + f"(max_blocks_per_seq={max_blocks_per_seq}):\n" + + "\n".join(f" bs={bs}, query={query}, ctx={ctx}" for bs, query, ctx in violations[:20]))