Skip to content
Merged
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00
GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
16 changes: 14 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
dcp_tot_seq_lens: Optional[torch.Tensor]


D = TypeVar("D", bound=MLACommonDecodeMetadata)
Expand Down Expand Up @@ -682,10 +683,12 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

def build_for_cudagraph_capture(
Expand Down Expand Up @@ -727,6 +730,7 @@ def build(
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens

query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]

Expand All @@ -742,7 +746,10 @@ def build(

# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + (
assert dcp_local_seq_lens is not None
dcp_local_seq_lens[:num_decodes] = seq_lens[
:num_decodes
] // self.dcp_world_size + (
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
)

Expand Down Expand Up @@ -899,10 +906,15 @@ def build(
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
seq_lens_device=dcp_local_seq_lens[:num_decodes]
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
else seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
if self.dcp_world_size > 1
else None,
)

attn_metadata = self.metadata_cls(
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
get_flash_attn_version,
)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
Expand Down Expand Up @@ -107,12 +106,6 @@ def __init__(
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH

# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.reorder_batch_threshold = (
1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
)

def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
Expand All @@ -121,7 +114,7 @@ def _schedule_decode(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads,
num_heads_q=self.num_heads * self.dcp_world_size,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
Expand All @@ -142,10 +135,11 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
max_seq_len = seq_lens_device.max().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@minosfuture this introduces a CPU sync and will be problematic for performance, especially in the async scheduling case. It looks like it should not be too hard to ensure we obtain this from the cpu-side tensor in the DCP case too?


scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
Expand Down Expand Up @@ -188,6 +182,7 @@ def _build_decode(
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)


Expand Down Expand Up @@ -289,6 +284,9 @@ def _forward_decode(
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
)

if self.need_to_return_lse_for_decode:
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens_device,
Expand Down Expand Up @@ -146,6 +147,7 @@ def _build_decode(
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)


Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
Expand Down Expand Up @@ -174,6 +175,7 @@ def _build_decode(
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

return attn_metadata
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class CommonAttentionMetadata:
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None

dcp_local_seq_lens: Optional[torch.Tensor] = None
"""Sequence lengths of the local rank in decode context parallelism world"""


def slice_query_start_locs(
query_start_loc: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def prepare_inputs_padded(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)

token_indices_to_sample = (
Expand Down Expand Up @@ -868,6 +869,7 @@ def prepare_inputs(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)

return spec_common_attn_metadata, token_indices
Expand Down
15 changes: 14 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,10 @@ def __init__(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
Expand Down Expand Up @@ -581,7 +585,10 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
assert self.reorder_batch_threshold == 1, (
"DCP not support reorder_batch_threshold > 1 now."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since only flash_attn_mla support custom mask, we can't just remove this assert right now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense. I'll make a whitelist here for FA3 MLA

)
Expand Down Expand Up @@ -1335,6 +1342,9 @@ def _prepare_inputs(
num_logits_indices=logits_indices.size(0),
causal=True,
encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)

if self.speculative_config and spec_decode_common_attn_metadata is None:
Expand Down Expand Up @@ -3309,6 +3319,9 @@ def _dummy_run(
kv_cache_group_id
].slot_mapping.gpu[:num_tokens],
causal=True,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if ubatch_slices is not None:
Expand Down