diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e11d9c39b2a..1bdf661de63 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -849,21 +849,37 @@ def forward_fused_infer_attention( learnable_sink=self.sinks, ) else: - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=attn_metadata.attn_mask, - block_table=block_table, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=actual_seq_lengths_kv, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) + if not attn_metadata.causal: + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=0, + ) + else: + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) output[:num_tokens] = attn_output[:num_tokens] @@ -910,6 +926,28 @@ def _forward_encoder_attention( actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, )[0] + def do_kv_cache_update( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: list[torch.Tensor], + slot_mapping: torch.Tensor, + ) -> None: + if self.attn_type in (AttentionType.ENCODER_ONLY): + return + + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + + DeviceOperator.reshape_and_cache( + key=key, + value=value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_mapping=slot_mapping, + ) + def reshape_and_cache( self, query: torch.Tensor, diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index 3c7aa450b7f..3429117c273 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -63,3 +63,74 @@ def prepare_inputs_padded_kernel( index_to_sample = q_last_tok_idx - num_rejected tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask) tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask) + + +@triton.jit +def copy_and_expand_dflash_inputs_kernel_single_grid( + # Inputs + next_token_ids_ptr, # [num_reqs] + target_positions_ptr, # [num_context] + # Outputs + out_input_ids_ptr, # [num_query_total] (output) + out_context_positions_ptr, # [num_context] (output) + out_query_positions_ptr, # [num_query_total] (output) + out_context_slot_mapping_ptr, # [num_context] (output) + out_query_slot_mapping_ptr, # [num_query_total] (output) + out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output) + # Block table + block_table_ptr, # [max_reqs, max_blocks] + block_table_stride, # stride of block_table dim 0 (in elements) + # Metadata + query_start_loc_ptr, # [num_reqs + 1] + num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded + # Scalars + parallel_drafting_token_id, # tl.int32 + block_size, # tl.int32 + num_query_per_req, # tl.int32 + num_speculative_tokens, # tl.int32 + total_input_tokens, # tl.int32 + batch_size, # tl.int32 + HAS_NUM_REJECTED: tl.constexpr = False, +): + for req_idx in range(0, batch_size): + ctx_start = tl.load(query_start_loc_ptr + req_idx) + ctx_end = tl.load(query_start_loc_ptr + req_idx + 1) + num_ctx = ctx_end - ctx_start + + for j in range(0, num_ctx): + ctx_pos_idx = ctx_start + j + pos = tl.load(target_positions_ptr + ctx_pos_idx) + tl.store(out_context_positions_ptr + ctx_pos_idx, pos) + + block_num = pos // block_size + block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_num).to(tl.int64) + slot = block_id * block_size + (pos % block_size) + tl.store(out_context_slot_mapping_ptr + ctx_pos_idx, slot) + + if HAS_NUM_REJECTED: + num_rejected = tl.load(num_rejected_tokens_ptr + req_idx) + valid_ctx_end = ctx_end - num_rejected + else: + valid_ctx_end = ctx_end + + last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1) + + for q_idx in range(0, num_query_per_req): + query_pos = last_pos + 1 + q_idx + query_out_idx = req_idx * num_query_per_req + q_idx + + tl.store(out_query_positions_ptr + query_out_idx, query_pos) + + block_num_q = query_pos // block_size + block_id_q = tl.load(block_table_ptr + req_idx * block_table_stride + block_num_q).to(tl.int64) + slot_q = block_id_q * block_size + (query_pos % block_size) + tl.store(out_query_slot_mapping_ptr + query_out_idx, slot_q) + + if q_idx == 0: + bonus_token = tl.load(next_token_ids_ptr + req_idx) + tl.store(out_input_ids_ptr + query_out_idx, bonus_token) + else: + tl.store(out_input_ids_ptr + query_out_idx, parallel_drafting_token_id) + + sample_out_idx = req_idx * num_speculative_tokens + (q_idx - 1) + tl.store(out_token_indices_ptr + sample_out_idx, query_out_idx) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 804c98242c3..57c9ea1c913 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -687,3 +687,13 @@ # when using mrope. # Future Plan: # Remove this patch when vllm-ascend supports pattern matching for this fused kernel. +# ** 29. File: worker/patch_qwen3_dflash.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3_dflash.DFlashQwen3Model.precompute_and_store_context_kv` +# Why: +# The function directly calls the ops.rms_norm and ops.rotary_imbedding operators, +# but NPU does not have a corresponding implementation. +# How: +# Replace ops.* with the internal implementation of vllm-ascend. +# Future Plan: +# Remove this patch when vllm-ascend supports pattern matching for ops.*. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 60ae3ad64f2..6805f3db9b1 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -17,7 +17,7 @@ from vllm.triton_utils import HAS_TRITON -from vllm_ascend.utils import is_310p +from vllm_ascend.utils import is_310p, vllm_version_is if HAS_TRITON: import vllm_ascend.patch.worker.patch_triton @@ -39,6 +39,9 @@ if not is_310p(): import vllm_ascend.patch.worker.patch_qwen3_5 # noqa import vllm_ascend.patch.worker.patch_gdn_attn # noqa + + if not vllm_version_is("0.19.0"): + import vllm_ascend.patch.worker.patch_qwen3_dflash # noqa import vllm_ascend.patch.worker.patch_rejection_sampler # noqa import vllm_ascend.patch.worker.patch_v2.patch_uva # noqa import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_dflash.py b/vllm_ascend/patch/worker/patch_qwen3_dflash.py new file mode 100644 index 00000000000..80dd31c7152 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_dflash.py @@ -0,0 +1,62 @@ +import torch +import torch.nn.functional as F +from vllm.model_executor.models.qwen3_dflash import DFlashQwen3Model + + +def precompute_and_store_context_kv( + self, + context_states: torch.Tensor, + context_positions: torch.Tensor, + context_slot_mapping: torch.Tensor | None = None, +) -> None: + if not hasattr(self, "_num_attn_layers"): + self._build_fused_kv_buffers() + + num_ctx = context_states.shape[0] + L = self._num_attn_layers + kv = self._kv_size + hd = self._head_dim + nkv = self._num_kv_heads + + # --- Fused KV projection (one GEMM for all layers) --- + normed_context_states = self.hidden_norm(context_states) + all_kv_flat = F.linear(normed_context_states, self._fused_kv_weight, self._fused_kv_bias) + # Single contiguous copy that separates K/V and transposes to + # layer-major layout. Result: [2, L, num_ctx, nkv, hd] contiguous. + # Indexing dim-0 gives contiguous [L, num_ctx, nkv, hd] for K and V. + all_kv = all_kv_flat.view(num_ctx, L, 2, nkv, hd).permute(2, 1, 0, 3, 4).contiguous() + all_k = all_kv[0] # [L, num_ctx, nkv, hd], contiguous + all_v = all_kv[1] # [L, num_ctx, nkv, hd], contiguous + + # --- Per-layer RMSNorm K (3D: [num_ctx, nkv, hd] per layer) --- + all_k_normed = torch.empty_like(all_k) + for i in range(L): + k_norm_layer = self.layers[i].self_attn.k_norm + all_k_normed[i] = k_norm_layer(all_k[i]) + + # --- Fused RoPE across all layers --- + # View as [L * num_ctx, kv] so RoPE sees one big batch (no copy). + # In-place RoPE: pass K as the "query" arg with key=None. + all_k_flat = all_k_normed.view(L * num_ctx, kv) + positions_repeated = context_positions.repeat(L) + tmpv = all_k_flat.clone() + self.layers[0].self_attn.rotary_emb(positions_repeated, all_k_flat, tmpv) + + if context_slot_mapping is None: + return + + # --- Per-layer cache insert --- + all_k_final = all_k_flat.view(L, num_ctx, nkv, hd) + for i in range(L): + attn = self._attn_layers[i] + kv_cache = attn.kv_cache + attn.impl.do_kv_cache_update( + attn, + all_k_final[i], + all_v[i], + kv_cache, + context_slot_mapping, + ) + + +DFlashQwen3Model.precompute_and_store_context_kv = precompute_and_store_context_kv diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index c17e9398722..618695a7ad2 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -17,11 +17,14 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # + +from vllm_ascend.spec_decode.dflash_proposer import AscendDflashProposer from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer +from vllm_ascend.utils import vllm_version_is def get_spec_decode_method(method, vllm_config, device, runner): @@ -33,6 +36,11 @@ def get_spec_decode_method(method, vllm_config, device, runner): return AscendMedusaProposer(vllm_config, device) elif method in ("eagle", "eagle3", "mtp"): return AscendEagleProposer(vllm_config, device, runner) + elif method == "dflash": + if not vllm_version_is("0.19.0"): + return AscendDflashProposer(vllm_config, device, runner) + else: + raise ValueError(f"VLLM v0.19.0 doesn't support {method} now") elif method == "draft_model": return AscendDraftModelProposer(vllm_config, device, runner) else: diff --git a/vllm_ascend/spec_decode/dflash_proposer.py b/vllm_ascend/spec_decode/dflash_proposer.py new file mode 100644 index 00000000000..0f3b8aad2c7 --- /dev/null +++ b/vllm_ascend/spec_decode/dflash_proposer.py @@ -0,0 +1,206 @@ +from typing import Any + +import torch +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.v1.attention.backends.utils import CommonAttentionMetadata + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.ops.triton.spec_decode.utils import copy_and_expand_dflash_inputs_kernel_single_grid +from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer + + +class AscendDflashProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config, + device, + pass_hidden_states_to_model=True, + runner=runner, + ) + + self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens) + self.max_positions = self.max_num_tokens + self.max_query_tokens + + self._context_slot_mapping_buffer = torch.zeros( + self.max_num_tokens, + dtype=torch.int32, + device=device, + ) + + self._slot_mapping_buffer = torch.zeros( + self.max_query_tokens, + dtype=torch.int32, + device=device, + ) + + self._context_positions_buffer = torch.zeros( + self.max_num_tokens, + dtype=torch.int32, + device=device, + ) + + self.positions = torch.zeros( + self.max_query_tokens, + dtype=torch.int32, + device=device, + ) + + self.arange_dflash = torch.arange(self.max_positions + 1, device=device, dtype=torch.int32) + + self.parallel_drafting_hidden_state_tensor = None + + def set_inputs_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor | None, + cad: CommonAttentionMetadata, + num_rejected_tokens_gpu: torch.Tensor | None, + req_scheduled_tokens=None, + long_seq_metadata=None, + num_prefill_reqs=0, + num_decode_reqs=0, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata, tuple[Any, Any] | None]: + # DFlash cross-attention: context K/V from target hidden states, + # Q from query embeddings (bonus + mask tokens). + batch_size = cad.num_reqs + num_context = target_token_ids.shape[0] + num_query_per_req = 1 + self.num_speculative_tokens + num_query_total = batch_size * num_query_per_req + + self._dflash_num_context = num_context + self._dflash_hidden_states = target_hidden_states + + token_indices_to_sample = torch.empty( + batch_size * self.num_speculative_tokens, + dtype=torch.int32, + device=self.device, + ) + + has_num_rejected = num_rejected_tokens_gpu is not None + + copy_and_expand_dflash_inputs_kernel_single_grid[1,]( + # Inputs + next_token_ids_ptr=next_token_ids, + target_positions_ptr=target_positions, + # Outputs + out_input_ids_ptr=self.input_ids, + out_context_positions_ptr=self._context_positions_buffer, + out_query_positions_ptr=self.positions, + out_context_slot_mapping_ptr=self._context_slot_mapping_buffer, + out_query_slot_mapping_ptr=self._slot_mapping_buffer, + out_token_indices_ptr=token_indices_to_sample, + # Block table + block_table_ptr=cad.block_table_tensor, + block_table_stride=cad.block_table_tensor.stride(0), + # Metadata + query_start_loc_ptr=cad.query_start_loc, + num_rejected_tokens_ptr=(num_rejected_tokens_gpu if has_num_rejected else 0), + # Scalars + parallel_drafting_token_id=self.parallel_drafting_token_id, + block_size=self.block_size, + num_query_per_req=num_query_per_req, + num_speculative_tokens=self.num_speculative_tokens, + total_input_tokens=num_context, + batch_size=batch_size, + HAS_NUM_REJECTED=has_num_rejected, + ) + + query_slot_mapping = self._slot_mapping_buffer[:num_query_total] + new_query_start_loc = self.arange_dflash[: batch_size + 1] * num_query_per_req + + effective_seq_lens = cad.seq_lens + if has_num_rejected: + effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu + + cad.query_start_loc = new_query_start_loc + cad.seq_lens = effective_seq_lens + num_query_per_req + cad.query_start_loc_cpu = ( + torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone() * num_query_per_req + ).to(torch.int32) + + if hasattr(cad, "actual_seq_lengths_q"): + cad.actual_seq_lengths_q = [num_query_per_req] * batch_size + if hasattr(cad, "decode_token_per_req"): + cad.decode_token_per_req = num_query_per_req + + cad.num_actual_tokens = num_query_total + cad.max_query_len = num_query_per_req + cad.max_seq_len = cad.max_seq_len + num_query_per_req + cad.slot_mapping = query_slot_mapping + cad.causal = False + cad.attn_mask = None + cad.attn_state = AscendAttentionState.ChunkedPrefill + + return num_query_total, token_indices_to_sample, cad, None + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + num_reqs: int = 0, + num_tokens_across_dp: torch.Tensor | None = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + **kwargs, + ) -> None: + num_query_tokens = min(num_tokens, self.max_query_tokens) + + ( + num_query_tokens, + num_tokens_across_dp, + _, + ) = self.runner._sync_metadata_across_dp(num_query_tokens, is_draft_model=True) + + num_input_tokens = num_query_tokens + + num_context = num_tokens + context_positions = self._context_positions_buffer[:num_context] + context_states = self.hidden_states[:num_context] + + input_ids = self.input_ids[:num_input_tokens] + positions = self.positions[:num_input_tokens] + + with set_ascend_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True, + ): + self.model.precompute_and_store_context_kv(context_states, context_positions) + self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=None, + ) + + def build_model_inputs_first_pass( + self, + num_input_tokens: int, + ) -> dict[str, Any]: + num_context = self._dflash_num_context + + self.model.precompute_and_store_context_kv( + self._dflash_hidden_states, + self._context_positions_buffer[:num_context], + self._context_slot_mapping_buffer[:num_context], + ) + + return dict( + input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], inputs_embeds=None + ) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 407524a63cd..067a98f44eb 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -46,7 +46,12 @@ from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled +from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is + +if not vllm_version_is("0.19.0"): + from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM +else: + DFlashQwen3ForCausalLM = None # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -243,8 +248,11 @@ def load_model(self, model: nn.Module) -> None: self._maybe_share_embeddings(target_language_model) self._maybe_share_lm_head(model) - if self.parallel_drafting and self.pass_hidden_states_to_model: - assert self.parallel_drafting_hidden_state_tensor is not None + if ( + self.parallel_drafting + and self.pass_hidden_states_to_model + and self.parallel_drafting_hidden_state_tensor is not None + ): self.parallel_drafting_hidden_state_tensor.copy_( self.model.combine_hidden_states(self.model.mask_hidden.view(3 * self.hidden_size)) if self.eagle3_use_aux_hidden_state @@ -317,8 +325,8 @@ def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: def _maybe_share_lm_head(self, model: nn.Module) -> None: # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.method == "eagle" and hasattr(model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") + if self.method in ("eagle", "dflash") and hasattr(model, "lm_head"): + logger.info("Loading EAGLE or DFLASH LM head weights from the target model.") if supports_multimodal(model): self.model.lm_head = model.get_language_model().lm_head else: @@ -493,8 +501,8 @@ def _propose( if token_indices_to_sample is None: token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 - if self.method == "eagle3": - assert isinstance(self.get_model(), Eagle3LlamaForCausalLM) + if self.method in ("eagle3", "dflash"): + assert isinstance(self.get_model(), (Eagle3LlamaForCausalLM, DFlashQwen3ForCausalLM)) target_hidden_states = self.model.combine_hidden_states(target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size @@ -769,13 +777,16 @@ def _run_merged_draft( model_input_ids = self.input_ids[:num_input_tokens] model_positions = self._get_positions(num_input_tokens) - model_kwargs = { - "input_ids": model_input_ids, - "positions": model_positions, - "inputs_embeds": inputs_embeds, - } + if self.method == "dflash": + model_kwargs = self.build_model_inputs_first_pass(num_input_tokens) + else: + model_kwargs = { + "input_ids": model_input_ids, + "positions": model_positions, + "inputs_embeds": inputs_embeds, + } - if self.pass_hidden_states_to_model: + if self.pass_hidden_states_to_model and self.method != "dflash": model_hidden_states = self.hidden_states[:num_input_tokens] model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) model_kwargs["hidden_states"] = model_hidden_states @@ -789,9 +800,10 @@ def _run_merged_draft( else: last_hidden_states, hidden_states = ret_hidden_states - last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states - ) + if self.method != "dflash": + last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( + last_hidden_states, model_positions, hidden_states + ) num_indices = token_indices_to_sample.shape[0] if self.pcp_size > 1: @@ -1142,7 +1154,7 @@ def set_inputs_first_pass( return total_num_output_tokens, token_indices_to_sample, new_cad, None def model_returns_tuple(self) -> bool: - return self.method not in ("mtp", "draft_model") + return self.method not in ("mtp", "draft_model", "dflash") def attn_update_stack_num_spec_norm( self, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 919954582dd..d44f97da6e9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -113,6 +113,7 @@ from vllm_ascend.quantization.utils import enable_fa_quant from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method +from vllm_ascend.spec_decode.dflash_proposer import AscendDflashProposer from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer @@ -451,6 +452,7 @@ def _set_up_drafter(self): AscendNgramProposer | AscendEagleProposer | AscendDraftModelProposer + | AscendDflashProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None @@ -2354,7 +2356,7 @@ def _build_attn_group_metadata( if kv_cache_gid > 0: cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid) if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer): + if isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer | AscendDflashProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: spec_decode_common_attn_metadata = cm else: @@ -2735,7 +2737,9 @@ def mock_true(): "Model does not support EAGLE3 interface but " "aux_hidden_state_outputs was requested" ) - aux_layers = self.model.get_eagle3_default_aux_hidden_state_layers() + aux_layers = self._get_eagle3_aux_layers_from_config() + if not aux_layers: + aux_layers = self.model.get_eagle3_default_aux_hidden_state_layers() self.model.set_aux_hidden_state_layers(aux_layers) if self.lora_config: @@ -2775,7 +2779,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.speculative_config and ( self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() ): - assert isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer) + assert isinstance(self.drafter, AscendEagleProposer | AscendDflashProposer | AscendDraftModelProposer) block_size = (self.kernel_block_sizes[0] if isinstance( self.kernel_block_sizes, list) else self.kernel_block_sizes) self.drafter.initialize_attn_backend(kv_cache_config, block_size)