Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 22 additions & 31 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,22 +522,6 @@ def forward_impl(
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]

assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)

has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens

decode_q = q[:num_decode_tokens]

prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]

# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
Expand All @@ -555,27 +539,32 @@ def forward_impl(
# Sparse MLA impls only support forward_mqa (decode-style attention)
is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)

if has_prefill and not is_sparse_impl:
if is_sparse_impl:
num_mqa_tokens = q.size(0)
num_mha_tokens = 0
else:
assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)
num_mqa_tokens = attn_metadata.num_decode_tokens
num_mha_tokens = q.size(0) - num_mqa_tokens

if num_mha_tokens > 0:
self.impl.forward_mha(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
q[num_mqa_tokens:],
k_c_normed[num_mqa_tokens:],
k_pe[num_mqa_tokens:],
kv_cache,
attn_metadata,
self._k_scale,
output=output[num_decode_tokens:],
output=output[num_mqa_tokens:],
)

if has_decode or (has_prefill and is_sparse_impl):
# For sparse impl, we always use forward_mqa for all tokens
# For non-sparse impl, we only use forward_mqa for decode tokens
if is_sparse_impl:
mqa_q = q
mqa_output_slice = output
else:
assert attn_metadata.decode is not None
mqa_q = decode_q
mqa_output_slice = output[:num_decode_tokens]
if num_mqa_tokens > 0:
mqa_q = q[:num_mqa_tokens]
mqa_output_slice = output[:num_mqa_tokens]

mqa_q_nope, mqa_q_pe = mqa_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
Expand Down Expand Up @@ -644,6 +633,8 @@ def forward_impl(
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)

# call decode attn
if not is_sparse_impl:
assert attn_metadata.decode is not None
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

# correct dcp attn_out with lse.
Expand Down