diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index 9807aa3389..3145c8d3c5 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. ############################################################################### +import math + import pytest from unittest.mock import patch @@ -189,25 +191,25 @@ def __init__(self, **kwargs): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') -def test_exponential_decode_cfgs_non_contiguous_pa_bounded(mock_get_config): - """max_decode_blocks should be max_blocks * 3 when use_contiguous_pa=False. - - The 3x multiplier accounts for prefix-cache block sharing: the same - physical block can appear in multiple sequences' block tables, so total - block references may exceed num_hpu_blocks. +def test_exponential_decode_cfgs_non_contiguous_pa_unbounded(mock_get_config): + """max_decode_blocks should be ceil(max_model_len/block_size)*max_num_seqs + when use_contiguous_pa=False. Actual bounding of generated buckets + happens via filters in generate_buckets(). """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() max_blocks = 3593 block_size = 128 - _, _, block_cfg = strategy.get_decode_cfgs(max_num_seqs=256, + max_model_len = 91964 + max_num_seqs = 256 + _, _, block_cfg = strategy.get_decode_cfgs(max_num_seqs=max_num_seqs, block_size=block_size, max_num_batched_tokens=131072, - max_model_len=91964, + max_model_len=max_model_len, max_blocks=max_blocks) - expected_max = max_blocks * 3 # 10779 + expected_max = math.ceil(max_model_len / block_size) * max_num_seqs assert block_cfg[2] == expected_max, (f"Expected max_decode_blocks={expected_max}, got {block_cfg[2]}") @@ -229,8 +231,12 @@ def test_exponential_decode_cfgs_contiguous_pa_uses_max_blocks(mock_get_config): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') -def test_exponential_decode_max_never_exceeds_bounded_value(mock_get_config): - """Regression test: large max_model_len must NOT produce gigantic decode buckets.""" +def test_exponential_decode_cfgs_non_contiguous_pa_formula(mock_get_config): + """Verify non-contiguous PA decode cfg uses ceil(max_model_len/block_size)*max_num_seqs. + + Actual bounding of excessive buckets happens via the + num_ctx_tokens_less_or_equal_batched_max_model_len filter in generate_buckets(). + """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() @@ -245,14 +251,8 @@ def test_exponential_decode_max_never_exceeds_bounded_value(mock_get_config): max_model_len=max_model_len, max_blocks=max_blocks) - # The old (buggy) formula would produce min(91964//128*256, ...) = 183808 - # The fix should give max_blocks * 3 = 10779 - assert block_cfg[2] <= max_blocks * 3, (f"Decode bucket max {block_cfg[2]} exceeds bounded limit " - f"{max_blocks * 3}. Buckets are too large!") - # Sanity: must not be the old gigantic value - old_buggy_value = max_model_len // block_size * max_num_seqs - assert block_cfg[2] < old_buggy_value, (f"Decode bucket max {block_cfg[2]} matches buggy formula output " - f"{old_buggy_value}") + expected_max = math.ceil(max_model_len / block_size) * max_num_seqs + assert block_cfg[2] == expected_max, (f"Expected max_decode_blocks={expected_max}, got {block_cfg[2]}") @patch('vllm_gaudi.extension.bucketing.exponential.get_config') @@ -315,8 +315,8 @@ def test_fallback_bucket_ctx_uses_calc_fallback(): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') def test_real_scenario_decode_cfg_matches_fixed_log(mock_get_config): """Verify decode bucket config matches expected values for real scenario. - - With max_blocks * 3: block config should be [1, 256, 10779, 14] + With non-contiguous PA: block config should be + [1, 256, ceil(91964/128)*256, ceil(log2(that))+1] """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() @@ -327,13 +327,11 @@ def test_real_scenario_decode_cfg_matches_fixed_log(mock_get_config): max_model_len=_REAL_MAX_MODEL_LEN, max_blocks=_REAL_MAX_BLOCKS) - # Expected: [1, 256, 10779, 14] + expected_max = math.ceil(_REAL_MAX_MODEL_LEN / _REAL_BLOCK_SIZE) * _REAL_MAX_NUM_SEQS + expected_limit = math.ceil(math.log2(expected_max)) + 1 assert block_cfg[0] == 1, f"block min: expected 1, got {block_cfg[0]}" assert block_cfg[1] == _REAL_MAX_NUM_SEQS, (f"block step: expected {_REAL_MAX_NUM_SEQS}, got {block_cfg[1]}") - assert block_cfg[2] == _REAL_FIXED_MAX_DECODE_BLOCKS, ( - f"block max: expected {_REAL_FIXED_MAX_DECODE_BLOCKS}, got {block_cfg[2]}") - import math - expected_limit = math.ceil(math.log2(_REAL_MAX_BLOCKS * 3)) + 1 # 14 + assert block_cfg[2] == expected_max, (f"block max: expected {expected_max}, got {block_cfg[2]}") assert block_cfg[3] == expected_limit, (f"block limit: expected {expected_limit}, got {block_cfg[3]}") @@ -357,10 +355,11 @@ def test_real_scenario_decode_cfg_matches_fixed_bs_log(mock_get_config): @patch('vllm_gaudi.extension.bucketing.exponential.get_config') -def test_real_scenario_decode_block_range_bounded(mock_get_config): - """Verify generated decode block range stays within bounds (real scenario). +def test_real_scenario_decode_block_range_within_cfg_max(mock_get_config): + """Verify generated decode block range stays within cfg max (real scenario). - Fixed log showed blocks up to 3721. Buggy run had blocks up to 183808. + The block range from get_range() extends up to max_decode_blocks. + Actual bounding per (bs, ctx) pair happens via filters in generate_buckets(). """ mock_get_config.return_value = _MockConfig(use_contiguous_pa=False) strategy = ExponentialBucketingStrategy() @@ -372,14 +371,9 @@ def test_real_scenario_decode_block_range_bounded(mock_get_config): max_blocks=_REAL_MAX_BLOCKS) block_range = strategy.get_range(block_cfg) + expected_max = math.ceil(_REAL_MAX_MODEL_LEN / _REAL_BLOCK_SIZE) * _REAL_MAX_NUM_SEQS - assert max(block_range) <= _REAL_FIXED_MAX_DECODE_BLOCKS, ( - f"Largest block bucket {max(block_range)} exceeds bounded max " - f"{_REAL_FIXED_MAX_DECODE_BLOCKS}") - assert max(block_range) < _REAL_BUGGY_MAX_DECODE_BLOCKS, ( - f"Block range still contains buggy value {_REAL_BUGGY_MAX_DECODE_BLOCKS}") - # Verify reasonable number of buckets (log showed 13 unique block values) - assert len(block_range) <= 20, (f"Too many block buckets: {len(block_range)}") + assert max(block_range) <= expected_max, (f"Largest block bucket {max(block_range)} exceeds cfg max {expected_max}") @patch('vllm_gaudi.extension.bucketing.exponential.get_config') @@ -533,6 +527,36 @@ def test_real_scenario_fallback_ctx_7408_not_truncated(): assert new_ctx == calc_fallback_value(7408, 32), (f"Fallback ctx {new_ctx} should equal calc_fallback_value result") +def test_exponential_decode_block_limit_uncapped(monkeypatch): + """Verify that decode block limit is computed from log2(max_decode_blocks). + + With the new approach, excessive warmup buckets are controlled by + filters in generate_buckets() (num_ctx_tokens_less_or_equal_batched_max_model_len) + rather than by capping the block limit in get_decode_cfgs(). + """ + monkeypatch.setenv("VLLM_EXPONENTIAL_BUCKETING", "true") + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true") + clear_config() + get_config() + + strategy = ExponentialBucketingStrategy() + max_num_seqs = 21 + block_size = 128 + max_num_batched_tokens = 8192 + max_model_len = 131072 + max_blocks = 65536 + + bs_cfg, query_cfg, block_cfg = strategy.get_decode_cfgs(max_num_seqs, block_size, max_num_batched_tokens, + max_model_len, max_blocks) + + # max_decode_blocks = min(65536, ceil(131072/128)*21) = min(65536, 21504) = 21504 + expected_max_decode_blocks = min(max_blocks, math.ceil(max_model_len / block_size) * max_num_seqs) + expected_limit = math.ceil(math.log2(expected_max_decode_blocks)) + 1 + assert block_cfg[2] == expected_max_decode_blocks, ( + f"Expected max_decode_blocks={expected_max_decode_blocks}, got {block_cfg[2]}") + assert block_cfg[3] == expected_limit, (f"Expected decode_blocks_limit={expected_limit}, got {block_cfg[3]}") + + # --- Padding-aware bucketing tests --- @@ -640,3 +664,66 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con max_blocks=3593) assert block_cfg == [3465, 128, 3593, 899, 25] + + +# --- Tests that num_ctx_tokens_less_or_equal_batched_max_model_len filter is applied --- + + +@pytest.mark.parametrize("use_contiguous_pa", [True, False], ids=["contiguous_pa", "non_contiguous_pa"]) +@pytest.mark.parametrize( + ("max_model_len", "block_size", "max_num_seqs", "max_blocks", "max_num_batched_tokens"), + [ + (91964, 128, 256, 3593, 2048), # Qwen3-32B real scenario + (4096, 128, 64, 500, 2048), # small model + (131072, 128, 21, 65536, 8192), # long context + ], + ids=["qwen3_32b", "small_model", "long_ctx"], +) +def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_model_len, block_size, max_num_seqs, + max_blocks, max_num_batched_tokens): + """Every decode bucket returned by generate_buckets must satisfy + num_ctx_tokens_less_or_equal_batched_max_model_len: + ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0]) + """ + monkeypatch.setenv("VLLM_CONTIGUOUS_PA", str(use_contiguous_pa).lower()) + clear_config() + get_config() + + strategy = ExponentialBucketingStrategy() + + bs_cfg, query_cfg, block_cfg = strategy.get_decode_cfgs( + max_num_seqs=max_num_seqs, + block_size=block_size, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + max_blocks=max_blocks, + ) + bs_range = strategy.get_range(bs_cfg) + query_range = strategy.get_range(query_cfg) + ctx_range = strategy.get_range(block_cfg) + + buckets = generate_buckets( + bs_range=bs_range, + query_range=query_range, + ctx_range=ctx_range, + is_prompt=False, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + max_num_prefill_seqs=1, + max_num_batched_tokens=max_num_batched_tokens, + block_size=block_size, + max_blocks=max_blocks, + ) + + ctx_min = ctx_range[0] + max_blocks_per_seq = math.ceil(max_model_len / block_size) + + violations = [] + for bs, query, ctx in buckets: + if ctx > ctx_min and ctx > max_blocks_per_seq * bs: + violations.append((bs, query, ctx)) + + assert not violations, (f"Found {len(violations)} decode bucket(s) violating " + f"ctx <= ceil(max_model_len/block_size) * bs " + f"(max_blocks_per_seq={max_blocks_per_seq}):\n" + + "\n".join(f" bs={bs}, query={query}, ctx={ctx}" for bs, query, ctx in violations[:20])) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index f35cb445f6..2fd85dbdae 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -1,3 +1,4 @@ +import logging import os import bisect import math @@ -446,6 +447,14 @@ def batch_size_smaller_than_blocks(bs, query, ctx): omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx)) return bs <= ctx + def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx): + is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True + if not is_valid: + omitted_buckets.add( + ("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True", + "-> bs, query, ctx: ", bs, query, ctx)) + return is_valid + filters_map = { "prompt": { # depends only on merged_prefill @@ -454,8 +463,8 @@ def batch_size_smaller_than_blocks(bs, query, ctx): }, "decode": { # depends only on contiguous PA - True: [], - False: [batch_size_smaller_than_blocks], + True: [num_ctx_tokens_less_or_equal_batched_max_model_len], + False: [batch_size_smaller_than_blocks, num_ctx_tokens_less_or_equal_batched_max_model_len], } } @@ -519,6 +528,12 @@ def is_ctx_allowed(ctx): raise RuntimeError("Generated 0 " + phase + " buckets. Please adjust the bucketing configuration according to README") + if logger().getEffectiveLevel() <= logging.DEBUG and omitted_buckets: + phase = "prompt" if is_prompt else "decode" + omitted_buckets_str = "\n".join(map(str, sorted(omitted_buckets))) + msg = f"Omitted {len(omitted_buckets)} {phase} buckets:\n{omitted_buckets_str}" + logger().debug(msg) + return sorted(buckets) diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index eece7ef7af..312294995a 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -89,15 +89,10 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ decode_bs_limit = math.ceil(math.log2(max_num_seqs)) + 1 decode_bs_bucket_cfg = [1, 2, max_num_seqs, decode_bs_limit] decode_query_bucket_cfg = [1, 1, 1, 1] - # With non-contiguous PA, total block references across all sequences - # can exceed physical num_hpu_blocks (same physical block appears in - # multiple sequence block tables). Use 3x headroom so prepared buckets - # cover realistic prefix-sharing scenarios and avoid costly HPU graph - # recompilation at high KV-cache utilization. - max_decode_blocks = max_blocks if use_contiguous_pa else \ - max_blocks * 3 - max_decode_block_limit = math.ceil(math.log2(max_decode_blocks)) + 1 - decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, max_decode_block_limit] + max_decode_blocks = math.ceil(max_model_len / block_size) * max_num_seqs + max_decode_blocks = min(max_blocks, max_decode_blocks) if use_contiguous_pa else max_decode_blocks + decode_blocks_limit = math.ceil(math.log2(max_decode_blocks)) + 1 + decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, decode_blocks_limit] msg = ("Decode bucket config (min, step, max_warmup, limit) " f"bs:{decode_bs_bucket_cfg}, " diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index d50a8e57a2..8d941ae2bd 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -63,17 +63,18 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ if contiguous_pa: max_decode_blocks = max_blocks decode_block_bucket_cfg = read_bucket_settings('decode', 'block', min=1, step=block_size, max=max_decode_blocks) - if decode_block_bucket_cfg[2] > max_blocks: + if contiguous_pa and decode_block_bucket_cfg[2] > max_blocks: logger().info( f'VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MAX={max_blocks}' ) decode_block_bucket_cfg[2] = max_blocks - if decode_block_bucket_cfg[0] > max_blocks: - decode_block_bucket_min = max(1, max_blocks - decode_block_bucket_cfg[1]) - logger().info( - f'VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}' - ) - decode_block_bucket_cfg[0] = decode_block_bucket_min + + if decode_block_bucket_cfg[0] > decode_block_bucket_cfg[2]: + decode_block_bucket_min = max(1, decode_block_bucket_cfg[2] - decode_block_bucket_cfg[1]) + logger().info( + f"VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={decode_block_bucket_cfg[2]}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}" + ) + decode_block_bucket_cfg[0] = decode_block_bucket_min msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{decode_bs_bucket_cfg}, " diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index caafb5e0b0..1953de9720 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -5096,6 +5096,9 @@ def _add_dummy_request(self, num_scheduled_tokens[req_id] = scheduled_tokens def _generate_seq_lengths(self, num_samples, num_blocks, block_size): + # ensure the actual number of blocks is less than the KV cache blocks + num_blocks = min(self.kv_cache_config.num_blocks, num_blocks) + assert num_samples <= num_blocks blocks = [num_blocks // num_samples] * num_samples missing_blocks = num_blocks - sum(blocks)