Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,32 @@ def forward(

class AttentionImpl(ABC, Generic[T]):

# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False

# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False

dcp_world_size: int
dcp_rank: int

def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from vllm.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
and self.can_return_lse_for_decode
return self

@abstractmethod
def __init__(
self,
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,10 +1592,6 @@ def forward(

# recorect dcp attn_out with lse.
if self.dcp_world_size > 1:
assert lse is not None, (
"For a mla backend want to enable"
"DCP, it is mandatory that the corresponding decode attn"
"kernel return the softmax lse.")
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())

# v_up projection
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def _build_decode(

class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

can_return_lse_for_decode: bool = True

def __init__(
self,
num_heads: int,
Expand Down
15 changes: 10 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend,
Expand Down Expand Up @@ -3321,10 +3320,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
copy_kv_blocks)

if self.dcp_world_size > 1:
assert self.attn_groups[0][0].backend is FlashMLABackend, (
"DCP only support flashmla now."
"For a mla backend want to enable DCP, it is mandatory that the"
"corresponding decode attn kernel return the softmax lse.")
layer_names = self.attn_groups[0][0].layer_names
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
for layer in layers.values():
assert layer.impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer.impl.__class__.__name__} "
"does not return the softmax lse for decode.")

def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Expand Down