diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index e6c7437aa35..534b1b15b7d 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -255,7 +255,14 @@ def __call__( ] if not with_prefill or self.full_mode: - batch = attn_metadata.num_prefills + attn_metadata.num_decodes + # TODO: When vllm_ascend enables graph mode, attn_metadata.num_decodes + # will be padded in decode requests. Therefore, it is first fixed using + # num_decode_tokens. However, in the future, when MTP is enabled, there + # may be cases where a single request involves multiple tokens, which + # will need to be solved. + num_decodes = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills + batch = num_prefills + num_decodes seq_lens = attn_metadata.seq_lens[:batch] seq_tensor = torch.cat([ torch.tensor([0]), @@ -269,9 +276,9 @@ def __call__( xlite_attn_metadata = ModelAttnMeta() xlite_attn_metadata.lens = query_lens.tolist() xlite_attn_metadata.cached_lens = cached_lens.tolist() - xlite_attn_metadata.is_prefills = [ - False - ] * attn_metadata.num_decodes + [True] * attn_metadata.num_prefills + xlite_attn_metadata.is_prefills = [False] * num_decodes + [ + True + ] * num_prefills xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu( ).tolist()