Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
window_block_groups: Optional[torch.Tensor] = None
window_block_usage: Optional[torch.Tensor] = None
window_attn_bias: Optional[torch.Tensor] = None
chunked_slot_mapping: Optional[torch.Tensor] = None
chunked_attn_bias: Optional[torch.Tensor] = None
chunked_block_mapping: Optional[torch.Tensor] = None
chunked_block_list: Optional[torch.Tensor] = None
chunked_block_groups: Optional[torch.Tensor] = None
chunked_block_usage: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -493,6 +499,8 @@ def __init__(
"is not implemented for "
"HPUAttentionImpl")

self.is_chunked_attention = False

def _maybe_init_alibi_biases(
self,
max_seq_len,
Expand Down Expand Up @@ -630,6 +638,9 @@ def forward(
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size
if self.is_chunked_attention and \
hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
attn_bias = attn_metadata.chunked_attn_bias

out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
Expand All @@ -650,6 +661,12 @@ def forward(
block_groups = attn_metadata.window_block_groups
block_mapping = attn_metadata.window_block_mapping
attn_bias = attn_metadata.window_attn_bias
elif self.is_chunked_attention and \
attn_metadata.chunked_block_list is not None:
block_list = attn_metadata.chunked_block_list
block_groups = attn_metadata.chunked_block_groups
block_mapping = attn_metadata.chunked_block_mapping
attn_bias = attn_metadata.chunked_attn_bias
else:
block_list = attn_metadata.block_list
block_groups = attn_metadata.block_groups
Expand Down
6 changes: 6 additions & 0 deletions vllm_gaudi/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def make_decode_metadata(cls,
window_block_list,
window_block_usage,
window_block_groups,
chunked_block_list,
chunked_block_usage,
chunked_block_groups,
query_start_loc=None):
return cls(is_prompt=False,
block_mapping=None,
Expand All @@ -100,6 +103,9 @@ def make_decode_metadata(cls,
window_block_list=window_block_list,
window_block_usage=window_block_usage,
window_block_groups=window_block_groups,
chunked_block_list=chunked_block_list,
chunked_block_usage=chunked_block_usage,
chunked_block_groups=chunked_block_groups,
input_positions=input_positions,
slot_mapping=slot_mapping,
block_size=block_size,
Expand Down
3 changes: 3 additions & 0 deletions vllm_gaudi/v1/spec_decode/hpu_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def prepare_attn_metadata(
window_block_list=None,
window_block_usage=None,
window_block_groups=None,
chunked_block_list=None,
chunked_block_usage=None,
chunked_block_groups=None,
)

return common_attn_metadata
168 changes: 153 additions & 15 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,12 @@ def forward(self, *args, **kwargs):
if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode')
input_ids = kwargs['input_ids']
model_has_chunked_attention = kwargs.pop('model_has_chunked_attention', False)
if not self.unified_attn:
kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'],
input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
input_ids.device, self.dtype,
model_has_chunked_attention)
if self._rotary_prepare_cos_sin is not None:
self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin)
attn_meta = kwargs.pop('attn_metadata')
Expand Down Expand Up @@ -587,6 +589,11 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object:
'window_block_usage',
'window_block_groups',
'window_attn_bias',
'chunked_block_mapping',
'chunked_attn_bias',
'chunked_block_list',
'chunked_block_usage',
'chunked_block_groups',
])
return attention_metadata

Expand Down Expand Up @@ -833,6 +840,8 @@ def __init__(
self.scheduler_output: SchedulerOutput | None = None
self.warmup_mode: bool = False
self.batch_changed: bool = False
# WA for chunked attention support
self.model_has_chunked_attention = False

assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!'
assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!'
Expand Down Expand Up @@ -1400,6 +1409,18 @@ def _get_num_decodes(self) -> int:
num_decodes += 1
return num_decodes

def maybe_set_chunked_attention_layers(self, model):
if hasattr(model.config, 'text_config') and \
hasattr(model.config.text_config, 'attention_chunk_size') and \
model.config.text_config.attention_chunk_size:
self.model_has_chunked_attention = True
try:
for layer in model.language_model.model.layers:
if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__:
layer.self_attn.attn.impl.is_chunked_attention = True
except Exception:
pass
Comment on lines +1412 to +1422
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bare except Exception: pass silently suppresses all errors without logging. This makes debugging difficult if the chunked attention setup fails. Add logging to record when this exception occurs, including the exception details.

Copilot uses AI. Check for mistakes.

def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
Expand Down Expand Up @@ -2116,6 +2137,19 @@ def _create_decode_input_data(self,
window_block_tables, slot_mapping.tolist(),
padded_batch_size * num_tokens)

if self.model_has_chunked_attention:
chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size)
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division could result in zero if attention_chunk_size is smaller than block_size, leading to incorrect chunking behavior. Add validation to ensure attention_chunk_size is at least equal to block_size or handle the zero case appropriately.

Suggested change
chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size)
attention_chunk_size = self.model.model.config.text_config.attention_chunk_size
if attention_chunk_size < self.block_size:
raise ValueError(
f"Configured attention_chunk_size ({attention_chunk_size}) must be at least "
f"as large as block_size ({self.block_size}) when using chunked attention."
)
chunk_size_in_blocks = attention_chunk_size // self.block_size

Copilot uses AI. Check for mistakes.
seq_lens_block = [len(block_table) for block_table in block_tables_list]
num_seq_chunks = [math.ceil(sl / chunk_size_in_blocks) - 1 for sl in seq_lens_block]
block_tables_chunk = [
block_table[num_seq_chunks[i] * chunk_size_in_blocks:]
for i, block_table in enumerate(block_tables_list)
]
chunked_block_list, chunked_block_groups, chunked_block_usage = \
self.get_habana_paged_attn_buffers(
block_tables_chunk, slot_mapping.tolist(),
padded_batch_size * num_tokens)

# CPU<>HPU sync *should not* happen here.
block_list_device = async_h2d_copy(block_list, device=self.device)
block_usage_device = async_h2d_copy(block_usage, device=self.device)
Expand All @@ -2130,6 +2164,12 @@ def _create_decode_input_data(self,
window_block_groups_device = async_h2d_copy(
window_block_groups,
device=self.device) if self.interleaved_sliding_window and self.sliding_window is not None else None
chunked_block_list_device = async_h2d_copy(chunked_block_list,
device=self.device) if self.model_has_chunked_attention else None
chunked_block_usage_device = async_h2d_copy(chunked_block_usage,
device=self.device) if self.model_has_chunked_attention else None
chunked_block_groups_device = async_h2d_copy(chunked_block_groups,
device=self.device) if self.model_has_chunked_attention else None

token_ids_device = async_h2d_copy(token_ids, device=self.device)
# when DP also enabled, some DP ranks will exeucte dummy run with empty
Expand Down Expand Up @@ -2160,21 +2200,26 @@ def _create_decode_input_data(self,
spec_decode_metadata = None
logits_indices_device = async_h2d_copy(logits_indices, device=self.device)

attn_metadata = HPUAttentionMetadataV1.make_decode_metadata(
block_list=block_list_device,
block_usage=block_usage_device,
block_groups=block_groups_device,
input_positions=None,
slot_mapping=slot_mapping_device,
block_size=self.block_size,
window_block_list=window_block_list_device,
window_block_usage=window_block_usage_device,
window_block_groups=window_block_groups_device,
chunked_block_list=chunked_block_list_device,
chunked_block_usage=chunked_block_usage_device,
chunked_block_groups=chunked_block_groups_device,
)

return DecodeInputData(num_decodes=num_decodes,
token_ids=token_ids_device,
position_ids=positions_device,
logits_indices=logits_indices_device,
attn_metadata=HPUAttentionMetadataV1.make_decode_metadata(
block_list=block_list_device,
block_usage=block_usage_device,
block_groups=block_groups_device,
input_positions=None,
slot_mapping=slot_mapping_device,
block_size=self.block_size,
window_block_list=window_block_list_device,
window_block_usage=window_block_usage_device,
window_block_groups=window_block_groups_device,
),
attn_metadata=attn_metadata,
spec_decode_metadata=spec_decode_metadata)

def _prepare_decode_inputs(self,
Expand Down Expand Up @@ -2602,6 +2647,8 @@ def _execute_model_generic(self,
else:
# no hpu graphs for t.compile?
use_graphs = False
if self.model_has_chunked_attention:
additional_kwargs.update({"model_has_chunked_attention": True})
trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata)
if self.is_driver_worker:
model_event_name = ("model_forward_"
Expand Down Expand Up @@ -3709,6 +3756,7 @@ def load_model(self) -> None:
htcore.mark_step()

apply_model_specific_patches(self.model)
self.maybe_set_chunked_attention_layers(self.model)
hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_model_layers(self.model,
Expand Down Expand Up @@ -5571,12 +5619,77 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV
attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias)
return attn_metadata

def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int,
chunk_size: int, device: torch.device,
dtype: torch.dtype) -> HPUAttentionMetadataV1:
"""Set attention bias for chunked attention.

Args:
attn_metadata (HPUAttentionMetadataV1): The attention metadata.
batch_size (int): The batch size.
seq_len (int): The sequence length.
chunk_size (int): The chunk size.
device (torch.device): The device to use.
dtype (torch.dtype): The data type.

Returns:
HPUAttentionMetadataV1: The updated attention metadata.
"""

if (attn_metadata is None or not attn_metadata.is_prompt):
return attn_metadata

prefill_metadata = attn_metadata
shift = 0

if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None:

context_lens_t = prefill_metadata.context_lens_tensor
assert context_lens_t is not None
block_list = prefill_metadata.block_list
max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
Comment on lines +5650 to +5651
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division could result in zero or incorrect value if block_list.size(-1) is smaller than batch_size. This could lead to incorrect attention bias calculation. Add validation or use math.ceil for the division to ensure proper handling of partial blocks.

Suggested change
max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
if block_list is not None and batch_size > 0:
# Compute number of blocks per sequence using ceiling division to handle partial blocks.
blocks_per_seq = math.ceil(block_list.size(-1) / batch_size)
max_context_len = blocks_per_seq * self.block_size
else:
max_context_len = 0

Copilot uses AI. Check for mistakes.
query_positions = torch.arange(seq_len, device=device)
total_token_positions = context_lens_t.unsqueeze(-1) + query_positions.unsqueeze(0)
which_chunk = (total_token_positions // chunk_size)
chunk_start_positions = which_chunk * chunk_size
invalid_lens_t = chunk_start_positions - 1

past_indices = torch.arange(max_context_len, device=device)
past_mask = (
(past_indices.unsqueeze(0).unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) &
(past_indices.unsqueeze(0).unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(-1))).unsqueeze(1)

causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing with [0] assumes which_chunk has at least one element. While this may be guaranteed by the context, the assumption is not immediately clear. Consider adding a comment explaining why the first element is used or add an assertion to document this assumption.

Suggested change
causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
# which_chunk is expected to have at least one row (batch dimension > 0) in this code path.
assert which_chunk.size(0) > 0, "which_chunk is expected to have at least one row"

Copilot uses AI. Check for mistakes.
query_chunk_ids = which_chunk[0]
same_chunk_mask = query_chunk_ids.unsqueeze(0) == query_chunk_ids.unsqueeze(1)

causal_mask = causal_mask & same_chunk_mask
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len)

mask = torch.concat((past_mask, causal_mask), dim=-1)
attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(float('-inf'), dtype=dtype, device=device))
else:
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
idx = torch.arange(seq_len, device=device)
chunk_id = idx // chunk_size
same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1)
same_chunk = same_chunk.unsqueeze(0).unsqueeze(0)
mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device))
attn_bias = torch.log(mask)

attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias)
return attn_metadata

def _set_block_mapping(self,
metadata: HPUAttentionMetadataV1,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
is_window_block: bool = False) -> HPUAttentionMetadataV1:
is_window_block: bool = False,
update_for_chunked_attention: bool = False) -> HPUAttentionMetadataV1:
"""
Set block mapping for decode phase.

Expand All @@ -5588,13 +5701,17 @@ def _set_block_mapping(self,
device: Device to create tensors on
dtype: Data type for tensors
is_window_block: Whether this is for window blocks
update_for_chunked_attention: Whether to update for chunked attention

Returns:
Updated attention metadata with block_mapping and attn_bias set
"""
if is_window_block:
block_usage = metadata.window_block_usage
block_groups = metadata.window_block_groups
elif update_for_chunked_attention:
block_usage = metadata.chunked_block_usage
block_groups = metadata.chunked_block_groups
else:
block_usage = metadata.block_usage
block_groups = metadata.block_groups
Expand Down Expand Up @@ -5625,15 +5742,25 @@ def _set_block_mapping(self,
"TrimmedAttentionMetadata",
window_block_mapping=block_mapping,
window_attn_bias=attn_bias)
elif update_for_chunked_attention:
metadata = custom_tuple_replace(metadata,
"TrimmedAttentionMetadata",
chunked_block_mapping=block_mapping,
chunked_attn_bias=attn_bias)
else:
metadata = custom_tuple_replace(metadata,
"TrimmedAttentionMetadata",
block_mapping=block_mapping,
attn_bias=attn_bias)
return metadata

def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int,
device: torch.device, dtype: torch.dtype) -> HPUAttentionMetadataV1:
def process_metadata(self,
attn_metadata: HPUAttentionMetadataV1,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
model_has_chunked_attention: bool = False) -> HPUAttentionMetadataV1:
"""
Post-process attention metadata with appropriate masks and mappings.

Expand All @@ -5647,6 +5774,7 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in
seq_len: Sequence length (for prompt phase)
device: Device to create tensors on
dtype: Data type for tensors
model_has_chunked_attention: Whether the model has chunked attention

Returns:
Post-processed attention metadata with additional tensors
Expand All @@ -5656,8 +5784,18 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in
if self.interleaved_sliding_window and self.sliding_window is not None:
attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len,
self.sliding_window, device, dtype)
if model_has_chunked_attention:
attention_chunk_size = self.vllm_config.model_config.hf_config.text_config.attention_chunk_size
attn_metadata = self._set_attn_bias_for_chunked_attention(attn_metadata, batch_size, seq_len,
attention_chunk_size, device, dtype)
else:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype)
if model_has_chunked_attention:
attn_metadata = self._set_block_mapping(attn_metadata,
batch_size,
device,
dtype,
update_for_chunked_attention=True)
if self.interleaved_sliding_window and self.sliding_window is not None:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True)
return attn_metadata
Expand Down