Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def setUp(self):
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.model_config.hf_text_config.sliding_window = None
self.mock_vllm_config.cache_config.block_size = 64
self.mock_vllm_config.compilation_config.cudagraph_mode = None
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
Expand Down Expand Up @@ -89,8 +90,6 @@ def test_build_non_310p(self, mock_soc_version, mock_ascend_metadata):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None,
Expand Down
2 changes: 0 additions & 2 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,8 +1004,6 @@ def mock_npu_ring_mla_effect(q_nope, q_rope, k_nope, k_rope, value,
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens = kv_with_q_head_nomask_seqlens
attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens = kv_with_q_tail_nomask_seqlens
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
torch.ones(10, 10, dtype=torch.float16), 1)

output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
value, kv_c_and_k_pe_cache,
Expand Down
56 changes: 41 additions & 15 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,15 @@ def test_ascend_mla_metadata_builder_spec_decode(self):
mock_vllm_config.scheduler_config.enable_chunked_prefill)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_get_cos_and_sin_mla):
self, mock_get_pcp_group, mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
Expand Down Expand Up @@ -400,14 +407,21 @@ def setUp(self):
self.kv_cache_spec.num_heads = 32

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros,
mock_zeros, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group

def zeros_override(*args, **kwargs):
kwargs.pop('pin_memory', None)
Expand All @@ -426,8 +440,6 @@ def zeros_override(*args, **kwargs):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None,
Expand Down Expand Up @@ -458,14 +470,21 @@ def zeros_override(*args, **kwargs):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros,
mock_zeros, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group

def zeros_override(*args, **kwargs):
kwargs.pop('pin_memory', None)
Expand All @@ -485,8 +504,6 @@ def zeros_override(*args, **kwargs):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None,
Expand Down Expand Up @@ -517,8 +534,16 @@ def zeros_override(*args, **kwargs):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
def test_build_decode_only_metadata(self, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group

common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
Expand All @@ -532,8 +557,6 @@ def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
decode_token_per_req=torch.tensor([1, 1, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((3, 3)),
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None,
Expand Down Expand Up @@ -563,9 +586,16 @@ def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
def test_build_for_graph_capture_decode_only(self,
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
def test_build_for_graph_capture_decode_only(self, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group

common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
Expand All @@ -579,8 +609,6 @@ def test_build_for_graph_capture_decode_only(self,
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
decode_token_per_req=torch.tensor([1, 1, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((3, 3)),
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None,
Expand Down Expand Up @@ -625,8 +653,6 @@ def test_build_for_graph_capture_prefill(self, mock_get_cos_and_sin_mla):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None,
Expand Down
4 changes: 0 additions & 4 deletions tests/ut/spec_decode/test_mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,6 @@ def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):

mock_runner = MagicMock()
mock_runner.actual_seq_lengths_q = MagicMock()
mock_runner.attn_mask = MagicMock()
mock_runner.spec_attn_mask = MagicMock()
mock_runner.attn_state = MagicMock()
mock_runner.graph_pad_size = 0
mock_runner.decode_token_per_req = MagicMock()
Expand Down Expand Up @@ -334,5 +332,3 @@ def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
assert spec_common_attn_metadata.max_query_len == 8
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q
assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask
assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask
17 changes: 1 addition & 16 deletions tests/ut/worker/test_pcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
input_batch.num_tokens = torch.tensor(num_tokens)

query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None,
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
input_batch)

if not expect_not_none:
Expand All @@ -97,21 +97,6 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
assert hasattr(result, 'head_attn_nomask_seqlens')
assert hasattr(result, 'tail_attn_nomask_seqlens')

if hasattr(result, 'pcp_prefill_mask'
) and result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048, 2048)
else:
if hasattr(result, 'pcp_prefill_mask'):
if result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048,
2048)


@pytest.mark.parametrize(
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
Expand Down
19 changes: 18 additions & 1 deletion vllm_ascend/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.distributed import get_pcp_group

from vllm_ascend.platform import ModelConfig
from vllm_ascend.utils import singleton


def _generate_attn_mask(max_seq_len, dtype):
Expand All @@ -29,6 +33,7 @@ def _generate_attn_mask(max_seq_len, dtype):
return attn_mask


@singleton
class AttentionMaskBuilder:

def __init__(self, device: torch.device):
Expand Down Expand Up @@ -82,4 +87,16 @@ def get_swa_mask(self, dtype: torch.dtype, sliding_window):
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
self.swa_mask = triu_mask + tril_mask
return self.swa_mask
return self.swa_mask

def get_attention_mask(self, model_config: ModelConfig):
if model_config.runner_type == "pooling":
return self.get_attn_mask(2048, torch.bool)

return self.get_splitfuse_attn_mask()

def get_final_mla_mask(self, model_config: ModelConfig):
if get_pcp_group().world_size > 1:
return self.get_pcp_mla_mask(model_config.dtype)
# Prefill stages use 512x512 mask with appropriate dtype
return self.get_mla_mask(model_config.dtype)
15 changes: 13 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForDecode, AscendMetadataForPrefill)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
Expand Down Expand Up @@ -219,6 +220,7 @@ def __init__(

scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
self.attn_mask_builder = AttentionMaskBuilder(self.device)

@classmethod
def get_cudagraph_support(
Expand Down Expand Up @@ -253,10 +255,19 @@ def build(
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]

slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
swa_mask = common_attn_metadata.swa_mask
attn_state = common_attn_metadata.attn_state

# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
attn_mask = self.attn_mask_builder.get_attention_mask(
self.model_config)

swa_mask = None
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
if self.model_config is not None and is_swa:
swa_mask = self.attn_mask_builder.get_swa_mask(
self.model_config.dtype,
self.model_config.hf_text_config.sliding_window)

# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
Expand Down
7 changes: 3 additions & 4 deletions vllm_ascend/attention/context_parallel/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def build(

slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
attn_mask = common_attn_metadata.attn_mask
attn_mask = self.attn_mask_builder.get_attention_mask(
self.model_config)
attn_state = common_attn_metadata.attn_state
num_computed_tokens_cpu = (seq_lens - query_lens)

Expand Down Expand Up @@ -212,7 +213,6 @@ def build(
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)

Expand Down Expand Up @@ -433,13 +433,12 @@ def _forward_prefill_cp_attn(self, data, is_head, attn_metadata):
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \
if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output, lse = self._attention_with_nomask_and_mask(
**data,
q_seqlens=attn_mask_seqlens,
kv_seqlens_nomask=nomask_seqlens,
kv_seqlens_mask=attn_mask_seqlens,
mask=mask,
mask=attn_metadata.attn_mask,
attn_metadata=attn_metadata)
return output, lse

Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/attention/context_parallel/common_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class AscendPCPMetadata:
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None


Expand Down
8 changes: 3 additions & 5 deletions vllm_ascend/attention/context_parallel/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def build_cp_metadata(
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)

Expand Down Expand Up @@ -195,7 +194,7 @@ def get_block_table_size(
).item()
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
# to avoid irregular attn_mask shape
return self.num_decodes_flatten + self.num_prefills
else:
return self.num_decodes_flatten
Expand Down Expand Up @@ -420,7 +419,6 @@ def _forward_prefill(
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
Expand All @@ -431,7 +429,7 @@ def _forward_prefill(
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)
mask=attn_metadata.attn_mask)

output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
Expand All @@ -443,7 +441,7 @@ def _forward_prefill(
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=mask)
mask=attn_metadata.attn_mask)

q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select(
Expand Down
Loading
Loading