Skip to content
6 changes: 5 additions & 1 deletion tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,21 @@ def _compare_cp_with_tp(


CP_TEXT_GENERATION_MODELS = {
# [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
],
}

CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
"bigcode/gpt_bigcode-santacoder",
]


Expand Down
5 changes: 4 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def check_available_online(
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
"bigcode/starcoder",
extras={"tiny": "bigcode/tiny_starcoder_py"},
extras={
"tiny": "bigcode/tiny_starcoder_py",
"santacoder": "bigcode/gpt_bigcode-santacoder",
},
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0",
),
Expand Down
10 changes: 9 additions & 1 deletion vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def cp_lse_ag_out_rs(
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse=False,
):
"""
cp_attn_out: [ B, H, D ]
Expand All @@ -192,8 +193,15 @@ def cp_lse_ag_out_rs(

cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)

if return_lse:
cp_num_heads = lse.shape[1] // cp_group.world_size
cp_rank = cp_group.rank_in_group
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
return out, lse
return out


Expand Down
17 changes: 17 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,23 @@ def verify_with_parallel_config(
"Supported models implement the `SupportsPP` interface."
)

decode_context_parallel_size = parallel_config.decode_context_parallel_size
if decode_context_parallel_size > 1 and not self.use_mla:
total_num_kv_heads = self.get_total_num_kv_heads()
assert tensor_parallel_size > total_num_kv_heads, (
f"tensor parallel size {tensor_parallel_size} must be greater "
f"than total num kv heads {total_num_kv_heads} when enable "
f"decode context parallel for GQA/MQA"
)

max_dcp_size = tensor_parallel_size // total_num_kv_heads
assert decode_context_parallel_size <= max_dcp_size, (
f"decode context parallel size must less than or equal to "
f"(tensor parallel size {tensor_parallel_size} // total "
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
f"but got {decode_context_parallel_size}"
)

def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
Expand Down
202 changes: 172 additions & 30 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_quantized_kv_cache,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8,
Expand All @@ -32,6 +33,7 @@
)

from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
Expand Down Expand Up @@ -147,6 +149,10 @@ class FlashAttentionMetadata:
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None

# For GQA DCP
max_dcp_context_kv_len: int | None = None
dcp_context_kv_lens: torch.Tensor | None = None

# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
Expand Down Expand Up @@ -216,6 +222,16 @@ def __init__(
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = get_flash_attn_version() == 3

try:
from vllm.distributed.parallel_state import get_dcp_group

self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0

self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
Expand Down Expand Up @@ -306,7 +322,7 @@ def schedule(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads_q,
num_heads_q=self.num_heads_q * self.dcp_world_size,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
cache_seqlens=seqlens,
Expand All @@ -320,8 +336,35 @@ def schedule(
return None

use_cascade = common_prefix_len > 0
max_dcp_context_kv_len = 0
dcp_context_kv_lens = None

cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None

if self.dcp_world_size > 1:
query_kv_lens_cpu = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()

if use_cascade:
scheduler_metadata = schedule(
batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=dcp_context_kv_lens,
max_seq_len=max_dcp_context_kv_len,
causal=False,
)
elif use_cascade:
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
)
Expand All @@ -348,10 +391,6 @@ def schedule(
causal=True,
)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(
batch_size=num_reqs,
cu_query_lens=query_start_loc,
Expand Down Expand Up @@ -379,6 +418,8 @@ def schedule(
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
Expand All @@ -396,6 +437,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:


class FlashAttentionImpl(AttentionImpl):
can_return_lse_for_decode: bool = True

def __init__(
self,
num_heads: int,
Expand Down Expand Up @@ -562,30 +605,45 @@ def forward(

descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output

# Cascade attention (rare case).
cascade_attention(
Expand Down Expand Up @@ -615,6 +673,86 @@ def forward(
)
return output

def _forward_with_dcp(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
q_descale: torch.Tensor | None = None,
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
) -> torch.Tensor:
cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
block_table = attn_metadata.block_table

query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
context_attn_out,
context_lse.transpose(0, 1),
get_dcp_group(),
return_lse=True,
)
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()

query_attn_out, query_lse = flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
merge_attn_states(
output,
context_attn_out_cor,
context_lse_cor,
query_attn_out,
query_lse,
)

def _forward_encoder_attention(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -684,6 +822,7 @@ def use_cascade_attention(
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
dcp_world_size: int,
) -> bool:
"""Decide whether to use cascade attention.

Expand All @@ -705,6 +844,9 @@ def use_cascade_attention(
num_reqs = len(query_lens)
if num_reqs < 8:
return False
# disable cascade attention for DCP
if dcp_world_size > 1:
return False

# Heuristics to decide whether using cascade attention is beneficial.
# 1. When FlashDecoding is not used for normal attention, cascade attention
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def use_cascade_attention(
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
dcp_world_size: int,
) -> bool:
return False

Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ def _compute_cascade_attn_prefix_len(
use_sliding_window=use_sliding_window,
use_local_attention=use_local_attention,
num_sms=self.num_sms,
dcp_world_size=self.dcp_world_size,
)
return common_prefix_len if use_cascade else 0

Expand Down