diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6a016366537..0448818e6e1 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -592,22 +592,39 @@ def forward_fused_infer_attention(self, query: torch.Tensor, if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER: key = key[:num_tokens] value = value[:num_tokens] - # Get workspace from cache or calculate it if not present. - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=attn_metadata.attn_mask, - block_table=block_table, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=actual_seq_lengths_kv, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) + if not attn_metadata.causal: + # for dflash + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=0, + ) + else: + # Get workspace from cache or calculate it if not present. + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index bd1f925d6d4..80e00ea6b9f 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -90,18 +90,22 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, if hasattr(model_config.hf_text_config, "partial_rotary_factor"): rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor) - _cos = torch.ones(1, - max_num_batched_tokens, - 1, - rope_dim, - dtype=dtype, - device=device) - _sin = torch.zeros(1, - max_num_batched_tokens, - 1, - rope_dim, - dtype=dtype, - device=device) + if vllm_config.speculative_config.method == "dflash": + _cos = torch.ones(1, max_num_batched_tokens * 2, 1, rope_dim, dtype=dtype, device=device) + _sin = torch.zeros(1, max_num_batched_tokens * 2, 1, rope_dim, dtype=dtype, device=device) + else: + _cos = torch.ones(1, + max_num_batched_tokens, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin = torch.zeros(1, + max_num_batched_tokens, + 1, + rope_dim, + dtype=dtype, + device=device) def get_cos_and_sin_mla(positions, use_cache=False): diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index df5015f1f7a..1d75588955a 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -25,7 +25,7 @@ def get_spec_decode_method(method, vllm_config, device, runner): if method == "ngram": return NgramProposer(vllm_config, device, runner) - elif method in ("eagle", "eagle3"): + elif method in ("eagle", "eagle3", "dflash"): return EagleProposer(vllm_config, device, runner) elif method == "mtp": return MtpProposer(vllm_config, device, runner) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 5dd4c4bc190..edb0f22b9bf 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -196,7 +196,6 @@ def load_model(self, model: nn.Module) -> None: draft_attn_layer_names = draft_attn_layer - target_attn_layer_names draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names - assert len(draft_attn_layer_names) == 1 self.attn_layer_names = list(sorted(draft_attn_layer_names)) if supports_multimodal(model): @@ -276,6 +275,8 @@ def load_model(self, model: nn.Module) -> None: # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.method == "dflash": + self.model.lm_head = target_language_model.lm_head if self.method == "eagle" and hasattr(model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") if supports_multimodal(model): @@ -325,6 +326,16 @@ def dummy_run(self, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None, is_profile=False): + # DFlash uses a different dummy_run path + if self.method == "dflash": + self._dummy_run_dflash( + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_profile=is_profile, + ) + return + # update global cos, sin update_cos_sin(self.positions[:num_tokens]) @@ -416,6 +427,199 @@ def dummy_run(self, self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) + @torch.inference_mode() + def _dummy_run_dflash( + self, + num_tokens: int, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + is_profile: bool = False + ): + num_tokens = num_tokens // 2 + if not self.use_cuda_graph: + aclgraph_runtime_mode = CUDAGraphMode.NONE + num_iters_to_capture = 1 + for _ in range(num_iters_to_capture): + ( + num_input_tokens, + num_tokens_across_dp, + _, + _, + ) = self.runner._sync_metadata_across_dp(num_tokens, + is_draft_model=True) + if self.use_cuda_graph and \ + num_input_tokens <= self.runner.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_input_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + else: + cudagraph_runtime_mode = CUDAGraphMode.NONE + + positions_len = 2 * num_input_tokens + update_cos_sin(self.positions[:positions_len]) + + with set_ascend_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_actual_tokens=0, + in_profile_run=is_profile, + aclgraph_runtime_mode=cudagraph_runtime_mode, + is_draft_model=True): + + input_ids = self.input_ids[:num_input_tokens] + positions = self.positions[:positions_len] + hidden_states = self.hidden_states[:num_input_tokens] + + self.model( + input_ids=input_ids, + positions=positions, + hidden_states=hidden_states, + inputs_embeds=None, + ) + + def _dflash_propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, + common_attn_metadata: CommonAttentionMetadata, + ) -> torch.Tensor: + batch_size = common_attn_metadata.num_reqs + assert self.runner is not None + target_hidden_states = self.model.combine_hidden_states(target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + query_start_loc = common_attn_metadata.query_start_loc[:batch_size + 1] + num_context_tokens = target_token_ids.shape[0] + num_query_tokens = 1 + self.num_speculative_tokens + total_num_query_tokens = batch_size * num_query_tokens + num_kv_tokens = num_context_tokens + total_num_query_tokens + + MASK_TOKEN_ID = 151669 + + query_positions_list = [] + for i in range(batch_size): + last_position = target_positions[last_token_indices[i]] + query_positions = ( + torch.arange( + num_query_tokens, + device=target_positions.device, + dtype=target_positions.dtype, + ) + + 1 + + last_position + ) + query_positions_list.append(query_positions) + position_ids = torch.cat([target_positions] + query_positions_list) + assert position_ids.shape[0] == num_kv_tokens + + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + block_size = attn_metadata_builder.kv_cache_spec.block_size + block_table_tensor = common_attn_metadata.block_table_tensor + query_slot_mapping_list = [] + for i in range(batch_size): + query_position_ids_i = query_positions_list[i] + block_numbers_i = query_position_ids_i // block_size + block_ids_i = block_table_tensor[i].gather( + dim=0, index=block_numbers_i.long() + ) + slot_mapping_i = ( + block_ids_i * block_size + query_position_ids_i % block_size + ).to(torch.int32) + query_slot_mapping_list.append(slot_mapping_i) + common_attn_metadata.slot_mapping = torch.cat( + [common_attn_metadata.slot_mapping[:target_hidden_states.shape[0]]] + query_slot_mapping_list) + common_attn_metadata.num_actual_tokens = num_kv_tokens + common_attn_metadata.max_query_len = num_kv_tokens + common_attn_metadata.query_start_loc = ( + self.arange[:batch_size + 1] * num_query_tokens + ) + common_attn_metadata.query_start_loc_cpu = ( + torch.from_numpy(self.token_arange_np[:batch_size + 1]).clone() + * num_query_tokens + ) + common_attn_metadata.max_seq_len += num_query_tokens + common_attn_metadata.seq_lens[:batch_size] = ( + common_attn_metadata.seq_lens[:batch_size] + num_query_tokens + ) + common_attn_metadata.seq_lens_cpu = ( + common_attn_metadata.seq_lens.cpu() + ) + common_attn_metadata.causal = False + common_attn_metadata.attn_mask = None + common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata = builder.build(0, common_attn_metadata, + self.runner.get_model()) + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + + self.positions[:num_kv_tokens] = position_ids + self.input_ids[:total_num_query_tokens] = MASK_TOKEN_ID + for i in range(batch_size): + self.input_ids[i * num_query_tokens] = next_token_ids[i] + self.hidden_states[:num_context_tokens] = target_hidden_states + + update_cos_sin(position_ids) + + if self.use_cuda_graph and total_num_query_tokens <= self.runner.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph( + total_num_query_tokens + ) + aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE + else: + num_input_tokens = total_num_query_tokens + aclgraph_runtime_mode = CUDAGraphMode.NONE + + ( + num_input_tokens, + num_tokens_across_dp, + _, + _, + ) = self.runner._sync_metadata_across_dp( + num_input_tokens, is_draft_model=True + ) + + with set_ascend_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_actual_tokens=num_kv_tokens, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=CUDAGraphMode.NONE, + is_draft_model=True): + + ret_hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_kv_tokens], + hidden_states=self.hidden_states[:num_context_tokens], + inputs_embeds=None, + ) + + valid_hidden_list = [] + for i in range(batch_size): + start = i * num_query_tokens + 1 + end = (i + 1) * num_query_tokens + valid_hidden_list.append(ret_hidden_states[start:end]) + valid_hidden_states = torch.cat(valid_hidden_list, dim=0) + logits = self.model.compute_logits(valid_hidden_states) + draft_token_ids = logits.argmax(dim=-1) + return draft_token_ids.view(batch_size, self.num_speculative_tokens) + def _propose( self, # [num_tokens] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index bbc691466b7..c4857a62002 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -383,7 +383,7 @@ def _set_up_drafter(self): self.decode_token_per_req = 1 + spec_token_num if get_pp_group().is_last_rank: self.drafter = self._get_drafter() - if self.speculative_config.method == "eagle3": + if self.speculative_config.method in ("eagle3", "dflash"): assert isinstance(self.drafter, EagleProposer) self.use_aux_hidden_state_outputs = ( self.drafter.eagle3_use_aux_hidden_state) @@ -1450,21 +1450,31 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[token_indices] assert self.drafter is not None - draft_token_ids = self.drafter._propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata, - req_scheduled_tokens=req_scheduled_tokens, - long_seq_metadata=long_seq_metadata, - num_prefill_reqs=num_prefill_reqs, - num_decode_reqs=num_decode_reqs, - scheduler_output=scheduler_output, - num_scheduled_tokens=num_scheduled_tokens, - ) + if self.speculative_config.method == "dflash": + draft_token_ids = self.drafter._dflash_propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + common_attn_metadata=common_attn_metadata + ) + else: + draft_token_ids = self.drafter._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + req_scheduled_tokens=req_scheduled_tokens, + long_seq_metadata=long_seq_metadata, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, + scheduler_output=scheduler_output, + num_scheduled_tokens=num_scheduled_tokens, + ) else: raise ValueError("Unknown speculative decoding method: " @@ -2387,7 +2397,7 @@ def load_model(self) -> None: self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + self.model.get_eagle3_aux_hidden_state_layers(self.drafter.method)) if self.lora_config: self.model = self.load_lora_model(self.model, self.vllm_config,