Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class AscendPrefillContextParallelMetadata:

num_actual_tokens_pcp_padded: Optional[int] = None

num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
num_computed_tokens_of_pcp_dcp: Optional[
list[Optional[list[Optional[list[int]]]]]
] = None

q_head_idx_tensor: torch.Tensor = None

Expand Down Expand Up @@ -47,7 +48,7 @@ class AscendCommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.

For many of the tensors we keep both GPU and CPU versions.
"""

Expand Down Expand Up @@ -104,7 +105,16 @@ class AscendCommonAttentionMetadata:
sin: torch.Tensor = None

prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
AscendPrefillContextParallelMetadata
] = None

max_seq_len: int = -1

def batch_size(self) -> int:
return self.seq_lens_cpu.shape[0]

def query_lens(self) -> torch.Tensor:
return self.query_start_loc[1:] - self.query_start_loc[:-1]


def split_decodes_and_prefills(
Expand Down Expand Up @@ -190,7 +200,8 @@ def trans_rope_weight(weight, rope_dim):
nope_part = weight[..., :-rope_dim, :]
rope_part = weight[..., -rope_dim:, :]
reordered_rope_part = torch.cat(
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2
)
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()


Expand All @@ -203,12 +214,34 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
nz_mat = torch.permute(
torch.reshape(
nd_mat,
(r // block_size[0], block_size[0], c // block_size[1],
block_size[1]),
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])
)
return nz_mat


def extend_flat_seqs(
seqs: torch.Tensor, end_locs: torch.Tensor, new_vals: torch.Tensor
) -> torch.Tensor:
"""
This function appends a single new value into multiple sequences
that are stored in a flat format. E.g.
[x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2]
"""
new_len = seqs.shape[0] + new_vals.shape[0]
new_seqs = torch.zeros(new_len, device=seqs.device, dtype=seqs.dtype)
# indices for previous seqs
start_locs = end_locs[:-1] + 1
seqs_new_idxs = torch.ones_like(seqs)
seqs_new_idxs[start_locs] += 1
seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1
# indices for new values
new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], device=seqs.device)
# assign seqs and new vals
new_seqs[seqs_new_idxs] = seqs
new_seqs[new_val_idxs] = new_vals
return new_seqs
5 changes: 5 additions & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: int = 1
max_num_partial_prefills: int = 1
long_prefill_token_threshold: int = MAX_INT
policy: str = "fcfs"
scheduler_cls: Union[str, Type[object]] = (
Expand All @@ -47,6 +48,7 @@ def initialize_from_config(
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["max_num_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["scheduler_cls"] = (
Expand Down Expand Up @@ -78,6 +80,9 @@ def __post_init__(self, *args) -> None:
self.max_long_partial_prefills = 1
self.long_prefill_token_threshold = MAX_INT

if self.max_num_partial_prefills is None:
self.max_num_partial_prefills = 1

if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if self.max_model_len is None:
Expand Down
5 changes: 0 additions & 5 deletions vllm_ascend/patch/platform/patch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,6 @@ def __post_init__(self):
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
Expand Down
16 changes: 8 additions & 8 deletions vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
from vllm_ascend.spec_decode.draft_proposer import DraftModelProposer


def get_spec_decode_method(method,
vllm_config,
device,
runner,
is_torchair_graph=False):
def get_spec_decode_method(
method, vllm_config, device, runner, is_torchair_graph=False
):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
elif method == "deepseek_mtp":
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
elif method == "draft_model":
return DraftModelProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
raise ValueError(f"Unknown speculative decoding method: {method}")
Loading
Loading