Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 123 additions & 36 deletions tests/unit_tests/test_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import math

import pytest
from unittest.mock import patch

Expand Down Expand Up @@ -189,25 +191,25 @@ def __init__(self, **kwargs):


@patch('vllm_gaudi.extension.bucketing.exponential.get_config')
def test_exponential_decode_cfgs_non_contiguous_pa_bounded(mock_get_config):
"""max_decode_blocks should be max_blocks * 3 when use_contiguous_pa=False.

The 3x multiplier accounts for prefix-cache block sharing: the same
physical block can appear in multiple sequences' block tables, so total
block references may exceed num_hpu_blocks.
def test_exponential_decode_cfgs_non_contiguous_pa_unbounded(mock_get_config):
"""max_decode_blocks should be ceil(max_model_len/block_size)*max_num_seqs
when use_contiguous_pa=False. Actual bounding of generated buckets
happens via filters in generate_buckets().
"""
mock_get_config.return_value = _MockConfig(use_contiguous_pa=False)
strategy = ExponentialBucketingStrategy()

max_blocks = 3593
block_size = 128
_, _, block_cfg = strategy.get_decode_cfgs(max_num_seqs=256,
max_model_len = 91964
max_num_seqs = 256
_, _, block_cfg = strategy.get_decode_cfgs(max_num_seqs=max_num_seqs,
block_size=block_size,
max_num_batched_tokens=131072,
max_model_len=91964,
max_model_len=max_model_len,
max_blocks=max_blocks)

expected_max = max_blocks * 3 # 10779
expected_max = math.ceil(max_model_len / block_size) * max_num_seqs
Copy link
Copy Markdown
Collaborator

@michalkuligowski michalkuligowski Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the maximum for expected_max value here? This value is calculated, in fact calculation is copied from the algorithm, this test should check specific values and that filter in fact reduces buckets number

Copy link
Copy Markdown
Collaborator Author

@yangulei yangulei Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected_max is calculated base on the corner decoding case with context_lens = [max_model_len - 1] + [1] * (max_num_seqs - 1). The context lens will be padded to the max one and the total decoding blocks could be calculated with math.ceil((max_model_len - 1) / block_size) * max_num_seqs which could be simplified to math.ceil(max_model_len / block_size) * max_num_seqs as max_model_len >> block_size for common cases.
For an example with max_model_len=16384, max_num_seqs=128 and block_size=128. The corner decoding case with context_lens = [16383, 1, 1, 1, ... 1, 1, 1] needs ceil(16383 / 128) * 128 = ceil(16384 / 128) * 128= 16384 blocks after padding.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the UT for the newly added filter is added. Thanks for remind that.

assert block_cfg[2] == expected_max, (f"Expected max_decode_blocks={expected_max}, got {block_cfg[2]}")


Expand All @@ -229,8 +231,12 @@ def test_exponential_decode_cfgs_contiguous_pa_uses_max_blocks(mock_get_config):


@patch('vllm_gaudi.extension.bucketing.exponential.get_config')
def test_exponential_decode_max_never_exceeds_bounded_value(mock_get_config):
"""Regression test: large max_model_len must NOT produce gigantic decode buckets."""
def test_exponential_decode_cfgs_non_contiguous_pa_formula(mock_get_config):
"""Verify non-contiguous PA decode cfg uses ceil(max_model_len/block_size)*max_num_seqs.

Actual bounding of excessive buckets happens via the
num_ctx_tokens_less_or_equal_batched_max_model_len filter in generate_buckets().
"""
mock_get_config.return_value = _MockConfig(use_contiguous_pa=False)
strategy = ExponentialBucketingStrategy()

Expand All @@ -245,14 +251,8 @@ def test_exponential_decode_max_never_exceeds_bounded_value(mock_get_config):
max_model_len=max_model_len,
max_blocks=max_blocks)

# The old (buggy) formula would produce min(91964//128*256, ...) = 183808
# The fix should give max_blocks * 3 = 10779
assert block_cfg[2] <= max_blocks * 3, (f"Decode bucket max {block_cfg[2]} exceeds bounded limit "
f"{max_blocks * 3}. Buckets are too large!")
# Sanity: must not be the old gigantic value
old_buggy_value = max_model_len // block_size * max_num_seqs
assert block_cfg[2] < old_buggy_value, (f"Decode bucket max {block_cfg[2]} matches buggy formula output "
f"{old_buggy_value}")
expected_max = math.ceil(max_model_len / block_size) * max_num_seqs
assert block_cfg[2] == expected_max, (f"Expected max_decode_blocks={expected_max}, got {block_cfg[2]}")


@patch('vllm_gaudi.extension.bucketing.exponential.get_config')
Expand Down Expand Up @@ -315,8 +315,8 @@ def test_fallback_bucket_ctx_uses_calc_fallback():
@patch('vllm_gaudi.extension.bucketing.exponential.get_config')
def test_real_scenario_decode_cfg_matches_fixed_log(mock_get_config):
"""Verify decode bucket config matches expected values for real scenario.

With max_blocks * 3: block config should be [1, 256, 10779, 14]
With non-contiguous PA: block config should be
[1, 256, ceil(91964/128)*256, ceil(log2(that))+1]
"""
mock_get_config.return_value = _MockConfig(use_contiguous_pa=False)
strategy = ExponentialBucketingStrategy()
Expand All @@ -327,13 +327,11 @@ def test_real_scenario_decode_cfg_matches_fixed_log(mock_get_config):
max_model_len=_REAL_MAX_MODEL_LEN,
max_blocks=_REAL_MAX_BLOCKS)

# Expected: [1, 256, 10779, 14]
expected_max = math.ceil(_REAL_MAX_MODEL_LEN / _REAL_BLOCK_SIZE) * _REAL_MAX_NUM_SEQS
expected_limit = math.ceil(math.log2(expected_max)) + 1
assert block_cfg[0] == 1, f"block min: expected 1, got {block_cfg[0]}"
assert block_cfg[1] == _REAL_MAX_NUM_SEQS, (f"block step: expected {_REAL_MAX_NUM_SEQS}, got {block_cfg[1]}")
assert block_cfg[2] == _REAL_FIXED_MAX_DECODE_BLOCKS, (
f"block max: expected {_REAL_FIXED_MAX_DECODE_BLOCKS}, got {block_cfg[2]}")
import math
expected_limit = math.ceil(math.log2(_REAL_MAX_BLOCKS * 3)) + 1 # 14
assert block_cfg[2] == expected_max, (f"block max: expected {expected_max}, got {block_cfg[2]}")
assert block_cfg[3] == expected_limit, (f"block limit: expected {expected_limit}, got {block_cfg[3]}")


Expand All @@ -357,10 +355,11 @@ def test_real_scenario_decode_cfg_matches_fixed_bs_log(mock_get_config):


@patch('vllm_gaudi.extension.bucketing.exponential.get_config')
def test_real_scenario_decode_block_range_bounded(mock_get_config):
"""Verify generated decode block range stays within bounds (real scenario).
def test_real_scenario_decode_block_range_within_cfg_max(mock_get_config):
"""Verify generated decode block range stays within cfg max (real scenario).

Fixed log showed blocks up to 3721. Buggy run had blocks up to 183808.
The block range from get_range() extends up to max_decode_blocks.
Actual bounding per (bs, ctx) pair happens via filters in generate_buckets().
"""
mock_get_config.return_value = _MockConfig(use_contiguous_pa=False)
strategy = ExponentialBucketingStrategy()
Expand All @@ -372,14 +371,9 @@ def test_real_scenario_decode_block_range_bounded(mock_get_config):
max_blocks=_REAL_MAX_BLOCKS)

block_range = strategy.get_range(block_cfg)
expected_max = math.ceil(_REAL_MAX_MODEL_LEN / _REAL_BLOCK_SIZE) * _REAL_MAX_NUM_SEQS

assert max(block_range) <= _REAL_FIXED_MAX_DECODE_BLOCKS, (
f"Largest block bucket {max(block_range)} exceeds bounded max "
f"{_REAL_FIXED_MAX_DECODE_BLOCKS}")
assert max(block_range) < _REAL_BUGGY_MAX_DECODE_BLOCKS, (
f"Block range still contains buggy value {_REAL_BUGGY_MAX_DECODE_BLOCKS}")
# Verify reasonable number of buckets (log showed 13 unique block values)
assert len(block_range) <= 20, (f"Too many block buckets: {len(block_range)}")
assert max(block_range) <= expected_max, (f"Largest block bucket {max(block_range)} exceeds cfg max {expected_max}")


@patch('vllm_gaudi.extension.bucketing.exponential.get_config')
Expand Down Expand Up @@ -533,6 +527,36 @@ def test_real_scenario_fallback_ctx_7408_not_truncated():
assert new_ctx == calc_fallback_value(7408, 32), (f"Fallback ctx {new_ctx} should equal calc_fallback_value result")


def test_exponential_decode_block_limit_uncapped(monkeypatch):
"""Verify that decode block limit is computed from log2(max_decode_blocks).

With the new approach, excessive warmup buckets are controlled by
filters in generate_buckets() (num_ctx_tokens_less_or_equal_batched_max_model_len)
rather than by capping the block limit in get_decode_cfgs().
"""
monkeypatch.setenv("VLLM_EXPONENTIAL_BUCKETING", "true")
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", "true")
clear_config()
get_config()

strategy = ExponentialBucketingStrategy()
max_num_seqs = 21
block_size = 128
max_num_batched_tokens = 8192
max_model_len = 131072
max_blocks = 65536

bs_cfg, query_cfg, block_cfg = strategy.get_decode_cfgs(max_num_seqs, block_size, max_num_batched_tokens,
max_model_len, max_blocks)

# max_decode_blocks = min(65536, ceil(131072/128)*21) = min(65536, 21504) = 21504
expected_max_decode_blocks = min(max_blocks, math.ceil(max_model_len / block_size) * max_num_seqs)
expected_limit = math.ceil(math.log2(expected_max_decode_blocks)) + 1
assert block_cfg[2] == expected_max_decode_blocks, (
f"Expected max_decode_blocks={expected_max_decode_blocks}, got {block_cfg[2]}")
assert block_cfg[3] == expected_limit, (f"Expected decode_blocks_limit={expected_limit}, got {block_cfg[3]}")


# --- Padding-aware bucketing tests ---


Expand Down Expand Up @@ -640,3 +664,66 @@ def test_padding_aware_decode_cfgs_contiguous_pa_clamps_block_range(mock_get_con
max_blocks=3593)

assert block_cfg == [3465, 128, 3593, 899, 25]


# --- Tests that num_ctx_tokens_less_or_equal_batched_max_model_len filter is applied ---


@pytest.mark.parametrize("use_contiguous_pa", [True, False], ids=["contiguous_pa", "non_contiguous_pa"])
@pytest.mark.parametrize(
("max_model_len", "block_size", "max_num_seqs", "max_blocks", "max_num_batched_tokens"),
[
(91964, 128, 256, 3593, 2048), # Qwen3-32B real scenario
(4096, 128, 64, 500, 2048), # small model
(131072, 128, 21, 65536, 8192), # long context
],
ids=["qwen3_32b", "small_model", "long_ctx"],
)
def test_decode_buckets_satisfy_ctx_filter(monkeypatch, use_contiguous_pa, max_model_len, block_size, max_num_seqs,
max_blocks, max_num_batched_tokens):
"""Every decode bucket returned by generate_buckets must satisfy
num_ctx_tokens_less_or_equal_batched_max_model_len:
ctx <= ceil(max_model_len / block_size) * bs (when ctx > ctx_range[0])
"""
monkeypatch.setenv("VLLM_CONTIGUOUS_PA", str(use_contiguous_pa).lower())
clear_config()
get_config()

strategy = ExponentialBucketingStrategy()

bs_cfg, query_cfg, block_cfg = strategy.get_decode_cfgs(
max_num_seqs=max_num_seqs,
block_size=block_size,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
max_blocks=max_blocks,
)
bs_range = strategy.get_range(bs_cfg)
query_range = strategy.get_range(query_cfg)
ctx_range = strategy.get_range(block_cfg)

buckets = generate_buckets(
bs_range=bs_range,
query_range=query_range,
ctx_range=ctx_range,
is_prompt=False,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
max_num_prefill_seqs=1,
max_num_batched_tokens=max_num_batched_tokens,
block_size=block_size,
max_blocks=max_blocks,
)

ctx_min = ctx_range[0]
max_blocks_per_seq = math.ceil(max_model_len / block_size)

violations = []
for bs, query, ctx in buckets:
if ctx > ctx_min and ctx > max_blocks_per_seq * bs:
violations.append((bs, query, ctx))

assert not violations, (f"Found {len(violations)} decode bucket(s) violating "
f"ctx <= ceil(max_model_len/block_size) * bs "
f"(max_blocks_per_seq={max_blocks_per_seq}):\n" +
"\n".join(f" bs={bs}, query={query}, ctx={ctx}" for bs, query, ctx in violations[:20]))
19 changes: 17 additions & 2 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import bisect
import math
Expand Down Expand Up @@ -446,6 +447,14 @@ def batch_size_smaller_than_blocks(bs, query, ctx):
omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx))
return bs <= ctx

def num_ctx_tokens_less_or_equal_batched_max_model_len(bs, query, ctx):
is_valid = ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True
if not is_valid:
omitted_buckets.add(
("condition: ctx <= math.ceil(max_model_len / block_size) * bs if ctx > ctx_range[0] else True",
"-> bs, query, ctx: ", bs, query, ctx))
return is_valid

filters_map = {
"prompt": {
# depends only on merged_prefill
Expand All @@ -454,8 +463,8 @@ def batch_size_smaller_than_blocks(bs, query, ctx):
},
"decode": {
# depends only on contiguous PA
True: [],
False: [batch_size_smaller_than_blocks],
True: [num_ctx_tokens_less_or_equal_batched_max_model_len],
False: [batch_size_smaller_than_blocks, num_ctx_tokens_less_or_equal_batched_max_model_len],
}
}

Expand Down Expand Up @@ -519,6 +528,12 @@ def is_ctx_allowed(ctx):
raise RuntimeError("Generated 0 " + phase +
" buckets. Please adjust the bucketing configuration according to README")

if logger().getEffectiveLevel() <= logging.DEBUG and omitted_buckets:
phase = "prompt" if is_prompt else "decode"
omitted_buckets_str = "\n".join(map(str, sorted(omitted_buckets)))
msg = f"Omitted {len(omitted_buckets)} {phase} buckets:\n{omitted_buckets_str}"
logger().debug(msg)

return sorted(buckets)


Expand Down
13 changes: 4 additions & 9 deletions vllm_gaudi/extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,10 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_
decode_bs_limit = math.ceil(math.log2(max_num_seqs)) + 1
decode_bs_bucket_cfg = [1, 2, max_num_seqs, decode_bs_limit]
decode_query_bucket_cfg = [1, 1, 1, 1]
# With non-contiguous PA, total block references across all sequences
# can exceed physical num_hpu_blocks (same physical block appears in
# multiple sequence block tables). Use 3x headroom so prepared buckets
# cover realistic prefix-sharing scenarios and avoid costly HPU graph
# recompilation at high KV-cache utilization.
max_decode_blocks = max_blocks if use_contiguous_pa else \
max_blocks * 3
max_decode_block_limit = math.ceil(math.log2(max_decode_blocks)) + 1
decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, max_decode_block_limit]
max_decode_blocks = math.ceil(max_model_len / block_size) * max_num_seqs
max_decode_blocks = min(max_blocks, max_decode_blocks) if use_contiguous_pa else max_decode_blocks
decode_blocks_limit = math.ceil(math.log2(max_decode_blocks)) + 1
decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, decode_blocks_limit]

msg = ("Decode bucket config (min, step, max_warmup, limit) "
f"bs:{decode_bs_bucket_cfg}, "
Expand Down
15 changes: 8 additions & 7 deletions vllm_gaudi/extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,18 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_
if contiguous_pa:
max_decode_blocks = max_blocks
decode_block_bucket_cfg = read_bucket_settings('decode', 'block', min=1, step=block_size, max=max_decode_blocks)
if decode_block_bucket_cfg[2] > max_blocks:
if contiguous_pa and decode_block_bucket_cfg[2] > max_blocks:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the decode bucket count now?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug fix for the missing buckets which cause "not warmed-up" warnings for non-contiguous PA cases. The actual number of decode buckets for linear bucketing is sensitive to the *_BUCKET_STEP_* configuration. And the default settings usually produce too many buckets that needs hours even days to warmup for cases with long max_model_len.

logger().info(
f'VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MAX={max_blocks}'
)
decode_block_bucket_cfg[2] = max_blocks
if decode_block_bucket_cfg[0] > max_blocks:
decode_block_bucket_min = max(1, max_blocks - decode_block_bucket_cfg[1])
logger().info(
f'VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}'
)
decode_block_bucket_cfg[0] = decode_block_bucket_min

if decode_block_bucket_cfg[0] > decode_block_bucket_cfg[2]:
decode_block_bucket_min = max(1, decode_block_bucket_cfg[2] - decode_block_bucket_cfg[1])
logger().info(
f"VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={decode_block_bucket_cfg[2]}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}"
)
decode_block_bucket_cfg[0] = decode_block_bucket_min

msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{decode_bs_bucket_cfg}, "
Expand Down
3 changes: 3 additions & 0 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5096,6 +5096,9 @@ def _add_dummy_request(self,
num_scheduled_tokens[req_id] = scheduled_tokens

def _generate_seq_lengths(self, num_samples, num_blocks, block_size):
# ensure the actual number of blocks is less than the KV cache blocks
num_blocks = min(self.kv_cache_config.num_blocks, num_blocks)

Comment thread
yangulei marked this conversation as resolved.
assert num_samples <= num_blocks
blocks = [num_blocks // num_samples] * num_samples
missing_blocks = num_blocks - sum(blocks)
Expand Down
Loading