Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,6 @@ def _context_parallel_compute_prefill_context(
k_scale: torch.Tensor,
dcp_world_size: int,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
Expand All @@ -1812,14 +1811,19 @@ def _context_parallel_compute_prefill_context(

for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
# gather_and_maybe_dequant_cache handles both FP8 (dequant) and
# non-FP8 (just copy) based on kv_cache_dtype
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
i
],
batch_size=attn_metadata.num_prefills,
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# workspace
Expand Down Expand Up @@ -1930,7 +1934,7 @@ def _forward_prefill(
q,
kv_c_and_k_pe_cache,
attn_metadata,
k_scale=None,
k_scale=k_scale,
dcp_world_size=self.dcp_world_size,
)
)
Expand Down Expand Up @@ -2131,9 +2135,9 @@ def forward(
else:
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
if not fp8_attention:
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
# decode_q do allgather in head dim.
decode_q = get_dcp_group().all_gather(decode_q, dim=1)

Expand Down
Loading