Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@1d157d0
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@adobrzyn/move_attn_with_context
20 changes: 10 additions & 10 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,16 @@ def forward(
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
**self.common_attention_args())
out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
is_causal=True,
attn_bias=attn_metadata.attn_bias,
**self.common_attention_args())
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
Expand Down
4 changes: 0 additions & 4 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(**kwargs) -> torch.Tensor:
return ops.prompt_attention_with_context(**kwargs)

@staticmethod
def swap_blocks(
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
Expand Down