diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index fb28a59c582..aa86823fe45 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -195,7 +195,6 @@ def load_model(self, model: nn.Module) -> None: all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()) self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names - assert len(self._draft_attn_layer_names) == 1 self.attn_layer_names = list(sorted(self._draft_attn_layer_names)) draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) self.kernel_block_size = ( @@ -699,10 +698,24 @@ def _propose( multi_steps_attn_metadata.append(per_layer_attn_metadata) else: # Copy the old attn_metadata and update - for draft_step in range(1, self.num_speculative_tokens): - per_layer_attn_metadata = dict() - if vllm_version_is("0.17.0"): - for attn_group in self.draft_attn_groups: + if not self.parallel_drafting: + for draft_step in range(1, self.num_speculative_tokens): + per_layer_attn_metadata = dict() + if vllm_version_is("0.17.0"): + for attn_group in self.draft_attn_groups: + common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode, + attn_group=attn_group, + ) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + else: common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( draft_step, attn_metadata, @@ -711,23 +724,10 @@ def _propose( num_input_tokens, used_update_positions, aclgraph_runtime_mode, - attn_group=attn_group, ) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - else: - common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( - draft_step, - attn_metadata, - common_attn_metadata, - batch_size, - num_input_tokens, - used_update_positions, - aclgraph_runtime_mode, - ) - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - multi_steps_attn_metadata.append(per_layer_attn_metadata) + multi_steps_attn_metadata.append(per_layer_attn_metadata) token_indices_to_sample_len = token_indices_to_sample.shape[0] self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample) @@ -1064,16 +1064,21 @@ def set_inputs_first_pass( # 2. # Recompute the slot mapping based on the new positions and # rejection mask. - builder = ( - self._get_attention_metadata_builder() - if self.attn_metadata_builder is None - else self.attn_metadata_builder - ) + if vllm_version_is("0.17.0"): + # Use the first draft attention group's kv_cache_spec for block_size + # (all draft layers share the same kv-cache group) + assert len(self.draft_attn_groups) > 0 + block_size = self.draft_attn_groups[0].kv_cache_spec.block_size + else: + if self.attn_metadata_builder is None: + block_size = self._get_attention_metadata_builder().kv_cache_spec.block_size + else: + block_size = self.attn_metadata_builder.kv_cache_spec.block_size new_slot_mapping = compute_new_slot_mapping( cad=cad, new_positions=self.positions[:total_num_output_tokens], is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens], - block_size=builder.kv_cache_spec.block_size, + block_size=block_size, num_new_tokens=self.net_num_new_slots_per_request, max_model_len=self.max_model_len, ) @@ -1152,14 +1157,14 @@ def attn_update_stack_num_spec_norm( # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. if self.uses_mrope: - exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = used_update_positions[0] >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where( exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions ) else: - exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = used_update_positions >= self.max_model_len clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions) # For data integrity when async scheduling, we shouldn't use in place diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8780b4d0fe8..896e90e032f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -74,7 +74,6 @@ from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler -from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext @@ -2561,7 +2560,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.speculative_config and ( self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() ): - assert isinstance(self.drafter, AscendEagleProposer | DraftModelProposer) + assert isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer) self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes) if has_kv_transfer_group():