Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 396 additions & 0 deletions tests/scheduler/test_scheduler.py

Large diffs are not rendered by default.

60 changes: 53 additions & 7 deletions vllm_ascend/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)


def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
# Construct lower triangle matrix.
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len),
Expand All @@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
if mask_value is None:
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
mask_flag, mask_value).to(dtype)
return attn_mask
Expand All @@ -66,12 +67,14 @@ class AttentionMaskBuilder:
def __init__(self, attn_mask: torch.Tensor):
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
self.splitfuse_mask_value = -10000

@classmethod
def initialize_from_len(cls,
max_seq_len: int,
dtype: torch.dtype = torch.float16):
return cls(generate_attn_mask(max_seq_len, dtype))
dtype: torch.dtype = torch.float16,
mask_value: Optional[int] = None):
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))

def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
device: torch.device):
Expand All @@ -97,6 +100,49 @@ def get_decode_attn_mask(
return (self.attn_mask_cache.index_select(
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())

def get_splitfuse_attn_mask(
self,
seq_lens,
query_lens,
position,
dtype,
device,
) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self._seq_len_cached:
self.update_attn_cache(max_seq_len, dtype, device)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
if self.attn_mask_cache[0][1] > 0:
attn_mask = self.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
attn_mask *= -10000
else:
attn_mask = self.attn_mask_cache
return torch.index_select(attn_mask, dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=dtype,
device="cpu")

current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len

assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = self.splitfuse_mask_value
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.mask_fill_(
right_tensor.tril() == self.splitfuse_mask_value, 0)
current_row += q_len

return attn_mask.to(device, non_blocking=True)


class AscendAttentionBackend(AttentionBackend):

Expand Down
70 changes: 51 additions & 19 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
Expand Down Expand Up @@ -50,7 +51,7 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def swap_blocks(
Expand Down Expand Up @@ -83,6 +84,12 @@ def copy_blocks(
value_caches[dst_indices] = value_caches[src_indices]


class AscendAttentionState(Enum):
PrefillOnly = 0
DecodeOnly = 1
ChunkedPrefill = 2


@dataclass
class AscendMetadata:
# (batch_size, max_blocks_per_seq).
Expand All @@ -104,6 +111,8 @@ class AscendMetadata:
# FlashAttention has better performance than PageAtttention,
# but it does not support decode requests.
is_only_prefill: bool = False
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill

attn_mask: Optional[torch.Tensor] = None

Expand Down Expand Up @@ -139,7 +148,8 @@ def __init__(

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None
self.key_cache = None
self.value_cache = None

def forward(
self,
Expand Down Expand Up @@ -190,30 +200,52 @@ def forward(
# TODO: Remove this contiguous in the future.
value = value.contiguous()

if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)

if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.context_lens,
out=output)
# Normal V1 situation.
else:
if kv_cache.numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(num_blocks, block_size,
self.num_kv_heads, self.head_size)
value_cache = value_cache.view(num_blocks, block_size,
self.num_kv_heads,
self.head_size)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slots)

# use paged attention
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=key_cache,
value_cache=value_cache,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.seq_lens,
Expand Down
Empty file added vllm_ascend/core/__init__.py
Empty file.
73 changes: 73 additions & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass, fields
from typing import Type, Union

from vllm.config import SchedulerConfig


@dataclass
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
policy: str = "fcfs"
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")

@classmethod
def initialize_from_config(
cls,
vllm_scheduler_config: SchedulerConfig,
ascend_scheduler_config: dict,
):
scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name)
for field in fields(vllm_scheduler_config) if field.init
}
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["policy"] = "fcfs"
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
# Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config
for k, v in ascend_scheduler_config.items():
scheduler_config[k] = v
return cls(**scheduler_config)

def __post_init__(self) -> None:
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.policy != "fcfs":
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
)
if self.is_multimodal_model:
raise NotImplementedError(
"currently AscendScheduler only supports LLM modles.")
if self.num_scheduler_steps > 1:
raise NotImplementedError(
"currently AscendScheduler doesn't support multi-step.")
if self.send_delta_data:
raise NotImplementedError(
"currently AscendScheduler doesn't support send_delta_data.")
if self.delay_factor > 0:
raise NotImplementedError(
"currently AscendScheduler doesn't support scheduler_delay_factor."
)
Loading