Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple):
multi_node_only: bool
attn_backend: str | None = None
kv_cache_dtype: str | None = None


@dataclass
Expand All @@ -77,6 +78,7 @@ def detailed(
multi_node_only: bool = False,
runner: RunnerOption = "auto",
attn_backend: str | None = None,
kv_cache_dtype: str | None = None,
):
parallel_setups = []
if dcp_multipliers is None:
Expand Down Expand Up @@ -104,6 +106,7 @@ def detailed(
test_options=CPTestOptions(
multi_node_only=multi_node_only,
attn_backend=attn_backend,
kv_cache_dtype=kv_cache_dtype,
),
)

Expand All @@ -129,6 +132,10 @@ def iter_params(self, model_id: str):
cp_kv_cache_interleave_size=64,
attn_backend="FLASHMLA",
),
CPTestSettings.detailed(
dcp_multipliers=[1],
kv_cache_dtype="fp8",
),
],
"Qwen/Qwen2.5-1.5B-Instruct": [
CPTestSettings.detailed(
Expand Down Expand Up @@ -161,7 +168,7 @@ def _test_cp_gsm8k(
chunked_prefill,
) = parallel_setup

multi_node_only, attn_backend = test_options
multi_node_only, attn_backend, kv_cache_dtype = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
Expand Down Expand Up @@ -221,6 +228,8 @@ def _test_cp_gsm8k(

if attn_backend:
server_args.append(f"--attention-backend={attn_backend}")
if kv_cache_dtype:
server_args.extend(["--kv-cache-dtype", kv_cache_dtype])

with RemoteOpenAIServer(
model_id,
Expand Down
111 changes: 92 additions & 19 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,20 +642,27 @@ def forward_impl(
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

if fp8_attention and self.impl.supports_quant_query_input:
if self.impl.dcp_world_size > 1:
if fp8_attention and self.impl.supports_quant_query_input:
# Backend wants FP8 Q: quant first, allgather in FP8
# (halves allgather bandwidth vs BF16)
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
else:
mqa_q = torch.cat((mqa_ql_nope, mqa_q_pe), dim=-1)
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
elif fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
else:
mqa_q = (mqa_ql_nope, mqa_q_pe)
if self.impl.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P)
mqa_q = torch.cat(mqa_q, dim=-1)
# mqa_q do allgather in head dim.
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)

# call decode attn
if not is_sparse_impl:
Expand Down Expand Up @@ -1145,6 +1152,9 @@ class ChunkedContextMetadata:
padded_local_cu_seq_lens: torch.Tensor | None = None
cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
# for mla DCP with FP8 KV cache (gather_and_maybe_dequant_cache)
padded_local_token_to_seq: torch.Tensor | None = None
padded_local_chunk_total_token: list[int] | None = None

block_table: torch.Tensor
query_start_loc: torch.Tensor
Expand Down Expand Up @@ -1825,6 +1835,24 @@ def build(
dtype=torch.int32,
)

# Compute padded-local token_to_seq and total_token
# for gather_and_maybe_dequant_cache (FP8 DCP support)
padded_local_chunk_total_token = padded_local_cu_chunk_seq_lens_cpu[
:, -1
]
padded_local_max_token_num = (
padded_local_chunk_total_token.max().item()
)
padded_local_token_to_seq_cpu = torch.zeros(
[num_chunks, padded_local_max_token_num],
dtype=torch.int32,
)
for i in range(num_chunks):
t2s = torch.repeat_interleave(
range_idx, padded_local_chunk_seq_lens[i]
)
padded_local_token_to_seq_cpu[i, : t2s.shape[0]] = t2s

chunked_context_metadata_cls = (
CudnnPrefillMetadata.ChunkedContextMetadata
if self._use_cudnn_prefill
Expand All @@ -1849,6 +1877,12 @@ def build(
),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
padded_local_token_to_seq=(
padded_local_token_to_seq_cpu.to(device, non_blocking=True)
),
padded_local_chunk_total_token=(
padded_local_chunk_total_token.tolist()
),
)
else:
chunked_context_metadata = chunked_context_metadata_cls(
Expand Down Expand Up @@ -2512,7 +2546,6 @@ def _context_parallel_compute_prefill_context(
k_scale: torch.Tensor,
dcp_world_size: int,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
Expand All @@ -2526,18 +2559,48 @@ def _context_parallel_compute_prefill_context(
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace

fp8_kv_cache = (
self.kv_cache_dtype.startswith("fp8")
and self.kv_cache_dtype != "fp8_ds_mla"
)

for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
i
],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
if fp8_kv_cache:
# FP8 KV cache: gather and dequant to BF16 workspace
assert k_scale is not None
token_to_seq = (
prefill_metadata.chunked_context.padded_local_token_to_seq
)
chunk_total = (
prefill_metadata.chunked_context.padded_local_chunk_total_token
)
assert token_to_seq is not None
assert chunk_total is not None
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
i
],
token_to_seq=token_to_seq[i],
num_tokens=chunk_total[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
else:
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
i
],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# workspace
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
# |<- use for local_gather ->|<--------- use for allgather -------->|
Expand Down Expand Up @@ -2652,12 +2715,22 @@ def forward_mha(
if has_context:
suffix_output, suffix_lse = output_prefill
if self.dcp_world_size > 1:
if self.kv_cache_dtype == "fp8_ds_mla":
raise NotImplementedError(
"DCP > 1 with `kv_cache_dtype='fp8_ds_mla'` is not supported."
)
assert not use_fp8_prefill, (
"DCP>1 with FP8 prefill query quantization is not "
"supported. Use --attention-config "
"'{\"use_prefill_query_quantization\": false}' "
"or reduce decode_context_parallel_size to 1."
)
context_output, context_lse = (
self._context_parallel_compute_prefill_context(
q,
kv_c_and_k_pe_cache,
attn_metadata,
k_scale=None,
k_scale=k_scale,
dcp_world_size=self.dcp_world_size,
)
)
Expand Down