Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,72 @@ def test_accuracy_pcp_only(max_tokens: int, ) -> None:
name_0="vllm_eager_outputs",
name_1="vllm_pcp_only_outputs",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [10])
def test_models_long_sequence_cp_kv_interleave_size_output_between_tp_and_cp(
model: str,
max_tokens: int,
) -> None:
prompts = [
"The president of the United States is", "The capital of France is"
]

common_kwargs = {
"max_model_len": 1024,
}

if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
cp_kwargs = {
"tensor_parallel_size": 2,
"decode_context_parallel_size": 2,
"prefill_context_parallel_size": 2,
"enable_expert_parallel": True,
"cp_kv_cache_interleave_size": 128,
"enforce_eager": True,
"quantization": "ascend",
}
tp_kwargs = {
"tensor_parallel_size": 4,
"enable_expert_parallel": True,
"enforce_eager": True,
"quantization": "ascend",
}

else:
cp_kwargs = {
"tensor_parallel_size": 1,
"decode_context_parallel_size": 1,
"prefill_context_parallel_size": 2,
"cp_kv_cache_interleave_size": 128,
"compilation_config": {
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
},
}
tp_kwargs = {
"tensor_parallel_size": 2,
"enforce_eager": True,
}

cp_full_kwargs = {}
cp_full_kwargs.update(common_kwargs) # type: ignore
cp_full_kwargs.update(cp_kwargs) # type: ignore

tp_full_kwargs = {}
tp_full_kwargs.update(common_kwargs) # type: ignore
tp_full_kwargs.update(tp_kwargs) # type: ignore
with VllmRunner(model, **cp_full_kwargs) as runner: # type: ignore
vllm_context_parallel_outputs = runner.generate_greedy(
prompts, max_tokens)

with VllmRunner(model, **tp_full_kwargs) as runner: # type: ignore
vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=vllm_eager_outputs,
outputs_1_lst=vllm_context_parallel_outputs,
name_0="vllm_eager_outputs",
name_1="vllm_context_parallel_outputs",
)
2 changes: 0 additions & 2 deletions tests/ut/attention/test_attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def mock_npu_fused_infer_attention_score_func(query, k_nope, value,

attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)

self.assertEqual(output.shape[0], 2)
Expand Down
14 changes: 2 additions & 12 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,7 @@ def test_process_attn_out_lse(self):
decode_metadata = MagicMock()
decode_metadata.actual_seq_lengths_q = MagicMock()
decode_metadata.seq_lens_list = MagicMock()
decode_metadata.batch_seq_mask = torch.tensor([True, False],
dtype=torch.bool)

result = _process_attn_out_lse(attn_output, softmax_lse,
decode_metadata.batch_seq_mask)
result = _process_attn_out_lse(attn_output, softmax_lse)

self.assertEqual(result.shape[0], B * self.impl.pcp_size)
self.assertEqual(result.shape[1], N)
Expand Down Expand Up @@ -478,8 +474,6 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock()
attn_metadata.decode.batch_seq_mask = torch.tensor([False, False],
dtype=torch.bool)

self.impl.enable_kv_nz = True

Expand Down Expand Up @@ -886,12 +880,8 @@ def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp,
# Inputs
attn_output = torch.randn(B, H, D)
softmax_lse = torch.randn(B, H, 1)
batch_seq_mask = torch.tensor([False, True, False, False]) # [B]
decode_meta = MagicMock()
decode_meta.batch_seq_mask = batch_seq_mask

result = _process_attn_out_lse(attn_output, softmax_lse,
batch_seq_mask)
result = _process_attn_out_lse(attn_output, softmax_lse)
# [PCP * S, DCP * H, D + 1]
self.assertIsInstance(result, torch.Tensor)
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
Expand Down
5 changes: 1 addition & 4 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,14 @@ def test_ascend_mla_decode_metadata_default(self):
seq_lens_list = [2, 3]
attn_mask = None
cp_seq_len = torch.tensor([2, 3])
batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]])

metadata = AscendMLADecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
max_seq_lens=max_seq_lens,
seq_lens_list=seq_lens_list,
attn_mask=attn_mask,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
cp_seq_len=cp_seq_len)

self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
Expand All @@ -139,7 +137,6 @@ def test_ascend_mla_decode_metadata_default(self):
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
self.assertIs(metadata.cp_seq_len, cp_seq_len)
self.assertIs(metadata.batch_seq_mask, batch_seq_mask)


class TestAscendMLAMetadata(TestBase):
Expand Down
23 changes: 5 additions & 18 deletions vllm_ascend/attention/context_parallel/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ def __init__(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.batch_seq_mask_buf = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.uint8,
device=device)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
Expand Down Expand Up @@ -229,16 +225,11 @@ def build(
num_computed_tokens_array = np.array(
num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[:num_decodes]
batch_seq_mask = (num_computed_tokens_array[:, self.pcp_rank,
self.dcp_rank] == 0)
# TODO: numpy array mode of the shared memory is used to improve performance
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
torch.from_numpy(batch_seq_mask), non_blocking=True)
decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
shape[0]],
block_tables=block_table[:num_decodes])
block_tables=block_table[:num_decodes],
)

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -549,8 +540,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
query, k_nope, value, **common_kwargs)
attn_out_lse = _process_attn_out_lse(
attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask)
attn_out_lse = _process_attn_out_lse(attn_out, attn_lse)
attn_out = _npu_attention_update(self.head_size, attn_out_lse)
return attn_out

Expand Down Expand Up @@ -661,11 +651,8 @@ def _compute_prefill_context(self, query: torch.Tensor,
actual_seq_lengths_kv=prefill_metadata.chunked_context.
actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
lse_mask = batch_chunk_seq_mask[:, None,
None].expand_as(prefix_chunk_lse)
prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse)
actual_chunk_seq_lengths,
)

return prefix_chunk_output, prefix_chunk_lse

Expand Down
14 changes: 5 additions & 9 deletions vllm_ascend/attention/context_parallel/common_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,17 @@ class ChunkedContextMetadata:

@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
"""Decode-specific metadata for Ascend attention with Context Parallelism."""

num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
block_tables: torch.Tensor = None


def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
batch_seq_mask: torch.Tensor) -> torch.Tensor:
def _process_attn_out_lse(attn_output: torch.Tensor,
softmax_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
out_mask = batch_seq_mask[:, None, None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
Expand Down
17 changes: 1 addition & 16 deletions vllm_ascend/attention/context_parallel/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,6 @@ def __init__(
) if self.dcp_size > 1 else 0
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
scheduler_config = vllm_config.scheduler_config
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs *
self.decode_threshold,
dtype=torch.uint8,
device=device)
self.block_size = (self.block_size *
self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size)
Expand Down Expand Up @@ -256,13 +248,7 @@ def build_decode_metadata(
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len.tolist()
decode_metadata.batch_seq_mask = batch_seq_mask

actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1
decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q
Expand Down Expand Up @@ -688,8 +674,7 @@ def _forward_decode(
B_lse * Q_S, N_lse, 1)

# Update out&lse
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse,
decode_meta.batch_seq_mask)
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse)
attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse)
return self._v_up_proj(attn_output)

Expand Down
6 changes: 2 additions & 4 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class AscendMLADecodeMetadata:
sin: torch.Tensor = None
cos: torch.Tensor = None
cp_seq_len: torch.Tensor = None
batch_seq_mask: torch.Tensor = None


@dataclass
Expand Down Expand Up @@ -590,7 +589,7 @@ def build_decode_metadata(
self.block_table = self.block_table[:self.graph_pad_size, ...]
seq_lens_list = self.seq_lens.tolist()

cp_seq_len, batch_seq_mask = None, None
cp_seq_len = None

if self.graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
Expand Down Expand Up @@ -651,8 +650,7 @@ def build_decode_metadata(
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
cp_seq_len=cp_seq_len)
return decode_metadata

def build_for_graph_capture(
Expand Down
9 changes: 1 addition & 8 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _propose(
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
if self.pcp_size * self.dcp_size > 1:
# update local seq_len and batch_seq_mask
# update local seq_len
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
ori_seq_len + step + 1,
self.pcp_size,
Expand All @@ -521,14 +521,7 @@ def _propose(
)
cp_seq_len = \
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
batch_seq_mask = (cp_seq_len == 0)
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
attn_metadata_i.decode.cp_seq_len = cp_seq_len
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The cp_seq_len variable is a numpy array, but it is being assigned to attn_metadata_i.decode.cp_seq_len, which is defined as a torch.Tensor in the AscendMLADecodeMetadata dataclass. This type mismatch will likely cause a runtime error in downstream operations that expect a tensor. You should convert cp_seq_len to a torch.Tensor before the assignment, similar to the implementation in vllm_ascend/attention/context_parallel/mla_cp.py.

Suggested change
attn_metadata_i.decode.cp_seq_len = cp_seq_len
attn_metadata_i.decode.cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)

attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
# update slot_mapping
slot_indices += self.pcp_size
slot_mapping = mtp_slot_mapping[slot_indices]
Expand Down