Skip to content
Merged
16 changes: 16 additions & 0 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
window_block_groups: Optional[torch.Tensor] = None
window_block_usage: Optional[torch.Tensor] = None
window_attn_bias: Optional[torch.Tensor] = None
chunked_slot_mapping: Optional[torch.Tensor] = None
chunked_attn_bias: Optional[torch.Tensor] = None
chunked_block_mapping: Optional[torch.Tensor] = None
chunked_block_list: Optional[torch.Tensor] = None
chunked_block_groups: Optional[torch.Tensor] = None
chunked_block_usage: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -492,6 +498,7 @@ def __init__(
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
self.is_chunked_attention = False

def _maybe_init_alibi_biases(
self,
Expand Down Expand Up @@ -630,6 +637,9 @@ def forward(
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size
if self.is_chunked_attention and \
hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
attn_bias = attn_metadata.chunked_attn_bias

out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
Expand All @@ -650,6 +660,12 @@ def forward(
block_groups = attn_metadata.window_block_groups
block_mapping = attn_metadata.window_block_mapping
attn_bias = attn_metadata.window_attn_bias
elif self.is_chunked_attention and \
attn_metadata.chunked_block_list is not None:
block_list = attn_metadata.chunked_block_list
block_groups = attn_metadata.chunked_block_groups
block_mapping = attn_metadata.chunked_block_mapping
attn_bias = attn_metadata.chunked_attn_bias
else:
block_list = attn_metadata.block_list
block_groups = attn_metadata.block_groups
Expand Down
2 changes: 2 additions & 0 deletions vllm_gaudi/ops/hpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def forward_oot(
permuted_weights=True,
activation=layer.activation,
)

# fix needed for llama4 when context len > 32K. Change in shape
if layer.dp_size > 1:
return output.view(*(output.size(0), *input_shape[1:]))
else:
Expand Down
6 changes: 6 additions & 0 deletions vllm_gaudi/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def make_decode_metadata(cls,
window_block_list,
window_block_usage,
window_block_groups,
chunked_block_list,
chunked_block_usage,
chunked_block_groups,
query_start_loc=None):
return cls(is_prompt=False,
block_mapping=None,
Expand All @@ -100,6 +103,9 @@ def make_decode_metadata(cls,
window_block_list=window_block_list,
window_block_usage=window_block_usage,
window_block_groups=window_block_groups,
chunked_block_list=chunked_block_list,
chunked_block_usage=chunked_block_usage,
chunked_block_groups=chunked_block_groups,
input_positions=input_positions,
slot_mapping=slot_mapping,
block_size=block_size,
Expand Down
3 changes: 3 additions & 0 deletions vllm_gaudi/v1/spec_decode/hpu_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def prepare_attn_metadata(
window_block_list=None,
window_block_usage=None,
window_block_groups=None,
chunked_block_list=None,
chunked_block_usage=None,
chunked_block_groups=None,
)

return common_attn_metadata
Loading