diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index a28630921771..28379f102af2 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -58,6 +58,7 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool attn_backend: str | None = None + kv_cache_dtype: str | None = None @dataclass @@ -77,6 +78,7 @@ def detailed( multi_node_only: bool = False, runner: RunnerOption = "auto", attn_backend: str | None = None, + kv_cache_dtype: str | None = None, ): parallel_setups = [] if dcp_multipliers is None: @@ -104,6 +106,7 @@ def detailed( test_options=CPTestOptions( multi_node_only=multi_node_only, attn_backend=attn_backend, + kv_cache_dtype=kv_cache_dtype, ), ) @@ -129,6 +132,10 @@ def iter_params(self, model_id: str): cp_kv_cache_interleave_size=64, attn_backend="FLASHMLA", ), + CPTestSettings.detailed( + dcp_multipliers=[1], + kv_cache_dtype="fp8", + ), ], "Qwen/Qwen2.5-1.5B-Instruct": [ CPTestSettings.detailed( @@ -161,7 +168,7 @@ def _test_cp_gsm8k( chunked_prefill, ) = parallel_setup - multi_node_only, attn_backend = test_options + multi_node_only, attn_backend, kv_cache_dtype = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -221,6 +228,8 @@ def _test_cp_gsm8k( if attn_backend: server_args.append(f"--attention-backend={attn_backend}") + if kv_cache_dtype: + server_args.extend(["--kv-cache-dtype", kv_cache_dtype]) with RemoteOpenAIServer( model_id, diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 926e8892eae8..e0a6c9f54687 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -642,7 +642,20 @@ def forward_impl( # Convert from (N, B, L) to (B, N, L) mqa_ql_nope = mqa_ql_nope.transpose(0, 1) - if fp8_attention and self.impl.supports_quant_query_input: + if self.impl.dcp_world_size > 1: + if fp8_attention and self.impl.supports_quant_query_input: + # Backend wants FP8 Q: quant first, allgather in FP8 + # (halves allgather bandwidth vs BF16) + assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] + assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] + mqa_q = self._decode_concat_quant_fp8_op( + mqa_ql_nope, mqa_q_pe, self._q_scale + ) + mqa_q = get_dcp_group().all_gather(mqa_q, dim=1) + else: + mqa_q = torch.cat((mqa_ql_nope, mqa_q_pe), dim=-1) + mqa_q = get_dcp_group().all_gather(mqa_q, dim=1) + elif fp8_attention and self.impl.supports_quant_query_input: assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] mqa_q = self._decode_concat_quant_fp8_op( @@ -650,12 +663,6 @@ def forward_impl( ) else: mqa_q = (mqa_ql_nope, mqa_q_pe) - if self.impl.dcp_world_size > 1: - assert not fp8_attention, "DCP not support fp8 kvcache now." - # concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P) - mqa_q = torch.cat(mqa_q, dim=-1) - # mqa_q do allgather in head dim. - mqa_q = get_dcp_group().all_gather(mqa_q, dim=1) # call decode attn if not is_sparse_impl: @@ -1145,6 +1152,9 @@ class ChunkedContextMetadata: padded_local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None chunk_size: int | None = None + # for mla DCP with FP8 KV cache (gather_and_maybe_dequant_cache) + padded_local_token_to_seq: torch.Tensor | None = None + padded_local_chunk_total_token: list[int] | None = None block_table: torch.Tensor query_start_loc: torch.Tensor @@ -1825,6 +1835,24 @@ def build( dtype=torch.int32, ) + # Compute padded-local token_to_seq and total_token + # for gather_and_maybe_dequant_cache (FP8 DCP support) + padded_local_chunk_total_token = padded_local_cu_chunk_seq_lens_cpu[ + :, -1 + ] + padded_local_max_token_num = ( + padded_local_chunk_total_token.max().item() + ) + padded_local_token_to_seq_cpu = torch.zeros( + [num_chunks, padded_local_max_token_num], + dtype=torch.int32, + ) + for i in range(num_chunks): + t2s = torch.repeat_interleave( + range_idx, padded_local_chunk_seq_lens[i] + ) + padded_local_token_to_seq_cpu[i, : t2s.shape[0]] = t2s + chunked_context_metadata_cls = ( CudnnPrefillMetadata.ChunkedContextMetadata if self._use_cudnn_prefill @@ -1849,6 +1877,12 @@ def build( ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), chunk_size=padded_local_max_context_chunk_across_ranks, + padded_local_token_to_seq=( + padded_local_token_to_seq_cpu.to(device, non_blocking=True) + ), + padded_local_chunk_total_token=( + padded_local_chunk_total_token.tolist() + ), ) else: chunked_context_metadata = chunked_context_metadata_cls( @@ -2512,7 +2546,6 @@ def _context_parallel_compute_prefill_context( k_scale: torch.Tensor, dcp_world_size: int, ): - assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None @@ -2526,18 +2559,48 @@ def _context_parallel_compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace + fp8_kv_cache = ( + self.kv_cache_dtype.startswith("fp8") + and self.kv_cache_dtype != "fp8_ds_mla" + ) + for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.cp_gather_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_table, - cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ - i - ], - batch_size=attn_metadata.num_prefills, - seq_starts=prefill_metadata.chunked_context.starts[i], - ) + if fp8_kv_cache: + # FP8 KV cache: gather and dequant to BF16 workspace + assert k_scale is not None + token_to_seq = ( + prefill_metadata.chunked_context.padded_local_token_to_seq + ) + chunk_total = ( + prefill_metadata.chunked_context.padded_local_chunk_total_token + ) + assert token_to_seq is not None + assert chunk_total is not None + ops.gather_and_maybe_dequant_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ + i + ], + token_to_seq=token_to_seq[i], + num_tokens=chunk_total[i], + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + else: + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ + i + ], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) # workspace # |------- N tokens --------|--------- N*dcp_size tokens ----------| # |<- use for local_gather ->|<--------- use for allgather -------->| @@ -2652,12 +2715,22 @@ def forward_mha( if has_context: suffix_output, suffix_lse = output_prefill if self.dcp_world_size > 1: + if self.kv_cache_dtype == "fp8_ds_mla": + raise NotImplementedError( + "DCP > 1 with `kv_cache_dtype='fp8_ds_mla'` is not supported." + ) + assert not use_fp8_prefill, ( + "DCP>1 with FP8 prefill query quantization is not " + "supported. Use --attention-config " + "'{\"use_prefill_query_quantization\": false}' " + "or reduce decode_context_parallel_size to 1." + ) context_output, context_lse = ( self._context_parallel_compute_prefill_context( q, kv_c_and_k_pe_cache, attn_metadata, - k_scale=None, + k_scale=k_scale, dcp_world_size=self.dcp_world_size, ) )