Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions tests/ut/_310p/attention/test_attention_v1_310.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setUp(self):
def test_forward_prefill_310(
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
):
"""Test forward pass in PrefillCacheHit state"""
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
Expand All @@ -98,7 +98,7 @@ def test_forward_prefill_310(

mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)

mock_npu_npu_flash_attention.assert_called_once()

Expand All @@ -107,10 +107,15 @@ def test_forward_prefill_310(
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
def test_forward_chunked_prefill_310(
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
"""Test forward pass in ChunkedPrefill state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.ChunkedPrefill
Expand All @@ -128,7 +133,42 @@ def test_forward_chunked_prefill_310(

mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)

mock_npu_paged_attention_splitfuse.assert_called_once()

@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
def test_forward_prefill_cache_hit_310(
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 128, 16, 16)
metadata.query_lens = torch.tensor([5])
metadata.seq_lens = torch.tensor([1, 4])
metadata.query_start_loc = torch.tensor([0, 1, 5])
metadata.actual_seq_lengths_q = [5]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)

mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_impl(query, key, value, None, metadata, output)

mock_npu_paged_attention_splitfuse.assert_called_once()

Expand All @@ -141,6 +181,7 @@ def test_forward_paged_attention_310(
):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)

metadata = self.attn_metadata
Expand All @@ -155,6 +196,15 @@ def test_forward_paged_attention_310(

mock_get_forward_context.return_value = MagicMock(capturing=False)

output = self.impl.forward_paged_attention(query, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)

mock_paged_attention.assert_called_once()

def test_forward_mtp_310(self):
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.SpecDecoding
with self.assertRaises(NotImplementedError):
output = self.impl.forward_impl(query, key, value, None, metadata, output)
35 changes: 17 additions & 18 deletions vllm_ascend/_310p/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def forward_chunked_prefill_310(self, query, attn_metadata, output):
out=output,
)

return output

def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
"""
Main dispatch method for attention operations.
Expand All @@ -218,22 +220,19 @@ def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
NotImplementedError: If the attention state is not supported on 310P.
"""
state = attn_metadata.attn_state

if state == AscendAttentionState.DecodeOnly:
return self.forward_paged_attention(query, attn_metadata, output)

# Condition for PrefillNoCache: No previous tokens have been processed yet
if state == AscendAttentionState.PrefillNoCache:
out = self.forward_prefill_310(query, key, value, attn_metadata, output)
return out

if state == AscendAttentionState.ChunkedPrefill:
self.forward_chunked_prefill_310(query, attn_metadata, output)
return output

raise NotImplementedError(
f"{self.__class__.__name__}.forward_impl: 310P only supports "
f"{AscendAttentionState.DecodeOnly.name}, "
f"{AscendAttentionState.PrefillNoCache.name}, "
f"{AscendAttentionState.ChunkedPrefill.name}, "
f"got {state!r}."
)
output = self.forward_prefill_310(query, key, value, attn_metadata, output)
# Condition for DecodeOnly: Pure decoding phase where each request generates one token
elif state == AscendAttentionState.DecodeOnly:
output = self.forward_paged_attention(query, attn_metadata, output)
# Condition for ChunkedPrefill:
# 1. During speculative decoding scenarios (except mtp)
# 2. Processing large prefill requests in chunks
# Condition for PrefillCacheHit: Indicates prefill with some cached tokens already processed
elif state in [AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit]:
output = self.forward_chunked_prefill_310(query, attn_metadata, output)
# Condition for SpecDecoding: Specified for mtp, which is not supported yet.
else:
raise NotImplementedError(f"AscendAttentionState: {state} is not supported for 310P currently.")
return output
96 changes: 96 additions & 0 deletions vllm_ascend/_310p/model_runner_310p.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from __future__ import annotations

import numpy as np
import torch
import torch_npu
from vllm.logger import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec

from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
Expand Down Expand Up @@ -185,3 +187,97 @@ def _allocate_kv_cache_and_reshape_tensors(
raise ValueError("Unknown KV cache spec type.")

return kv_caches

# Override this function because of tensor.copy_(other) accuracy issue.
# TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved.
def _prepare_input_ids(
self,
scheduler_output: SchedulerOutput,
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray,
) -> None:
"""Prepare the input IDs for the current batch.

Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""

if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return

# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the NPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
indices_match = True
max_flattened_index = -1
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens

for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1))
start = prev_index * self.num_spec_tokens
prev_draft_token_indices.extend(range(start, start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(sample_flattened_indices)
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
if num_commmon_tokens < total_without_spec:
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
return
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
# NOTE: Override the copy_ function here
indices = torch.arange(num_commmon_tokens, device=self.input_ids.gpu.device)
source = self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0]
self.input_ids.gpu.index_copy_(0, indices, source)
if self.enable_prompt_embeds:
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=self.input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0],
)
# Scatter the draft tokens after the sampled tokens are scattered.
if self._draft_token_ids is None or not spec_flattened_indices:
return
assert isinstance(self._draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self.input_ids.gpu.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
)