Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def test_pcp_dcp_mtp3_eager():
max_num_batched_tokens=1024,
enable_expert_parallel=True,
block_size=128,
async_scheduling=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The async_scheduling parameter is specified twice for the VllmRunner. The value True set on this line will be overridden by async_scheduling=False on line 79. As a result, this test case does not run with asynchronous scheduling enabled, and therefore does not validate the bugfix for the async scenario. To fix this, please remove the redundant async_scheduling=False on line 79.

speculative_config={
"num_speculative_tokens": 3,
"method": "deepseek_mtp",
},
enforce_eager=True,
async_scheduling=False,
) as runner:
runner.generate_greedy(prompts, 32)

Expand Down
11 changes: 9 additions & 2 deletions tests/ut/worker/test_pcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
input_batch = MagicMock()
input_batch.num_reqs = num_reqs
Expand All @@ -65,13 +66,16 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
num_prompt_tokens.append(query_lens[i])
num_tokens.append(query_lens[i])

input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens)
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens)
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
input_batch.num_tokens = torch.tensor(num_tokens)
num_scheduled_tokens = np.array(
query_lens) - input_batch.num_computed_tokens_cpu

query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
input_batch)
input_batch,
num_scheduled_tokens)

if not expect_not_none:
assert result is None, f"Expected to return None, but got {type(result)}"
Expand Down Expand Up @@ -128,6 +132,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
input_batch = MagicMock()
input_batch.num_reqs = num_reqs
Expand Down Expand Up @@ -193,6 +198,7 @@ def test_get_cp_local_seq_lens(
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
dcp_world_size,
Expand Down Expand Up @@ -276,6 +282,7 @@ def test_generate_pcp_mtp_input(
max_num_reqs=max_num_reqs,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
arange_np = np.arange(max_model_len)
input_batch = MagicMock()
Expand Down
22 changes: 16 additions & 6 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
Expand Down Expand Up @@ -541,10 +542,18 @@ def _prepare_inputs(
# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.pcp_manager.generate_pcp_mtp_input(
num_reqs, total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens, with_prefill,
self.input_batch, self.arange_np, req_indices, positions_np,
cu_num_tokens)
num_reqs,
total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens,
with_prefill,
self.input_batch,
self.arange_np,
req_indices,
positions_np,
cu_num_tokens,
self._draft_token_ids, # type: ignore[has-type]
scheduler_output,
self.num_spec_tokens)

if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
Expand Down Expand Up @@ -930,7 +939,7 @@ def _prepare_inputs(
if self.pcp_size * self.dcp_size > 1:
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
total_num_scheduled_tokens, self.query_lens,
self.input_batch)
self.input_batch, num_scheduled_tokens)
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
if self.pcp_size > 1:
slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping(
Expand Down Expand Up @@ -1947,7 +1956,8 @@ def _build_dummy_attn_metadata(
slot_mapping = self.input_batch.block_table[
kv_cache_group_id].slot_mapping
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
num_tokens, self.query_lens, self.input_batch)
num_tokens, self.query_lens, self.input_batch,
num_scheduled_tokens)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group().world_size
dcp_world_size = get_dcp_group().world_size
Expand Down
115 changes: 112 additions & 3 deletions vllm_ascend/worker/pcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#

from typing import List
from typing import TYPE_CHECKING, List

import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput


class PCPManager:
"""
Expand All @@ -44,6 +47,7 @@ def __init__(
max_num_reqs: int,
device: torch.device,
vllm_config: VllmConfig,
use_async_scheduling: bool,
pin_memory: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
Expand All @@ -58,6 +62,7 @@ def __init__(
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
self.device = device
self.use_async_scheduling = use_async_scheduling
self.pcp_allgather_restore_idx = CpuGpuBuffer(
max_buffer_num_tokens,
dtype=torch.int64,
Expand Down Expand Up @@ -354,6 +359,9 @@ def generate_pcp_mtp_input(
req_indices=None,
positions_np=None,
cu_num_tokens=None,
draft_token_ids=None,
scheduler_output=None,
num_spec_tokens=None,
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
Expand Down Expand Up @@ -390,6 +398,12 @@ def generate_pcp_mtp_input(
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
if self.use_async_scheduling:
self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids,
scheduler_output,
total_num_scheduled_tokens,
cu_num_tokens_pcp_full,
num_spec_tokens)
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.copy_to_gpu(
Expand Down Expand Up @@ -428,6 +442,99 @@ def generate_pcp_mtp_input(
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)

def _update_input_ids_pcp_full_ids(
self,
input_batch,
draft_token_ids,
scheduler_output: "SchedulerOutput",
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray,
num_spec_tokens: int,
) -> None:
"""Prepare the input IDs for the current batch.

Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""

if (input_batch.prev_sampled_token_ids is None
or input_batch.prev_req_id_to_index is None):
return

# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = input_batch.prev_req_id_to_index
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens

for req_id, cur_index in input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1,
flattened_index + 1))
start = prev_index * num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start,
start + draft_len))
num_commmon_tokens = len(sample_flattened_indices)

if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids.cpu will have all the input ids.
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices,
dtype=torch.int64)
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
dtype=torch.int64)
self.input_ids_pcp_full.cpu.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0].cpu(),
)

# Scatter the draft tokens after the sampled tokens are scattered.
if draft_token_ids is None or not spec_flattened_indices:
return

assert isinstance(draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(spec_flattened_indices,
dtype=torch.int64)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64)

# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = draft_token_ids.to(dtype=torch.int32)

self.input_ids_pcp_full.cpu.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()
[prev_draft_token_indices_tensor].cpu(),
)

def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
Expand Down Expand Up @@ -498,7 +605,7 @@ def generate_kv_idx(self, scheduler_output, input_batch):
torch.float32).argsort().to(torch.int32)

def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
input_batch):
input_batch, num_scheduled_tokens):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_reqs = input_batch.num_reqs or query_lens.size(0)
Expand All @@ -510,7 +617,9 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
decode_context_lens = input_batch.num_tokens[:num_decodes]
decode_context_lens = input_batch.num_computed_tokens_cpu[:
num_decodes] + num_scheduled_tokens[:
num_decodes]
prefill_context_lens = input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(
Expand Down
Loading