Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3e0fb39
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@62ad004
56 changes: 26 additions & 30 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,45 +241,40 @@ def forward(
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
if attn_metadata is None or attn_metadata.block_list is None:
out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
**self.common_attention_args())

block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None

out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args(block_list, key_cache,
value_cache))
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
**self.common_attention_args())
**self.common_attention_args(attn_metadata.block_list,
key_cache, value_cache))
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

def common_attention_args(self):
def common_attention_args(self,
block_list=None,
key_cache=None,
value_cache=None):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None

return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
Expand All @@ -290,6 +285,9 @@ def common_attention_args(self):
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
'block_list': block_list,
'key_cache': key_cache,
'value_cache': value_cache,
}

def forward_encoder_decoder(
Expand Down Expand Up @@ -371,13 +369,11 @@ def forward_encoder_decoder(
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=block_list,
block_mapping=block_mapping,
block_bias=attn_bias,
block_groups=block_groups,
**self.common_attention_args())
**self.common_attention_args(block_list, key_cache,
value_cache))
# Reshape the output tensor.
return output.view(batch_size, -1, hidden_size)

Expand Down
4 changes: 0 additions & 4 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(**kwargs) -> torch.Tensor:
return ops.prompt_attention_with_context(**kwargs)

@staticmethod
def swap_blocks(
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
Expand Down