Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f855191
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@aee363c
246 changes: 240 additions & 6 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.mla.utils import MLACommonImpl
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
Expand Down Expand Up @@ -69,6 +70,49 @@ def copy_blocks(
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)


class HPUMLAAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "HPU_MLA"

@staticmethod
def get_impl_cls() -> Type["HPUMLAImpl"]:
return HPUMLAImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HPUMLAMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
Expand All @@ -78,6 +122,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
input_positions: torch.Tensor
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand All @@ -91,6 +136,201 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
cross_attn_bias: Optional[torch.Tensor] = None


@dataclass
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
pass


class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**kwargs) -> None:
torch.nn.Module.__init__(self)
MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**kwargs)

self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.latent_cache_k = VLLMKVCache()
self.fused_scaled_dot_product_attention = kernels.fsdpa()

if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
else:
self.prefill_impl = 'naive'

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"HPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")

def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")

batch_size = hidden_states_or_q_c.shape[0]

is_prefill = attn_metadata.is_prompt

k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

assert hasattr(attn_metadata,
"input_positions"), f"attn meta: {attn_metadata}"

if not is_prefill:
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
else:
q_nope, q_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
input_positions = attn_metadata.input_positions.view(-1)
q_pe, k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)
else:
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

q_pe = q[..., self.qk_nope_head_dim:]

input_positions = attn_metadata.input_positions.view(-1)
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)

block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets

latent_vec_k = torch.concat(
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)),
dim=-1)
latent_vec_k = latent_vec_k.view(
-1, self.qk_rope_head_dim + self.kv_lora_rank)
if is_prefill:
latent_vec_k = latent_vec_k.unflatten(0,
(block_indices.size(0), -1))

# write the latent and rope to kv cache
if kv_cache is not None and len(kv_cache) == 2:
self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices,
block_offsets)
k_cache = kv_cache[0]
v_cache = None

if is_prefill:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata,
batch_size)
else:
return self._forward_decode(q_nope, q_pe, (k_cache, v_cache),
attn_metadata, batch_size)

def _forward_prefill( # type: ignore
self, q: torch.Tensor, k_c_normed: torch.Tensor,
k_pe: torch.Tensor, attn_metadata: HPUAttentionMetadata,
batch_size: int) -> torch.Tensor:
kv_nope = self.kv_b_proj(k_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
v_padded = v_padded.view(batch_size, -1, self.num_heads,
self.qk_head_dim)
out = ops.prompt_attention(
impl=self.prefill_impl,
query=q,
key=k,
value=v_padded,
is_causal=True,
attn_bias=attn_metadata.attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None)
attn_output = out.view(batch_size, -1, self.num_heads, q.shape[-1])
attn_output = attn_output[..., :v.shape[-1]]\
.reshape(batch_size, -1, self.num_heads * v.shape[-1])

return self.o_proj(attn_output)[0]

def _forward_decode( # type: ignore
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata,
batch_size: int) -> torch.Tensor:
query = torch.cat([q_nope, q_pe], dim=-1)

key_cache = kv_cache[0].unsqueeze(2)
value_cache = kv_cache[1] # value_cache is None
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
values_fetch_func=None,
kv_lora_rank=self.kv_lora_rank)
output = output.view(batch_size, 1, -1)
result = self._v_up_proj_and_o_proj(output)
result = result.view(batch_size, 1, -1)
return result


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down Expand Up @@ -153,12 +393,6 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.attn_type = attn_type
if (self.attn_type != AttentionType.DECODER
and self.attn_type != AttentionType.ENCODER_DECODER
Expand Down
Loading
Loading