Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ def initialize(self,
block_size,
max_num_batched_tokens,
max_model_len,
num_speculative_tokens=0):
num_speculative_tokens=0,
mamba_chunk_size=0):
self.max_num_seqs = max_num_seqs
self.max_num_prefill_seqs = max_num_prefill_seqs
self.block_size = block_size
self.max_num_batched_tokens = max_num_batched_tokens
self.num_hpu_blocks = None
self.max_model_len = max_model_len
self.num_speculative_tokens = num_speculative_tokens
self.mamba_chunk_size = mamba_chunk_size
self.initialized = True
self.fallback_bs_base_step = 2
self.fallback_seq_base_step = 32
Expand Down Expand Up @@ -156,7 +158,7 @@ def generate_prompt_buckets(self):
self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file)
buckets_from_file, self.mamba_chunk_size)
self.log_generate_info(True)
if self.use_sliding_window:
self.prompt_buckets = [
Expand Down Expand Up @@ -198,7 +200,7 @@ def generate_decode_buckets(self):
self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file)
buckets_from_file, self.mamba_chunk_size)
if self.num_speculative_tokens:
# The existing buckets are used as seed decode buckets
self.seed_decode_buckets = self.decode_buckets
Expand Down Expand Up @@ -347,7 +349,8 @@ def generate_buckets(bs_range,
max_num_batched_tokens,
block_size,
max_blocks,
file_buckets=None):
file_buckets=None,
mamba_chunk_size=0):
use_merged_prefill = get_config().merged_prefill
use_contiguous_pa = get_config().use_contiguous_pa

Expand Down Expand Up @@ -399,6 +402,9 @@ def no_corrections(bs, query, ctx):
def correct_for_max_model_len(bs, query, ctx):
return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size)))

def correct_for_mamba_chunk_size(bs, query, ctx):
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero will occur if mamba_chunk_size is 0. While the corrector is only called when mamba_chunk_size > 0 (line 433), the function itself doesn't enforce this constraint. Add a guard condition at the start of the function to prevent potential misuse.

Suggested change
def correct_for_mamba_chunk_size(bs, query, ctx):
def correct_for_mamba_chunk_size(bs, query, ctx):
if mamba_chunk_size <= 0:
raise ValueError("mamba_chunk_size must be greater than 0 to avoid division by zero.")

Copilot uses AI. Check for mistakes.
return (bs, math.ceil(query / mamba_chunk_size) * mamba_chunk_size, ctx)

def batch_size_smaller_than_blocks(bs, query, ctx):
if not bs <= ctx:
omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx))
Expand All @@ -424,7 +430,9 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa):
return filters_map[phase][use_contiguous_pa]

def get_corrector(is_prompt, use_contiguous_pa):
if is_prompt or use_contiguous_pa:
if is_prompt and mamba_chunk_size > 0:
return correct_for_mamba_chunk_size
elif is_prompt or use_contiguous_pa:
return no_corrections
else:
return correct_for_max_model_len
Expand Down
6 changes: 5 additions & 1 deletion vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,9 @@ def __init__(
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
self.is_multimodal_raw_input_supported = (model_config.is_multimodal_raw_input_only_model)

self.num_mamba_layers = self.model_config.get_num_layers_by_block_type(self.parallel_config, "mamba")
self.mamba_chunk_size = self.model_config.get_mamba_chunk_size() if self.num_mamba_layers > 0 else 0

# Lazy initialization
# self.model: nn.Module # set after load_model
self.kv_caches: list[torch.Tensor] = []
Expand Down Expand Up @@ -806,7 +809,8 @@ def __init__(
block_size=self.block_size,
max_num_batched_tokens=self.max_num_batched_tokens,
max_model_len=self.max_model_len,
num_speculative_tokens=num_speculative_tokens)
num_speculative_tokens=num_speculative_tokens,
mamba_chunk_size=self.mamba_chunk_size)
self.graphed_buckets: set[Any] = set()
self.graphed_multimodal_buckets: set[Any] = set()
else:
Expand Down