From c708f8d42ebca7f90441c8f205107f2edf8d3460 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 26 Jan 2026 15:21:25 +0000 Subject: [PATCH 01/11] update CODEOWNERS for v0.14.1 Signed-off-by: Luca Calabria --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3b7fbad7e6..82c0b38156 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @mgawarkiewicz-intel @wpyszka @piotrbocian @adobrzyn +* @mgawarkiewicz-intel @wpyszka @piotrbocian From b95ee8d57949836de9ea5c7a4962d0c3724f8d5e Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Mon, 26 Jan 2026 23:31:04 +0100 Subject: [PATCH 02/11] cherry pick llama4 chunked attn and 32k+ context fix Signed-off-by: Luca Calabria --- vllm_gaudi/attention/backends/hpu_attn.py | 18 ++ vllm_gaudi/ops/hpu_fused_moe.py | 5 +- vllm_gaudi/v1/attention/backends/hpu_attn.py | 6 + vllm_gaudi/v1/spec_decode/hpu_eagle.py | 3 + vllm_gaudi/v1/worker/hpu_model_runner.py | 168 +++++++++++++++++-- 5 files changed, 183 insertions(+), 17 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 5bd6460ed2..6a681c687c 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -169,6 +169,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): window_block_groups: Optional[torch.Tensor] = None window_block_usage: Optional[torch.Tensor] = None window_attn_bias: Optional[torch.Tensor] = None + chunked_slot_mapping: Optional[torch.Tensor] = None + chunked_attn_bias: Optional[torch.Tensor] = None + chunked_block_mapping: Optional[torch.Tensor] = None + chunked_block_list: Optional[torch.Tensor] = None + chunked_block_groups: Optional[torch.Tensor] = None + chunked_block_usage: Optional[torch.Tensor] = None @dataclass @@ -492,6 +498,8 @@ def __init__( raise NotImplementedError("Encoder self-attention " "is not implemented for " "HPUAttentionImpl") + + self.is_chunked_attention = False def _maybe_init_alibi_biases( self, @@ -630,6 +638,10 @@ def forward( attn_bias = None window_size = (self.sliding_window, 0) common_args['window_size'] = window_size + + if self.is_chunked_attention and \ + hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None: + attn_bias = attn_metadata.chunked_attn_bias out = ops.prompt_attention(impl=self.prefill_impl, query=query.view(query_shape), @@ -650,6 +662,12 @@ def forward( block_groups = attn_metadata.window_block_groups block_mapping = attn_metadata.window_block_mapping attn_bias = attn_metadata.window_attn_bias + elif self.is_chunked_attention and \ + attn_metadata.chunked_block_list is not None: + block_list = attn_metadata.chunked_block_list + block_groups = attn_metadata.chunked_block_groups + block_mapping = attn_metadata.chunked_block_mapping + attn_bias = attn_metadata.chunked_attn_bias else: block_list = attn_metadata.block_list block_groups = attn_metadata.block_groups diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index a27710fa74..4168d515ee 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -160,7 +160,10 @@ def forward_oot( permuted_weights=True, activation=layer.activation, ) - return output.view(*(output.size(0), *input_shape[1:])) + if layer.dp_size > 1: + return output.view(*(output.size(0), *input_shape[1:])) + else: + return output.view(*input_shape) def reduce_output(self, states: torch.Tensor) -> torch.Tensor: diff --git a/vllm_gaudi/v1/attention/backends/hpu_attn.py b/vllm_gaudi/v1/attention/backends/hpu_attn.py index 1d2c1d2363..978cec0cd9 100644 --- a/vllm_gaudi/v1/attention/backends/hpu_attn.py +++ b/vllm_gaudi/v1/attention/backends/hpu_attn.py @@ -87,6 +87,9 @@ def make_decode_metadata(cls, window_block_list, window_block_usage, window_block_groups, + chunked_block_list, + chunked_block_usage, + chunked_block_groups, query_start_loc=None): return cls(is_prompt=False, block_mapping=None, @@ -100,6 +103,9 @@ def make_decode_metadata(cls, window_block_list=window_block_list, window_block_usage=window_block_usage, window_block_groups=window_block_groups, + chunked_block_list=chunked_block_list, + chunked_block_usage=chunked_block_usage, + chunked_block_groups=chunked_block_groups, input_positions=input_positions, slot_mapping=slot_mapping, block_size=block_size, diff --git a/vllm_gaudi/v1/spec_decode/hpu_eagle.py b/vllm_gaudi/v1/spec_decode/hpu_eagle.py index 7a0d508b0a..1c9159a810 100644 --- a/vllm_gaudi/v1/spec_decode/hpu_eagle.py +++ b/vllm_gaudi/v1/spec_decode/hpu_eagle.py @@ -237,6 +237,9 @@ def prepare_attn_metadata( window_block_list=None, window_block_usage=None, window_block_groups=None, + chunked_block_list=None, + chunked_block_usage=None, + chunked_block_groups=None, ) return common_attn_metadata diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index c9488523e7..b80ceda4a6 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -463,10 +463,12 @@ def forward(self, *args, **kwargs): if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] + model_has_chunked_attention = kwargs.pop('model_has_chunked_attention', False) if not self.unified_attn: kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + input_ids.device, self.dtype, + model_has_chunked_attention) if self._rotary_prepare_cos_sin is not None: self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') @@ -591,6 +593,11 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object: 'window_block_usage', 'window_block_groups', 'window_attn_bias', + 'chunked_block_mapping', + 'chunked_attn_bias', + 'chunked_block_list', + 'chunked_block_usage', + 'chunked_block_groups', ]) return attention_metadata @@ -838,6 +845,8 @@ def __init__( self.scheduler_output: SchedulerOutput | None = None self.warmup_mode: bool = False self.batch_changed: bool = False + # WA for chunked attention support + self.model_has_chunked_attention = False assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' @@ -1411,6 +1420,18 @@ def _get_num_decodes(self) -> int: continue num_decodes += 1 return num_decodes + + def maybe_set_chunked_attention_layers(self, model): + if hasattr(model.config, 'text_config') and \ + hasattr(model.config.text_config, 'attention_chunk_size') and \ + model.config.text_config.attention_chunk_size: + self.model_has_chunked_attention = True + try: + for layer in model.language_model.model.layers: + if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: + layer.self_attn.attn.impl.is_chunked_attention = True + except Exception: + pass def _get_prompts_and_decodes( self, @@ -2128,6 +2149,19 @@ def _create_decode_input_data(self, window_block_tables, slot_mapping.tolist(), padded_batch_size * num_tokens) + if self.model_has_chunked_attention: + chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size) + seq_lens_block = [len(block_table) for block_table in block_tables_list] + num_seq_chunks = [math.ceil(sl / chunk_size_in_blocks) - 1 for sl in seq_lens_block] + block_tables_chunk = [ + block_table[num_seq_chunks[i] * chunk_size_in_blocks:] + for i, block_table in enumerate(block_tables_list) + ] + chunked_block_list, chunked_block_groups, chunked_block_usage = \ + self.get_habana_paged_attn_buffers( + block_tables_chunk, slot_mapping.tolist(), + padded_batch_size * num_tokens) + # CPU<>HPU sync *should not* happen here. block_list_device = async_h2d_copy(block_list, device=self.device) block_usage_device = async_h2d_copy(block_usage, device=self.device) @@ -2139,7 +2173,13 @@ def _create_decode_input_data(self, device=self.device) if self.interleaved_sliding_window else None window_block_groups_device = async_h2d_copy(window_block_groups, device=self.device) if self.interleaved_sliding_window else None - + chunked_block_list_device = async_h2d_copy(chunked_block_list, + device=self.device) if self.model_has_chunked_attention else None + chunked_block_usage_device = async_h2d_copy(chunked_block_usage, + device=self.device) if self.model_has_chunked_attention else None + chunked_block_groups_device = async_h2d_copy(chunked_block_groups, + device=self.device) if self.model_has_chunked_attention else None + token_ids_device = async_h2d_copy(token_ids, device=self.device) # when DP also enabled, some DP ranks will exeucte dummy run with empty # SchedulerOutput, in this case we need skip the prepare_input_ids @@ -2169,21 +2209,26 @@ def _create_decode_input_data(self, spec_decode_metadata = None logits_indices_device = async_h2d_copy(logits_indices, device=self.device) + attn_metadata = HPUAttentionMetadataV1.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + input_positions=None, + slot_mapping=slot_mapping_device, + block_size=self.block_size, + window_block_list=window_block_list_device, + window_block_usage=window_block_usage_device, + window_block_groups=window_block_groups_device, + chunked_block_list=chunked_block_list_device, + chunked_block_usage=chunked_block_usage_device, + chunked_block_groups=chunked_block_groups_device, + ) + return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids_device, position_ids=positions_device, logits_indices=logits_indices_device, - attn_metadata=HPUAttentionMetadataV1.make_decode_metadata( - block_list=block_list_device, - block_usage=block_usage_device, - block_groups=block_groups_device, - input_positions=None, - slot_mapping=slot_mapping_device, - block_size=self.block_size, - window_block_list=window_block_list_device, - window_block_usage=window_block_usage_device, - window_block_groups=window_block_groups_device, - ), + attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata) def _prepare_decode_inputs(self, @@ -2611,6 +2656,8 @@ def _execute_model_generic(self, else: # no hpu graphs for t.compile? use_graphs = False + if self.model_has_chunked_attention: + additional_kwargs.update({"model_has_chunked_attention": True}) trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata) if self.is_driver_worker: model_event_name = ("model_forward_" @@ -3718,6 +3765,7 @@ def load_model(self) -> None: htcore.mark_step() apply_model_specific_patches(self.model) + self.maybe_set_chunked_attention_layers(self.model) hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) modify_model_layers(self.model, @@ -5586,12 +5634,75 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata + def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + chunk_size: int, device: torch.device, + dtype: torch.dtype) -> HPUAttentionMetadataV1: + """Set attention bias for chunked attention. + Args: + attn_metadata (HPUAttentionMetadataV1): The attention metadata. + batch_size (int): The batch size. + seq_len (int): The sequence length. + chunk_size (int): The chunk size. + device (torch.device): The device to use. + dtype (torch.dtype): The data type. + Returns: + HPUAttentionMetadataV1: The updated attention metadata. + """ + + if (attn_metadata is None or not attn_metadata.is_prompt): + return attn_metadata + + prefill_metadata = attn_metadata + shift = 0 + + if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: + + context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None + block_list = prefill_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + query_positions = torch.arange(seq_len, device=device) + total_token_positions = context_lens_t.unsqueeze(-1) + query_positions.unsqueeze(0) + which_chunk = (total_token_positions // chunk_size) + chunk_start_positions = which_chunk * chunk_size + invalid_lens_t = chunk_start_positions - 1 + + past_indices = torch.arange(max_context_len, device=device) + past_mask = ( + (past_indices.unsqueeze(0).unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & + (past_indices.unsqueeze(0).unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(-1))).unsqueeze(1) + + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) + query_chunk_ids = which_chunk[0] + same_chunk_mask = query_chunk_ids.unsqueeze(0) == query_chunk_ids.unsqueeze(1) + + causal_mask = causal_mask & same_chunk_mask + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) + + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(float('-inf'), dtype=dtype, device=device)) + else: + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + idx = torch.arange(seq_len, device=device) + chunk_id = idx // chunk_size + same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) + same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) + mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) + attn_bias = torch.log(mask) + + attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) + return attn_metadata + def _set_block_mapping(self, metadata: HPUAttentionMetadataV1, batch_size: int, device: torch.device, dtype: torch.dtype, - is_window_block: bool = False) -> HPUAttentionMetadataV1: + is_window_block: bool = False, + update_for_chunked_attention: bool = False) -> HPUAttentionMetadataV1: """ Set block mapping for decode phase. @@ -5603,6 +5714,7 @@ def _set_block_mapping(self, device: Device to create tensors on dtype: Data type for tensors is_window_block: Whether this is for window blocks + update_for_chunked_attention: Whether to update for chunked attention Returns: Updated attention metadata with block_mapping and attn_bias set @@ -5610,6 +5722,9 @@ def _set_block_mapping(self, if is_window_block: block_usage = metadata.window_block_usage block_groups = metadata.window_block_groups + elif update_for_chunked_attention: + block_usage = metadata.chunked_block_usage + block_groups = metadata.chunked_block_groups else: block_usage = metadata.block_usage block_groups = metadata.block_groups @@ -5640,6 +5755,11 @@ def _set_block_mapping(self, "TrimmedAttentionMetadata", window_block_mapping=block_mapping, window_attn_bias=attn_bias) + elif update_for_chunked_attention: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + chunked_block_mapping=block_mapping, + chunked_attn_bias=attn_bias) else: metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", @@ -5647,8 +5767,13 @@ def _set_block_mapping(self, attn_bias=attn_bias) return metadata - def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, - device: torch.device, dtype: torch.dtype) -> HPUAttentionMetadataV1: + def process_metadata(self, + attn_metadata: HPUAttentionMetadataV1, + batch_size: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + model_has_chunked_attention: bool = False) -> HPUAttentionMetadataV1: """ Post-process attention metadata with appropriate masks and mappings. @@ -5662,6 +5787,7 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in seq_len: Sequence length (for prompt phase) device: Device to create tensors on dtype: Data type for tensors + model_has_chunked_attention: Whether the model has chunked attention Returns: Post-processed attention metadata with additional tensors @@ -5671,8 +5797,18 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in if self.interleaved_sliding_window: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) + if model_has_chunked_attention: + attention_chunk_size = self.vllm_config.model_config.hf_config.text_config.attention_chunk_size + attn_metadata = self._set_attn_bias_for_chunked_attention(attn_metadata, batch_size, seq_len, + attention_chunk_size, device, dtype) else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) + if model_has_chunked_attention: + attn_metadata = self._set_block_mapping(attn_metadata, + batch_size, + device, + dtype, + update_for_chunked_attention=True) if self.interleaved_sliding_window: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) return attn_metadata From 7fafd820f5e4e5c060a7258690b1c616df2821a2 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Tue, 27 Jan 2026 00:10:16 +0100 Subject: [PATCH 03/11] add comment Signed-off-by: Luca Calabria --- vllm_gaudi/ops/hpu_fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 4168d515ee..a096ce2ce0 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -160,6 +160,7 @@ def forward_oot( permuted_weights=True, activation=layer.activation, ) + # fix needed for llama4 when context len > 32K. Change in shape if layer.dp_size > 1: return output.view(*(output.size(0), *input_shape[1:])) else: From c537112c821fb063c6570bf5f7b1aeb06827bfb7 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Tue, 27 Jan 2026 00:29:03 +0100 Subject: [PATCH 04/11] fix format Signed-off-by: Luca Calabria --- vllm_gaudi/attention/backends/hpu_attn.py | 2 -- vllm_gaudi/v1/worker/hpu_model_runner.py | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 6a681c687c..ea9ec9ed8e 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -498,7 +498,6 @@ def __init__( raise NotImplementedError("Encoder self-attention " "is not implemented for " "HPUAttentionImpl") - self.is_chunked_attention = False def _maybe_init_alibi_biases( @@ -638,7 +637,6 @@ def forward( attn_bias = None window_size = (self.sliding_window, 0) common_args['window_size'] = window_size - if self.is_chunked_attention and \ hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None: attn_bias = attn_metadata.chunked_attn_bias diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index b80ceda4a6..5456dd8747 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1420,7 +1420,7 @@ def _get_num_decodes(self) -> int: continue num_decodes += 1 return num_decodes - + def maybe_set_chunked_attention_layers(self, model): if hasattr(model.config, 'text_config') and \ hasattr(model.config.text_config, 'attention_chunk_size') and \ @@ -2161,7 +2161,7 @@ def _create_decode_input_data(self, self.get_habana_paged_attn_buffers( block_tables_chunk, slot_mapping.tolist(), padded_batch_size * num_tokens) - + # CPU<>HPU sync *should not* happen here. block_list_device = async_h2d_copy(block_list, device=self.device) block_usage_device = async_h2d_copy(block_usage, device=self.device) @@ -2179,7 +2179,6 @@ def _create_decode_input_data(self, device=self.device) if self.model_has_chunked_attention else None chunked_block_groups_device = async_h2d_copy(chunked_block_groups, device=self.device) if self.model_has_chunked_attention else None - token_ids_device = async_h2d_copy(token_ids, device=self.device) # when DP also enabled, some DP ranks will exeucte dummy run with empty # SchedulerOutput, in this case we need skip the prepare_input_ids From 306ab473e4d149c167ebd0ac31d38bcc8a74198b Mon Sep 17 00:00:00 2001 From: Jakub Byczkowski Date: Wed, 28 Jan 2026 08:03:36 +0100 Subject: [PATCH 05/11] Implement bucket corrector for Mamba chunk size - v0.14.1 (#885) Due to MambaMixer2 implementation requirements, all buckets used for mamba must be a multiple of mamba chunk size. Signed-off-by: Jakub Byczkowski Signed-off-by: Luca Calabria --- vllm_gaudi/extension/bucketing/common.py | 18 +++++++++++++----- vllm_gaudi/v1/worker/hpu_model_runner.py | 6 +++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 74c8bb14b1..84e93ad1dc 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -55,7 +55,8 @@ def initialize(self, block_size, max_num_batched_tokens, max_model_len, - num_speculative_tokens=0): + num_speculative_tokens=0, + mamba_chunk_size=0): self.max_num_seqs = max_num_seqs self.max_num_prefill_seqs = max_num_prefill_seqs self.block_size = block_size @@ -63,6 +64,7 @@ def initialize(self, self.num_hpu_blocks = None self.max_model_len = max_model_len self.num_speculative_tokens = num_speculative_tokens + self.mamba_chunk_size = mamba_chunk_size self.initialized = True self.fallback_bs_base_step = 2 self.fallback_seq_base_step = 32 @@ -156,7 +158,7 @@ def generate_prompt_buckets(self): self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len, self.max_num_seqs, self.max_num_prefill_seqs, self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks, - buckets_from_file) + buckets_from_file, self.mamba_chunk_size) self.log_generate_info(True) if self.use_sliding_window: self.prompt_buckets = [ @@ -198,7 +200,7 @@ def generate_decode_buckets(self): self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len, self.max_num_seqs, self.max_num_prefill_seqs, self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks, - buckets_from_file) + buckets_from_file, self.mamba_chunk_size) if self.num_speculative_tokens: # The existing buckets are used as seed decode buckets self.seed_decode_buckets = self.decode_buckets @@ -347,7 +349,8 @@ def generate_buckets(bs_range, max_num_batched_tokens, block_size, max_blocks, - file_buckets=None): + file_buckets=None, + mamba_chunk_size=0): use_merged_prefill = get_config().merged_prefill use_contiguous_pa = get_config().use_contiguous_pa @@ -399,6 +402,9 @@ def no_corrections(bs, query, ctx): def correct_for_max_model_len(bs, query, ctx): return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size))) + def correct_for_mamba_chunk_size(bs, query, ctx): + return (bs, math.ceil(query / mamba_chunk_size) * mamba_chunk_size, ctx) + def batch_size_smaller_than_blocks(bs, query, ctx): if not bs <= ctx: omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx)) @@ -424,7 +430,9 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa): return filters_map[phase][use_contiguous_pa] def get_corrector(is_prompt, use_contiguous_pa): - if is_prompt or use_contiguous_pa: + if is_prompt and mamba_chunk_size > 0: + return correct_for_mamba_chunk_size + elif is_prompt or use_contiguous_pa: return no_corrections else: return correct_for_max_model_len diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 5456dd8747..363b3210e9 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -722,6 +722,9 @@ def __init__( self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) self.is_multimodal_raw_input_supported = (model_config.is_multimodal_raw_input_only_model) + self.num_mamba_layers = self.model_config.get_num_layers_by_block_type(self.parallel_config, "mamba") + self.mamba_chunk_size = self.model_config.get_mamba_chunk_size() if self.num_mamba_layers > 0 else 0 + # Lazy initialization # self.model: nn.Module # set after load_model self.kv_caches: list[torch.Tensor] = [] @@ -813,7 +816,8 @@ def __init__( block_size=self.block_size, max_num_batched_tokens=self.max_num_batched_tokens, max_model_len=self.max_model_len, - num_speculative_tokens=num_speculative_tokens) + num_speculative_tokens=num_speculative_tokens, + mamba_chunk_size=self.mamba_chunk_size) self.graphed_buckets: set[Any] = set() self.graphed_multimodal_buckets: set[Any] = set() else: From 8243450374fa40af2f9aef880a2f55fcd8504c1e Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:45:37 +0100 Subject: [PATCH 06/11] Revert "skip HPU graphs for long prefills" (#850) (#888) Reverts vllm-project/vllm-gaudi#780 --------- Signed-off-by: Agata Dobrzyniewicz Co-authored-by: Chendi.Xue Signed-off-by: Luca Calabria --- tests/full_tests/ci_gsm8k_tests.sh | 2 +- tests/full_tests/ci_perf_tests.sh | 2 +- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 ++++------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/full_tests/ci_gsm8k_tests.sh b/tests/full_tests/ci_gsm8k_tests.sh index 8225926295..0cba7bec7d 100644 --- a/tests/full_tests/ci_gsm8k_tests.sh +++ b/tests/full_tests/ci_gsm8k_tests.sh @@ -99,7 +99,7 @@ run_qwen3_compressed_tensor_dynamic_scaling_test() { # QWEN3 FP8 + MOE compressed tensor + dynamic scaling run_qwen3_moe_compressed_tensor_dynamic_scaling_test() { echo "➡️ Testing Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 + moe + compressed-tensor + dynamic scaling..." - HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --trust-remote-code + HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --trust-remote-code --max-model-len 131072 echo "✅ Test with Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 + moe + compressed-tensor + dynamic scaling successful." } diff --git a/tests/full_tests/ci_perf_tests.sh b/tests/full_tests/ci_perf_tests.sh index fb94a19565..2066572eea 100644 --- a/tests/full_tests/ci_perf_tests.sh +++ b/tests/full_tests/ci_perf_tests.sh @@ -37,4 +37,4 @@ vllm bench throughput \ --dataset_path ShareGPT_V3_unfiltered_cleaned_split.json \ --dataset_name sharegpt \ --num-prompts 1000 \ - --max-model-len 32768 + --max-model-len 16384 diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 363b3210e9..ce460bb3cb 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -796,14 +796,14 @@ def __init__( self.use_hpu_graph = not self.model_config.enforce_eager self.max_batch_size = self.scheduler_config.max_num_seqs self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size if prompt_profile_cfg: self.max_prefill_batch_size = prompt_profile_cfg[0] else: self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1) self.seen_configs: set = set() - self.max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - self.max_graph_capture_tokens = self.vllm_config.compilation_config.max_cudagraph_capture_size if \ - self.vllm_config.compilation_config.max_cudagraph_capture_size is not None else self.max_num_batched_tokens + self.max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens self.use_prefix_caching = (self.vllm_config.cache_config.enable_prefix_caching) self.bucketing_manager = HPUBucketingManager() max_num_prefill_seqs = self.max_num_seqs if self.use_merged_prefill \ @@ -2651,9 +2651,7 @@ def _execute_model_generic(self, additional_kwargs = {} if htorch.utils.internal.is_lazy(): use_graphs = self._use_graphs() - # skip HPU graphs for long prefills - if seq_len > 1 and \ - batch_size * (seq_len + num_blocks * self.block_size) > self.max_graph_capture_tokens: + if self.max_cudagraph_capture_size is not None and batch_size * seq_len > self.max_cudagraph_capture_size: use_graphs = False additional_kwargs.update({"bypass_hpu_graphs": not use_graphs}) else: From 6758ac64e7f13c432f15185df3d2042c4066edea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Smyrek?= Date: Wed, 28 Jan 2026 10:14:59 +0100 Subject: [PATCH 07/11] Cherry-picks to enable Llama4 Maverick (#882) 1. #805 2. #837 3. #855 4. #862 --------- Signed-off-by: Radoslaw Smyrek Signed-off-by: linoy buchnik Signed-off-by: Iryna Boiko Signed-off-by: Artur Fierka Co-authored-by: Linoy Buchnik Co-authored-by: Iryna Boiko Co-authored-by: Artur Fierka Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ce460bb3cb..ec4795f7f7 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -370,7 +370,8 @@ def patch_llama4_get_attn_scale(model): orig = attn._get_attn_scale def _get_attn_scale_for_hpu(self, positions, _orig=orig): - positions = positions.flatten() + if self.qk_norm is not None: + positions = positions.flatten() return _orig(positions) attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn) @@ -4669,8 +4670,7 @@ def warmup_multimodal_graphs(self, buckets): phase = 'Graph/Multimodal' from vllm.v1.worker.utils import MultiModalBudget self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, + self.vllm_config, self.mm_registry, ) if self.supports_mm_inputs else None @@ -5795,7 +5795,7 @@ def process_metadata(self, """ if attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) - if self.interleaved_sliding_window: + if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) if model_has_chunked_attention: @@ -5810,7 +5810,7 @@ def process_metadata(self, device, dtype, update_for_chunked_attention=True) - if self.interleaved_sliding_window: + if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) return attn_metadata From 7e5adea5369fe4cfccdb85a5e43afc16f63499c0 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Wed, 28 Jan 2026 13:59:58 +0100 Subject: [PATCH 08/11] moved as default function Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 34 ++++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ec4795f7f7..f4d61fc1d5 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -377,10 +377,23 @@ def _get_attn_scale_for_hpu(self, positions, _orig=orig): attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn) -def apply_model_specific_patches(model): - """The function applies model-specific monkey patches.""" +def maybe_set_chunked_attention_layers(model_runner): + if hasattr(model_runner.model.config, 'text_config') and \ + hasattr(model_runner.model.config.text_config, 'attention_chunk_size') and \ + model_runner.model.config.text_config.attention_chunk_size: + model_runner.model_has_chunked_attention = True + try: + for layer in model_runner.model.language_model.model.layers: + if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: + layer.self_attn.attn.impl.is_chunked_attention = True + except Exception: + pass + - patch_llama4_get_attn_scale(model) +def apply_model_specific_patches(model_runner): + """The function applies model-specific monkey patches.""" + maybe_set_chunked_attention_layers(model_runner) + patch_llama4_get_attn_scale(model_runner.model) class HpuModelAdapter(torch.nn.Module, KVConnectorModelRunnerMixin): @@ -1426,18 +1439,6 @@ def _get_num_decodes(self) -> int: num_decodes += 1 return num_decodes - def maybe_set_chunked_attention_layers(self, model): - if hasattr(model.config, 'text_config') and \ - hasattr(model.config.text_config, 'attention_chunk_size') and \ - model.config.text_config.attention_chunk_size: - self.model_has_chunked_attention = True - try: - for layer in model.language_model.model.layers: - if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: - layer.self_attn.attn.impl.is_chunked_attention = True - except Exception: - pass - def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", @@ -3766,8 +3767,7 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() - apply_model_specific_patches(self.model) - self.maybe_set_chunked_attention_layers(self.model) + apply_model_specific_patches(self) hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) modify_model_layers(self.model, From 93eb78599e04a4b6e191a008aeee02832f31f911 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Wed, 28 Jan 2026 14:02:38 +0100 Subject: [PATCH 09/11] moved as default function Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index f4d61fc1d5..d85a0d71ea 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -387,6 +387,7 @@ def maybe_set_chunked_attention_layers(model_runner): if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: layer.self_attn.attn.impl.is_chunked_attention = True except Exception: + # add explicit warning pass From 1dc2f3b1f970bc3c152e7fe6b1addc0913b055f2 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Wed, 28 Jan 2026 14:13:56 +0100 Subject: [PATCH 10/11] fix format Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d85a0d71ea..c5c7d0e1b4 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -378,17 +378,17 @@ def _get_attn_scale_for_hpu(self, positions, _orig=orig): def maybe_set_chunked_attention_layers(model_runner): - if hasattr(model_runner.model.config, 'text_config') and \ - hasattr(model_runner.model.config.text_config, 'attention_chunk_size') and \ - model_runner.model.config.text_config.attention_chunk_size: - model_runner.model_has_chunked_attention = True - try: - for layer in model_runner.model.language_model.model.layers: - if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: - layer.self_attn.attn.impl.is_chunked_attention = True - except Exception: - # add explicit warning - pass + if hasattr(model_runner.model.config, 'text_config') and \ + hasattr(model_runner.model.config.text_config, 'attention_chunk_size') and \ + model_runner.model.config.text_config.attention_chunk_size: + model_runner.model_has_chunked_attention = True + try: + for layer in model_runner.model.language_model.model.layers: + if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: + layer.self_attn.attn.impl.is_chunked_attention = True + except Exception: + # add explicit warning + pass def apply_model_specific_patches(model_runner): From 0206d3f1bf9d4f600b41181fed84d5a27180ba6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Smyrek?= Date: Wed, 28 Jan 2026 10:14:59 +0100 Subject: [PATCH 11/11] Cherry-picks to enable Llama4 Maverick (#882) 1. #805 2. #837 3. #855 4. #862 --------- Signed-off-by: Radoslaw Smyrek Signed-off-by: linoy buchnik Signed-off-by: Iryna Boiko Signed-off-by: Artur Fierka Co-authored-by: Linoy Buchnik Co-authored-by: Iryna Boiko Co-authored-by: Artur Fierka Signed-off-by: Luca Calabria --- vllm_gaudi/ops/hpu_fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index a096ce2ce0..90b5555d43 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -161,6 +161,7 @@ def forward_oot( activation=layer.activation, ) # fix needed for llama4 when context len > 32K. Change in shape + if layer.dp_size > 1: return output.view(*(output.size(0), *input_shape[1:])) else: