diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 74c8bb14b1..84e93ad1dc 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -55,7 +55,8 @@ def initialize(self, block_size, max_num_batched_tokens, max_model_len, - num_speculative_tokens=0): + num_speculative_tokens=0, + mamba_chunk_size=0): self.max_num_seqs = max_num_seqs self.max_num_prefill_seqs = max_num_prefill_seqs self.block_size = block_size @@ -63,6 +64,7 @@ def initialize(self, self.num_hpu_blocks = None self.max_model_len = max_model_len self.num_speculative_tokens = num_speculative_tokens + self.mamba_chunk_size = mamba_chunk_size self.initialized = True self.fallback_bs_base_step = 2 self.fallback_seq_base_step = 32 @@ -156,7 +158,7 @@ def generate_prompt_buckets(self): self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len, self.max_num_seqs, self.max_num_prefill_seqs, self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks, - buckets_from_file) + buckets_from_file, self.mamba_chunk_size) self.log_generate_info(True) if self.use_sliding_window: self.prompt_buckets = [ @@ -198,7 +200,7 @@ def generate_decode_buckets(self): self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len, self.max_num_seqs, self.max_num_prefill_seqs, self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks, - buckets_from_file) + buckets_from_file, self.mamba_chunk_size) if self.num_speculative_tokens: # The existing buckets are used as seed decode buckets self.seed_decode_buckets = self.decode_buckets @@ -347,7 +349,8 @@ def generate_buckets(bs_range, max_num_batched_tokens, block_size, max_blocks, - file_buckets=None): + file_buckets=None, + mamba_chunk_size=0): use_merged_prefill = get_config().merged_prefill use_contiguous_pa = get_config().use_contiguous_pa @@ -399,6 +402,9 @@ def no_corrections(bs, query, ctx): def correct_for_max_model_len(bs, query, ctx): return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size))) + def correct_for_mamba_chunk_size(bs, query, ctx): + return (bs, math.ceil(query / mamba_chunk_size) * mamba_chunk_size, ctx) + def batch_size_smaller_than_blocks(bs, query, ctx): if not bs <= ctx: omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx)) @@ -424,7 +430,9 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa): return filters_map[phase][use_contiguous_pa] def get_corrector(is_prompt, use_contiguous_pa): - if is_prompt or use_contiguous_pa: + if is_prompt and mamba_chunk_size > 0: + return correct_for_mamba_chunk_size + elif is_prompt or use_contiguous_pa: return no_corrections else: return correct_for_max_model_len diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index c9488523e7..e2422da908 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -715,6 +715,9 @@ def __init__( self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) self.is_multimodal_raw_input_supported = (model_config.is_multimodal_raw_input_only_model) + self.num_mamba_layers = self.model_config.get_num_layers_by_block_type(self.parallel_config, "mamba") + self.mamba_chunk_size = self.model_config.get_mamba_chunk_size() if self.num_mamba_layers > 0 else 0 + # Lazy initialization # self.model: nn.Module # set after load_model self.kv_caches: list[torch.Tensor] = [] @@ -806,7 +809,8 @@ def __init__( block_size=self.block_size, max_num_batched_tokens=self.max_num_batched_tokens, max_model_len=self.max_model_len, - num_speculative_tokens=num_speculative_tokens) + num_speculative_tokens=num_speculative_tokens, + mamba_chunk_size=self.mamba_chunk_size) self.graphed_buckets: set[Any] = set() self.graphed_multimodal_buckets: set[Any] = set() else: