diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 9168224c7da..be073c46196 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -117,8 +117,11 @@ def unpadded(self, num_actual_tokens: int, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, decode_token_per_req=self.decode_token_per_req, - block_table_tensor=self.block_table_tensor[:num_actual_reqs], - slot_mapping=self.slot_mapping[:num_actual_tokens], + # NOTE: keep all tokens for block_table_tensor and slot_mapping otherwise + # there will be error about shape mismatch during reshape and cache. + # This is really strange since vLLM slices them as well + block_table_tensor=self.block_table_tensor, + slot_mapping=self.slot_mapping, causal=self.causal, actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], positions=self.positions[:num_actual_tokens], diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 1684ab562a1..e7d31e94579 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -692,13 +692,6 @@ def prepare_inputs_padded( num_reqs = common_attn_metadata.num_reqs device = valid_sampled_tokens_count.device - if num_reqs != spec_decode_metadata.cu_num_draft_tokens.shape[0]: - # TODO: This is a serious issue and should be taken care of ASAP - # In short, why input_batch.num_reqs != attn_metadata.num_reqs? - # Previously in #4963, we modified `query_start_loc`, but this - # problem remains unsolved. - num_reqs = spec_decode_metadata.cu_num_draft_tokens.shape[0] - token_indices_to_sample = torch.empty((num_reqs, ), dtype=torch.int32, device=device) @@ -730,9 +723,9 @@ def prepare_inputs_padded( torch.zeros_like(num_draft_tokens_gpu), ) - query_start_loc = common_attn_metadata.query_start_loc[ - 1:1 + num_rejected_tokens_gpu.shape[0]] - token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - + num_rejected_tokens_gpu) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5b16e433fe4..02858b41016 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1013,9 +1013,7 @@ def _prepare_inputs( if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata = common_attn_metadata - if self.speculative_config.method in ("eagle", "eagle3") and \ - (self.vllm_config.speculative_config.enforce_eager \ - or self.use_async_scheduling): + if num_reqs != base_num_reqs or total_num_scheduled_tokens != num_input_tokens: self.spec_decode_common_attn_metadata = \ self.spec_decode_common_attn_metadata.unpadded( total_num_scheduled_tokens, base_num_reqs)