Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def load_model(self, model: nn.Module) -> None:
all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names

assert len(self._draft_attn_layer_names) == 1
self.attn_layer_names = list(sorted(self._draft_attn_layer_names))
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
self.kernel_block_size = (
Expand Down Expand Up @@ -699,10 +698,24 @@ def _propose(
multi_steps_attn_metadata.append(per_layer_attn_metadata)
else:
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
per_layer_attn_metadata = dict()
if vllm_version_is("0.17.0"):
for attn_group in self.draft_attn_groups:
if not self.parallel_drafting:
for draft_step in range(1, self.num_speculative_tokens):
per_layer_attn_metadata = dict()
if vllm_version_is("0.17.0"):
for attn_group in self.draft_attn_groups:
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
attn_group=attn_group,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
else:
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
Expand All @@ -711,23 +724,10 @@ def _propose(
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
attn_group=attn_group,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
else:
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
multi_steps_attn_metadata.append(per_layer_attn_metadata)

token_indices_to_sample_len = token_indices_to_sample.shape[0]
self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample)
Expand Down Expand Up @@ -1064,16 +1064,21 @@ def set_inputs_first_pass(
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
if vllm_version_is("0.17.0"):
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
else:
if self.attn_metadata_builder is None:
block_size = self._get_attention_metadata_builder().kv_cache_spec.block_size
else:
block_size = self.attn_metadata_builder.kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
block_size=builder.kv_cache_spec.block_size,
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
Expand Down Expand Up @@ -1152,14 +1157,14 @@ def attn_update_stack_num_spec_norm(
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
if self.uses_mrope:
exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len
exceeds_max_model_len = used_update_positions[0] >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions
)
else:
exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len
exceeds_max_model_len = used_update_positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions)

# For data integrity when async scheduling, we shouldn't use in place
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
from vllm.v1.sample.logits_processor import build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
Expand Down Expand Up @@ -2561,7 +2560,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
if self.speculative_config and (
self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, AscendEagleProposer | DraftModelProposer)
assert isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer)
self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes)

if has_kv_transfer_group():
Expand Down
Loading