From 4b57887d8737664459803007979cb6284cc4bf6e Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Fri, 23 Jan 2026 16:21:30 +0100 Subject: [PATCH 1/4] cherry-pick chunked attention --- vllm_gaudi/attention/backends/hpu_attn.py | 18 ++ vllm_gaudi/v1/attention/backends/hpu_attn.py | 6 + vllm_gaudi/v1/spec_decode/hpu_eagle.py | 3 + vllm_gaudi/v1/worker/hpu_model_runner.py | 166 +++++++++++++++++-- 4 files changed, 178 insertions(+), 15 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 5bd6460ed2..5058dd1dec 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -169,6 +169,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): window_block_groups: Optional[torch.Tensor] = None window_block_usage: Optional[torch.Tensor] = None window_attn_bias: Optional[torch.Tensor] = None + chunked_slot_mapping: Optional[torch.Tensor] = None + chunked_attn_bias: Optional[torch.Tensor] = None + chunked_block_mapping: Optional[torch.Tensor] = None + chunked_block_list: Optional[torch.Tensor] = None + chunked_block_groups: Optional[torch.Tensor] = None + chunked_block_usage: Optional[torch.Tensor] = None @dataclass @@ -493,6 +499,8 @@ def __init__( "is not implemented for " "HPUAttentionImpl") + self.is_chunked_attention = False + def _maybe_init_alibi_biases( self, max_seq_len, @@ -631,6 +639,10 @@ def forward( window_size = (self.sliding_window, 0) common_args['window_size'] = window_size + if self.is_chunked_attention and \ + hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None: + attn_bias = attn_metadata.chunked_attn_bias + out = ops.prompt_attention(impl=self.prefill_impl, query=query.view(query_shape), key=key.view(kv_shape), @@ -650,6 +662,12 @@ def forward( block_groups = attn_metadata.window_block_groups block_mapping = attn_metadata.window_block_mapping attn_bias = attn_metadata.window_attn_bias + elif self.is_chunked_attention and \ + attn_metadata.chunked_block_list is not None: + block_list = attn_metadata.chunked_block_list + block_groups = attn_metadata.chunked_block_groups + block_mapping = attn_metadata.chunked_block_mapping + attn_bias = attn_metadata.chunked_attn_bias else: block_list = attn_metadata.block_list block_groups = attn_metadata.block_groups diff --git a/vllm_gaudi/v1/attention/backends/hpu_attn.py b/vllm_gaudi/v1/attention/backends/hpu_attn.py index 1d2c1d2363..978cec0cd9 100644 --- a/vllm_gaudi/v1/attention/backends/hpu_attn.py +++ b/vllm_gaudi/v1/attention/backends/hpu_attn.py @@ -87,6 +87,9 @@ def make_decode_metadata(cls, window_block_list, window_block_usage, window_block_groups, + chunked_block_list, + chunked_block_usage, + chunked_block_groups, query_start_loc=None): return cls(is_prompt=False, block_mapping=None, @@ -100,6 +103,9 @@ def make_decode_metadata(cls, window_block_list=window_block_list, window_block_usage=window_block_usage, window_block_groups=window_block_groups, + chunked_block_list=chunked_block_list, + chunked_block_usage=chunked_block_usage, + chunked_block_groups=chunked_block_groups, input_positions=input_positions, slot_mapping=slot_mapping, block_size=block_size, diff --git a/vllm_gaudi/v1/spec_decode/hpu_eagle.py b/vllm_gaudi/v1/spec_decode/hpu_eagle.py index 7a0d508b0a..1c9159a810 100644 --- a/vllm_gaudi/v1/spec_decode/hpu_eagle.py +++ b/vllm_gaudi/v1/spec_decode/hpu_eagle.py @@ -237,6 +237,9 @@ def prepare_attn_metadata( window_block_list=None, window_block_usage=None, window_block_groups=None, + chunked_block_list=None, + chunked_block_usage=None, + chunked_block_groups=None, ) return common_attn_metadata diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 541334cd91..87c411efd2 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -459,10 +459,12 @@ def forward(self, *args, **kwargs): if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] + model_has_chunked_attention = kwargs.pop('model_has_chunked_attention', False) if not self.unified_attn: kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + input_ids.device, self.dtype, + model_has_chunked_attention) if self._rotary_prepare_cos_sin is not None: self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') @@ -587,6 +589,11 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object: 'window_block_usage', 'window_block_groups', 'window_attn_bias', + 'chunked_block_mapping', + 'chunked_attn_bias', + 'chunked_block_list', + 'chunked_block_usage', + 'chunked_block_groups', ]) return attention_metadata @@ -833,6 +840,8 @@ def __init__( self.scheduler_output: SchedulerOutput | None = None self.warmup_mode: bool = False self.batch_changed: bool = False + # WA for chunked attention support + self.model_has_chunked_attention = False assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' @@ -1400,6 +1409,18 @@ def _get_num_decodes(self) -> int: num_decodes += 1 return num_decodes + def maybe_set_chunked_attention_layers(self, model): + if hasattr(model.config, 'text_config') and \ + hasattr(model.config.text_config, 'attention_chunk_size') and \ + model.config.text_config.attention_chunk_size: + self.model_has_chunked_attention = True + try: + for layer in model.language_model.model.layers: + if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: + layer.self_attn.attn.impl.is_chunked_attention = True + except Exception: + pass + def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", @@ -2116,6 +2137,19 @@ def _create_decode_input_data(self, window_block_tables, slot_mapping.tolist(), padded_batch_size * num_tokens) + if self.model_has_chunked_attention: + chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size) + seq_lens_block = [len(block_table) for block_table in block_tables_list] + num_seq_chunks = [math.ceil(sl / chunk_size_in_blocks) - 1 for sl in seq_lens_block] + block_tables_chunk = [ + block_table[num_seq_chunks[i] * chunk_size_in_blocks:] + for i, block_table in enumerate(block_tables_list) + ] + chunked_block_list, chunked_block_groups, chunked_block_usage = \ + self.get_habana_paged_attn_buffers( + block_tables_chunk, slot_mapping.tolist(), + padded_batch_size * num_tokens) + # CPU<>HPU sync *should not* happen here. block_list_device = async_h2d_copy(block_list, device=self.device) block_usage_device = async_h2d_copy(block_usage, device=self.device) @@ -2130,6 +2164,12 @@ def _create_decode_input_data(self, window_block_groups_device = async_h2d_copy( window_block_groups, device=self.device) if self.interleaved_sliding_window and self.sliding_window is not None else None + chunked_block_list_device = async_h2d_copy(chunked_block_list, + device=self.device) if self.model_has_chunked_attention else None + chunked_block_usage_device = async_h2d_copy(chunked_block_usage, + device=self.device) if self.model_has_chunked_attention else None + chunked_block_groups_device = async_h2d_copy(chunked_block_groups, + device=self.device) if self.model_has_chunked_attention else None token_ids_device = async_h2d_copy(token_ids, device=self.device) # when DP also enabled, some DP ranks will exeucte dummy run with empty @@ -2160,21 +2200,26 @@ def _create_decode_input_data(self, spec_decode_metadata = None logits_indices_device = async_h2d_copy(logits_indices, device=self.device) + attn_metadata = HPUAttentionMetadataV1.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + input_positions=None, + slot_mapping=slot_mapping_device, + block_size=self.block_size, + window_block_list=window_block_list_device, + window_block_usage=window_block_usage_device, + window_block_groups=window_block_groups_device, + chunked_block_list=chunked_block_list_device, + chunked_block_usage=chunked_block_usage_device, + chunked_block_groups=chunked_block_groups_device, + ) + return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids_device, position_ids=positions_device, logits_indices=logits_indices_device, - attn_metadata=HPUAttentionMetadataV1.make_decode_metadata( - block_list=block_list_device, - block_usage=block_usage_device, - block_groups=block_groups_device, - input_positions=None, - slot_mapping=slot_mapping_device, - block_size=self.block_size, - window_block_list=window_block_list_device, - window_block_usage=window_block_usage_device, - window_block_groups=window_block_groups_device, - ), + attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata) def _prepare_decode_inputs(self, @@ -2602,6 +2647,8 @@ def _execute_model_generic(self, else: # no hpu graphs for t.compile? use_graphs = False + if self.model_has_chunked_attention: + additional_kwargs.update({"model_has_chunked_attention": True}) trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata) if self.is_driver_worker: model_event_name = ("model_forward_" @@ -3709,6 +3756,7 @@ def load_model(self) -> None: htcore.mark_step() apply_model_specific_patches(self.model) + self.maybe_set_chunked_attention_layers(self.model) hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) modify_model_layers(self.model, @@ -5571,13 +5619,76 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata + + def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + chunk_size: int, device: torch.device, + dtype: torch.dtype) -> HPUAttentionMetadataV1: + """Set attention bias for chunked attention. + Args: + attn_metadata (HPUAttentionMetadataV1): The attention metadata. + batch_size (int): The batch size. + seq_len (int): The sequence length. + chunk_size (int): The chunk size. + device (torch.device): The device to use. + dtype (torch.dtype): The data type. + Returns: + HPUAttentionMetadataV1: The updated attention metadata. + """ + + if (attn_metadata is None or not attn_metadata.is_prompt): + return attn_metadata + + prefill_metadata = attn_metadata + shift = 0 + + if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: + + context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None + block_list = prefill_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + query_positions = torch.arange(seq_len, device=device) + total_token_positions = context_lens_t.unsqueeze(-1) + query_positions.unsqueeze(0) + which_chunk = (total_token_positions // chunk_size) + chunk_start_positions = which_chunk * chunk_size + invalid_lens_t = chunk_start_positions - 1 + + past_indices = torch.arange(max_context_len, device=device) + past_mask = ( + (past_indices.unsqueeze(0).unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & + (past_indices.unsqueeze(0).unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(-1))).unsqueeze(1) + + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) + query_chunk_ids = which_chunk[0] + same_chunk_mask = query_chunk_ids.unsqueeze(0) == query_chunk_ids.unsqueeze(1) + + causal_mask = causal_mask & same_chunk_mask + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) + + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(float('-inf'), dtype=dtype, device=device)) + else: + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + idx = torch.arange(seq_len, device=device) + chunk_id = idx // chunk_size + same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) + same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) + mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) + attn_bias = torch.log(mask) + + attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) + return attn_metadata def _set_block_mapping(self, metadata: HPUAttentionMetadataV1, batch_size: int, device: torch.device, dtype: torch.dtype, - is_window_block: bool = False) -> HPUAttentionMetadataV1: + is_window_block: bool = False, + update_for_chunked_attention: bool = False) -> HPUAttentionMetadataV1: """ Set block mapping for decode phase. @@ -5589,6 +5700,7 @@ def _set_block_mapping(self, device: Device to create tensors on dtype: Data type for tensors is_window_block: Whether this is for window blocks + update_for_chunked_attention: Whether to update for chunked attention Returns: Updated attention metadata with block_mapping and attn_bias set @@ -5596,6 +5708,9 @@ def _set_block_mapping(self, if is_window_block: block_usage = metadata.window_block_usage block_groups = metadata.window_block_groups + elif update_for_chunked_attention: + block_usage = metadata.chunked_block_usage + block_groups = metadata.chunked_block_groups else: block_usage = metadata.block_usage block_groups = metadata.block_groups @@ -5626,6 +5741,11 @@ def _set_block_mapping(self, "TrimmedAttentionMetadata", window_block_mapping=block_mapping, window_attn_bias=attn_bias) + elif update_for_chunked_attention: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + chunked_block_mapping=block_mapping, + chunked_attn_bias=attn_bias) else: metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", @@ -5633,8 +5753,13 @@ def _set_block_mapping(self, attn_bias=attn_bias) return metadata - def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, - device: torch.device, dtype: torch.dtype) -> HPUAttentionMetadataV1: + def process_metadata(self, + attn_metadata: HPUAttentionMetadataV1, + batch_size: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + model_has_chunked_attention: bool = False) -> HPUAttentionMetadataV1: """ Post-process attention metadata with appropriate masks and mappings. @@ -5648,6 +5773,7 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in seq_len: Sequence length (for prompt phase) device: Device to create tensors on dtype: Data type for tensors + model_has_chunked_attention: Whether the model has chunked attention Returns: Post-processed attention metadata with additional tensors @@ -5657,8 +5783,18 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in if self.interleaved_sliding_window: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) + if model_has_chunked_attention: + attention_chunk_size = self.vllm_config.model_config.hf_config.text_config.attention_chunk_size + attn_metadata = self._set_attn_bias_for_chunked_attention(attn_metadata, batch_size, seq_len, + attention_chunk_size, device, dtype) else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) + if model_has_chunked_attention: + attn_metadata = self._set_block_mapping(attn_metadata, + batch_size, + device, + dtype, + update_for_chunked_attention=True) if self.interleaved_sliding_window: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) return attn_metadata From 15ed298cca88bc1923c48b881abd9af3baf2a3dc Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Fri, 23 Jan 2026 16:59:57 +0100 Subject: [PATCH 2/4] Update vllm_gaudi/v1/worker/hpu_model_runner.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 87c411efd2..067b8363f5 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1420,7 +1420,6 @@ def maybe_set_chunked_attention_layers(self, model): layer.self_attn.attn.impl.is_chunked_attention = True except Exception: pass - def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", From 1b72163a6d695c100f8e359541561b82bd0ea6e1 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Fri, 23 Jan 2026 17:00:27 +0100 Subject: [PATCH 3/4] Update vllm_gaudi/v1/worker/hpu_model_runner.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 067b8363f5..c704332a10 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2148,7 +2148,6 @@ def _create_decode_input_data(self, self.get_habana_paged_attn_buffers( block_tables_chunk, slot_mapping.tolist(), padded_batch_size * num_tokens) - # CPU<>HPU sync *should not* happen here. block_list_device = async_h2d_copy(block_list, device=self.device) block_usage_device = async_h2d_copy(block_usage, device=self.device) From 326ca333ccc29454c225c8d19e73535d0400706c Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Mon, 26 Jan 2026 10:02:19 +0100 Subject: [PATCH 4/4] fix for 32k+ context window Llama4 --- vllm_gaudi/ops/hpu_fused_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index a27710fa74..4168d515ee 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -160,7 +160,10 @@ def forward_oot( permuted_weights=True, activation=layer.activation, ) - return output.view(*(output.size(0), *input_shape[1:])) + if layer.dp_size > 1: + return output.view(*(output.size(0), *input_shape[1:])) + else: + return output.view(*input_shape) def reduce_output(self, states: torch.Tensor) -> torch.Tensor: