Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 11 additions & 35 deletions tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@
import os
import pytest

from tests.e2e.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free

os.environ["HCCL_BUFFSIZE"] = "512"

prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"

@wait_until_npu_memory_free()
def test_pcp_dcp_mtp1_eager():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
Expand All @@ -50,15 +51,8 @@ def test_pcp_dcp_mtp1_eager():
runner.generate_greedy(prompts, 32)


@pytest.mark.skip(
reason="vLLM PR-32118 break this",
)
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp3_eager():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
Expand All @@ -78,15 +72,8 @@ def test_pcp_dcp_mtp3_eager():
runner.generate_greedy(prompts, 32)


@pytest.mark.skip(
reason="vLLM PR-32118 break this",
)
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp3_piecewise_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
Expand All @@ -109,15 +96,8 @@ def test_pcp_dcp_mtp3_piecewise_graph():
runner.generate_greedy(prompts, 32)


@pytest.mark.skip(
reason="vLLM PR-32118 break this",
)
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp3_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
Expand All @@ -140,12 +120,8 @@ def test_pcp_dcp_mtp3_full_graph():
runner.generate_greedy(prompts, 32)


@wait_until_npu_memory_free()
def test_dcp_mtp3_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
Expand Down
8 changes: 5 additions & 3 deletions tests/ut/worker/test_pcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
dtype=np.int32)
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
arange_np = np.arange(10000)
num_scheduled_tokens = np.array(tokens)
pcp_manager.init_batch_info(num_scheduled_tokens, num_reqs)
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
np.array(tokens), arange_np, num_reqs, 1)
num_scheduled_tokens, arange_np)

assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
Expand Down Expand Up @@ -305,8 +307,8 @@ def test_generate_pcp_mtp_input(
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor

pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens,
num_scheduled_tokens, False,
pcp_manager.init_batch_info(np.array(list(num_scheduled_tokens.values())), num_reqs)
pcp_manager.generate_pcp_mtp_input(total_num_scheduled_tokens, num_scheduled_tokens, False,
input_batch, arange_np)
assert torch.equal(
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
Expand Down
27 changes: 17 additions & 10 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,16 @@ def _prepare_inputs(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)

if self.pcp_size * self.dcp_size > 1:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
self.input_batch.num_reqs,
)

# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.pcp_manager.generate_pcp_mtp_input(
num_reqs,
total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens,
with_prefill,
Expand All @@ -584,8 +590,6 @@ def _prepare_inputs(
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
num_scheduled_tokens[:num_reqs],
self.arange_np,
self.input_batch.num_reqs,
self.reorder_batch_threshold,
)
# Re-update after PCP split sequences.
total_num_scheduled_tokens = sum(num_scheduled_tokens)
Expand Down Expand Up @@ -739,8 +743,7 @@ def _prepare_inputs(
num_draft_tokens = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
if self.pcp_size * self.dcp_size > 1:
logits_indices = self.pcp_manager.get_logits_indices(
cu_num_tokens, num_reqs)
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
logits_indices = logits_indices.pin_memory().to(
self.device, non_blocking=True)
else:
Expand Down Expand Up @@ -987,9 +990,8 @@ def propose_draft_token_ids(
num_reqs = self.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
num_decode_reqs = self.pcp_manager.num_decode_reqs
else:
long_seq_metadata = None # type: ignore
num_prefill_reqs = 0
Expand Down Expand Up @@ -1938,7 +1940,7 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int):
)
return blk_table_tensor, slot_mapping

long_seq_metdadata = _get_pcp_metadata(num_tokens)
self.long_seq_metadata = _get_pcp_metadata(num_tokens)
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)

cm_base = AscendCommonAttentionMetadata(
Expand All @@ -1963,7 +1965,7 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int):
positions=self.positions.gpu,
attn_state=self.attn_state,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metdadata,
prefill_context_parallel_metadata=self.long_seq_metadata,
)

if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
Expand Down Expand Up @@ -2153,6 +2155,11 @@ def _dummy_run(
force_has_lora=activate_lora,
)
)
if self.pcp_size * self.dcp_size > 1:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
num_reqs,
)
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = _cudagraph_mode
else:
Expand Down
Loading