diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2ee2740a51ba..27a1f3bc05b2 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1796,7 +1796,6 @@ def _context_parallel_compute_prefill_context( k_scale: torch.Tensor, dcp_world_size: int, ): - assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None @@ -1812,14 +1811,19 @@ def _context_parallel_compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.cp_gather_cache( + # gather_and_maybe_dequant_cache handles both FP8 (dequant) and + # non-FP8 (just copy) based on kv_cache_dtype + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ i ], - batch_size=attn_metadata.num_prefills, + token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], + num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i], ) # workspace @@ -1930,7 +1934,7 @@ def _forward_prefill( q, kv_c_and_k_pe_cache, attn_metadata, - k_scale=None, + k_scale=k_scale, dcp_world_size=self.dcp_world_size, ) ) @@ -2131,9 +2135,9 @@ def forward( else: decode_q = (decode_ql_nope, decode_q_pe) if self.dcp_world_size > 1: - assert not fp8_attention, "DCP not support fp8 kvcache now." - # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) - decode_q = torch.cat(decode_q, dim=-1) + if not fp8_attention: + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) # decode_q do allgather in head dim. decode_q = get_dcp_group().all_gather(decode_q, dim=1)