Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3e0fb39
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7208458
56 changes: 26 additions & 30 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,45 +241,40 @@ def forward(
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
if attn_metadata is None or attn_metadata.block_list is None:
out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
**self.common_attention_args())

block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None

out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args(block_list, key_cache,
value_cache))
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
**self.common_attention_args())
**self.common_attention_args(attn_metadata.block_list,
key_cache, value_cache))
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

def common_attention_args(self):
def common_attention_args(self,
block_list=None,
key_cache=None,
value_cache=None):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None

return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
Expand All @@ -290,6 +285,9 @@ def common_attention_args(self):
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
'block_list': block_list,
'key_cache': key_cache,
'value_cache': value_cache,
}

def forward_encoder_decoder(
Expand Down Expand Up @@ -371,13 +369,11 @@ def forward_encoder_decoder(
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=block_list,
block_mapping=block_mapping,
block_bias=attn_bias,
block_groups=block_groups,
**self.common_attention_args())
**self.common_attention_args(block_list, key_cache,
value_cache))
# Reshape the output tensor.
return output.view(batch_size, -1, hidden_size)

Expand Down
4 changes: 0 additions & 4 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(**kwargs) -> torch.Tensor:
return ops.prompt_attention_with_context(**kwargs)

@staticmethod
def swap_blocks(
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
Expand Down