Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ class AscendAttentionState(Enum):

@dataclass
class AscendMetadata:
"""
Per-layer attention metadata for Ascend FlashAttention backend.

Contains attention masks, token counts, sequence lengths and KV cache
related properties for attention computation.
"""
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
Expand Down Expand Up @@ -186,7 +192,12 @@ class AscendMetadata:


class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
Builder for constructing AscendMetadata from CommonAttentionMetadata.

Handles attention mask generation and metadata preparation for
Ascend FlashAttention backend.
"""
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/attention/context_parallel/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@


class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
Builder for constructing AscendMetadata with Context Parallelism support.

Extends AscendAttentionMetadataBuilder with PCP/DCP metadata handling.
"""
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
Expand Down
16 changes: 14 additions & 2 deletions vllm_ascend/attention/context_parallel/common_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@

@dataclass
class AscendPCPMetadata:
"""
Metadata for Prefill Context Parallelism (PCP) on Ascend devices.

Stores index tensors and sequence lengths for routing attention
computations across PCP ranks during long sequence processing.
"""
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
Expand All @@ -26,7 +32,11 @@ class AscendPCPMetadata:

@dataclass
class CPChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
"""
Metadata for chunked context handling in Context Parallelism (CP).

Extends chunked prefill with per-rank chunk information for PCP/DCP.
"""
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
Expand All @@ -46,9 +56,11 @@ class CPChunkedContextMetadata:

@dataclass
class AscendMetadataForPrefill:
""" Prefill-specific metadata for Ascend attention with Context Parallelism."""

@dataclass
class ChunkedContextMetadata:
"""Metadata for chunked context processing within prefill phase."""
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
Expand All @@ -69,7 +81,7 @@ class ChunkedContextMetadata:

@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
""" Decode-specific metadata for Ascend attention with Context Parallelism."""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
Expand Down
10 changes: 7 additions & 3 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ def get_impl_cls() -> Type["MLAAttentionImpl"]:

@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
"""
Metadata for chunked context handling in MLA attention.

Manages sequence boundaries and workspace for chunked prefill processing.
"""
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
Expand Down Expand Up @@ -116,7 +119,8 @@ class AscendMLAPrefillMetadata:

@dataclass
class AscendMLADecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
""" Decode-specific metadata for Ascend MLA attention."""
# Input positions for rotary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
Expand Down
24 changes: 20 additions & 4 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def enable_cp():


@dataclass
# class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata:
"""
Metadata for Prefill Context Parallelism (PCP) in CommonAttentionMetadata.

Contains index tensors and sequence lengths for PCP operations.
"""
pcp_allgather_restore_idx: torch.Tensor = None

cp_kv_recover_idx_for_chunk: torch.Tensor = None
Expand Down Expand Up @@ -81,24 +85,36 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):

For many of the tensors we keep both NPU and CPU versions.
"""
# CPU tensor of sequence lengths for host-side operations.
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
seq_lens_cpu: torch.Tensor = None

# CPU tensor of already computed tokens count per request.
# E.g., tensor([100, 200, 50]) means req0 has 100 tokens already computed.
num_computed_tokens_cpu: torch.Tensor = None

# Number of decode tokens per request, used for speculative decoding.
# E.g., 1 for normal decoding, >1 for speculative decoding.
decode_token_per_req: int = 1
"""decode token number per request"""

# Actual query sequence lengths for each token in the batch (CPU list).
# E.g., [1, 1, 1, 128] for 3 decode tokens and 1 prefill with 128 tokens.
actual_seq_lengths_q: list[int] = field(default_factory=list)

# NPU tensor of position indices for rotary embeddings computation.
# E.g., tensor([0, 1, 2, ...]) indicating token positions in sequence.
positions: torch.Tensor = None

# Current attention state (e.g., ChunkedPrefill, DecodeOnly).
attn_state: Any = None

# Padding size for graph capture, -1 means not in graph mode.
graph_pad_size: int = -1

# num_input_tokens refers to total number of tokens including
# padding tokens. It is used to handle some padding operations.
# Total number of tokens including padding, used for padding operations.
num_input_tokens: int = 0

# Metadata for Prefill Context Parallelism (PCP) operations.
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None

Expand Down