Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5dbefd6
[SW-224648] Redirect test logs to file (#1017)
bmyrcha Apr 8, 2025
236ac10
apply deepseek change
xuechendi Apr 8, 2025
d273848
update for mypy
xuechendi Apr 8, 2025
7cf1dcf
fix acc issue
xuechendi Apr 8, 2025
4560a09
fix mypy
xuechendi Apr 8, 2025
85f0693
update vllm-hpu-extension comit id for test
xuechendi Apr 8, 2025
ff61f89
[SW-224648] Fix test logs redirection (#1027)
bmyrcha Apr 9, 2025
b92af9c
[SW-225233] Adjust method of getting synapse_build (#1045)
bmyrcha Apr 9, 2025
5a9ddfd
Implement Pipeline Parallelism support for HPU. (#1000) (#1040)
jmaksymc Apr 10, 2025
ed47e1e
[1.21 cherry-pick] Fix async callback ordering (#1023) (#1028)
madamczyk-intel Apr 10, 2025
9a06a89
[1.21 cherry-pick] Make lazy mode autodetection more robust (#1038)
madamczyk-intel Apr 10, 2025
a93c26a
Add temporary workaround for V1
kwisniewski98 Apr 11, 2025
035db32
APC - Remove prompt attn with context and use existing implementation…
adobrzyn Apr 11, 2025
496938d
Resolve review comments
xuechendi Apr 12, 2025
a6358a5
update dependent vllm-hpu-extension
xuechendi Apr 12, 2025
d362dd4
Merge branch 'v1.21.0_next' into dev/chendi/deepseek_r1
xuechendi Apr 12, 2025
b576015
Cherry pick exponential bucketing integration from #642 (#1067)
kzawora-intel Apr 12, 2025
b49caca
Remove o_proj only for deepseek
kwisniewski98 Apr 14, 2025
214bcae
Change vllm-hpu-extension version
kwisniewski98 Apr 14, 2025
3cb8b06
Merge branch 'v1.21.0_next' into dev/chendi/deepseek_r1
kwisniewski98 Apr 14, 2025
9b85748
Explicitly disable t.compile for deepseek
kwisniewski98 Apr 15, 2025
1d7fb51
Change method of checking lazy mode
kwisniewski98 Apr 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3e0fb39
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@35c8288
249 changes: 243 additions & 6 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.mla.utils import MLACommonImpl
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
Expand Down Expand Up @@ -69,6 +70,49 @@ def copy_blocks(
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)


class HPUMLAAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "HPU_MLA"

@staticmethod
def get_impl_cls() -> Type["HPUMLAImpl"]:
return HPUMLAImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HPUMLAMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
Expand All @@ -78,6 +122,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
input_positions: torch.Tensor
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand All @@ -91,6 +136,204 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
cross_attn_bias: Optional[torch.Tensor] = None


@dataclass
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
pass


class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**kwargs) -> None:
torch.nn.Module.__init__(self)
MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**kwargs)

self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.latent_cache_k = VLLMKVCache()
self.fused_scaled_dot_product_attention = kernels.fsdpa()

if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
else:
self.prefill_impl = 'naive'

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")

def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")

batch_size = hidden_states_or_q_c.shape[0]

is_prefill = attn_metadata.is_prompt

k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

# Restore head dim (for rotary embedding)
# k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata,
"input_positions"), f"attn meta: {attn_metadata}"

if not is_prefill:
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
else:
q_nope, q_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
input_positions = attn_metadata.input_positions.view(-1)
q_pe, k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)
else:
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

q_pe = q[..., self.qk_nope_head_dim:]

input_positions = attn_metadata.input_positions.view(-1)
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)

block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets

latent_vec_k = torch.concat(
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)),
dim=-1)
latent_vec_k = latent_vec_k.view(
-1, self.qk_rope_head_dim + self.kv_lora_rank)
if is_prefill:
latent_vec_k = latent_vec_k.unflatten(0,
(block_indices.size(0), -1))

# write the latent and rope to kv cache
if kv_cache is not None and len(kv_cache) == 2:
self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices,
block_offsets)
k_cache = kv_cache[0]
v_cache = None

if is_prefill:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata,
batch_size)
else:
return self._forward_decode(q_nope, q_pe, (k_cache, v_cache),
attn_metadata, batch_size)

def _forward_prefill( # type: ignore
self, q: torch.Tensor, k_c_normed: torch.Tensor,
k_pe: torch.Tensor, attn_metadata: HPUAttentionMetadata,
batch_size: int) -> torch.Tensor:
kv_nope = self.kv_b_proj(k_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
v_padded = v_padded.view(batch_size, -1, self.num_heads,
self.qk_head_dim)
out = ops.prompt_attention(
impl=self.prefill_impl,
query=q,
key=k,
value=v_padded,
is_causal=True,
attn_bias=attn_metadata.attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None)
attn_output = out.view(batch_size, -1, self.num_heads, q.shape[-1])
attn_output = attn_output[..., :v.shape[-1]]\
.reshape(batch_size, -1, self.num_heads * v.shape[-1])

return self.o_proj(attn_output)[0]

def _forward_decode( # type: ignore
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata,
batch_size: int) -> torch.Tensor:
query = torch.cat([q_nope, q_pe], dim=-1)

key_cache = kv_cache[0].unsqueeze(2)
value_cache = kv_cache[1] # value_cache is None
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
values_fetch_func=None,
kv_lora_rank=self.kv_lora_rank)
output = output.view(batch_size, 1, -1)
result = self._v_up_proj_and_o_proj(output)
result = result.view(batch_size, 1, -1)
return result


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down Expand Up @@ -153,12 +396,6 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.attn_type = attn_type
if (self.attn_type != AttentionType.DECODER
and self.attn_type != AttentionType.ENCODER_DECODER
Expand Down
Loading