Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions vllm_ascend/attention/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ def build(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
Expand Down
49 changes: 17 additions & 32 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills,
enable_cp, split_decodes_and_prefills,
using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
Expand All @@ -52,19 +52,15 @@ def get_name() -> str:

@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
prefill_config = get_current_vllm_config().parallel_config
if (prefill_config.prefill_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1):
if enable_cp():
from vllm_ascend.attention.attention_cp import \
AscendAttentionCPImpl
return AscendAttentionCPImpl
return AscendAttentionBackendImpl

@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
prefill_config = get_current_vllm_config().parallel_config
if (prefill_config.prefill_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1):
if enable_cp():
from vllm_ascend.attention.attention_cp import \
AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder
Expand Down Expand Up @@ -191,10 +187,8 @@ class AscendMetadata:
seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore
query_start_loc_list: List[int] = None # type: ignore

query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None
# Maximum query length in the batch (None for decoding).
max_query_len: Optional[int] = None

Expand All @@ -214,9 +208,9 @@ class AscendMetadata:
# dcp
decode_meta: Optional[AscendMetadataForDecode] = None

# Whether is the pooling model with causal attention,
# used to guide the attention computation for pooling models.
is_causal_pooling: Optional[bool] = None
causal: bool = True
# runner_type in model_config.
model_runner_type: str = ""


class AscendAttentionMetadataBuilder:
Expand Down Expand Up @@ -276,11 +270,8 @@ def build(

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens

block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]

long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
Expand All @@ -297,19 +288,13 @@ def build(
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
is_causal_pooling = None
if self.model_config.runner_type == "pooling":
is_causal_pooling = common_attn_metadata.causal if hasattr(
common_attn_metadata, 'causal') else True

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
Expand All @@ -319,7 +304,8 @@ def build(
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
is_causal_pooling=is_causal_pooling)
causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type)
return attn_metadata

def build_for_graph_capture(
Expand Down Expand Up @@ -384,9 +370,9 @@ def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata)

num_tokens = attn_metadata.query_start_loc_list[-1]
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level

Expand All @@ -402,7 +388,7 @@ def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
Expand All @@ -422,7 +408,7 @@ def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))

Expand All @@ -435,7 +421,7 @@ def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
Expand Down Expand Up @@ -518,10 +504,10 @@ def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
Expand Down Expand Up @@ -644,9 +630,8 @@ def _forward_encoder_attention(self, query: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.is_causal_pooling is not None

if attn_metadata.is_causal_pooling:
if attn_metadata.causal:
# use sparse_mode 3 in causal scenario
return torch_npu.npu_fusion_attention(
query=query,
Expand Down Expand Up @@ -768,7 +753,7 @@ def forward(
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
# pooling model branch
if isinstance(attn_metadata.is_causal_pooling, bool):
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
Expand Down
7 changes: 3 additions & 4 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight, transdata,
Expand Down Expand Up @@ -58,8 +59,7 @@ def get_name() -> str:

@staticmethod
def get_builder_cls():
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder
Expand All @@ -71,8 +71,7 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,

@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
return AscendMlaCPImpl
return AscendMLAImpl
Expand Down
18 changes: 13 additions & 5 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
Expand All @@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
return runtime_shape in get_ascend_config().pa_shape_list


@lru_cache(maxsize=1)
def enable_cp():
prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \
or prefill_config.decode_context_parallel_size > 1


@dataclass
# class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata:
Expand Down Expand Up @@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata:
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.

For many of the tensors we keep both GPU and CPU versions.
For many of the tensors we keep both NPU and CPU versions.
"""

query_start_loc: torch.Tensor
Expand Down Expand Up @@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata:

attn_state: Any = None

is_only_prefill: bool = False

graph_pad_size: int = -1

# num_input_tokens refers to total number of tokens including
Expand All @@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None

causal: bool = True

# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
Expand All @@ -137,12 +145,12 @@ def unpadded(self, num_actual_tokens: int,
decode_token_per_req=self.decode_token_per_req,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=self.is_only_prefill,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self.
Expand Down
6 changes: 3 additions & 3 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
attn_output, softmax_lse) = param

seq_lens = forward_context.attn_metadata[key].seq_lens_list
query_start_loc = forward_context.attn_metadata[
key].query_start_loc_list
actual_seq_lengths_q = forward_context.attn_metadata[
key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
Expand All @@ -282,7 +282,7 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
Expand Down
5 changes: 2 additions & 3 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,11 @@ def _propose(
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens

attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range(
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,6 @@ def _prepare_inputs(
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,
Expand Down
2 changes: 0 additions & 2 deletions vllm_ascend/worker/v2/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def build_attn_metadata(
| None = None,
spec_attn_mask: torch.Tensor | None = None,
attn_state: Any | None = None,
is_only_prefill: bool = False,
graph_pad_size: int = -1,
num_input_tokens: int = 0,
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
Expand Down Expand Up @@ -78,7 +77,6 @@ def build_attn_metadata(
attn_mask=attn_mask,
spec_attn_mask=spec_attn_mask,
attn_state=attn_state,
is_only_prefill=is_only_prefill,
graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/xlite/xlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def __call__(
if not with_prefill or self.full_mode:
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
seq_lens = attn_metadata.seq_lens[:batch]
query_lens = attn_metadata.query_lens[:batch]
query_lens = attn_metadata.query_start_loc_cpu[
1:] - attn_metadata.query_start_loc_cpu[:-1]
query_lens = query_lens[:batch]
Comment thread
weijinqian0 marked this conversation as resolved.
cached_lens = seq_lens - query_lens

xlite_attn_metadata = ModelAttnMeta()
Expand Down
Loading