Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,21 +849,37 @@ def forward_fused_infer_attention(
learnable_sink=self.sinks,
)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
if not attn_metadata.causal:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=0,
)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
Comment on lines +852 to +882
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the if and else blocks. This makes the code harder to maintain, as changes to the arguments of torch_npu.npu_fused_infer_attention_score must be applied in two places, increasing the risk of introducing bugs.

To improve maintainability, you can refactor the common arguments into a dictionary.

            common_args = {
                "query": query,
                "key": key,
                "value": value,
                "block_table": block_table,
                "input_layout": "TND",
                "block_size": block_size,
                "actual_seq_lengths": attn_metadata.actual_seq_lengths_q,
                "actual_seq_lengths_kv": actual_seq_lengths_kv,
                "num_key_value_heads": self.num_kv_heads,
                "num_heads": self.num_heads,
                "scale": self.scale,
            }
            if not attn_metadata.causal:
                attn_output, _ = torch_npu.npu_fused_infer_attention_score(
                    **common_args,
                    sparse_mode=0,
                )
            else:
                attn_output, _ = torch_npu.npu_fused_infer_attention_score(
                    **common_args,
                    atten_mask=attn_metadata.attn_mask,
                    sparse_mode=3,
                )


attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output[:num_tokens]
Expand Down Expand Up @@ -910,6 +926,28 @@ def _forward_encoder_attention(
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]

def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: list[torch.Tensor],
slot_mapping: torch.Tensor,
) -> None:
if self.attn_type in (AttentionType.ENCODER_ONLY):
return

if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]

DeviceOperator.reshape_and_cache(
key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=slot_mapping,
)

def reshape_and_cache(
self,
query: torch.Tensor,
Expand Down
71 changes: 71 additions & 0 deletions vllm_ascend/ops/triton/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,74 @@ def prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask)


@triton.jit
def copy_and_expand_dflash_inputs_kernel_single_grid(
# Inputs
next_token_ids_ptr, # [num_reqs]
target_positions_ptr, # [num_context]
# Outputs
out_input_ids_ptr, # [num_query_total] (output)
out_context_positions_ptr, # [num_context] (output)
out_query_positions_ptr, # [num_query_total] (output)
out_context_slot_mapping_ptr, # [num_context] (output)
out_query_slot_mapping_ptr, # [num_query_total] (output)
out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output)
# Block table
block_table_ptr, # [max_reqs, max_blocks]
block_table_stride, # stride of block_table dim 0 (in elements)
# Metadata
query_start_loc_ptr, # [num_reqs + 1]
num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded
# Scalars
parallel_drafting_token_id, # tl.int32
block_size, # tl.int32
num_query_per_req, # tl.int32
num_speculative_tokens, # tl.int32
total_input_tokens, # tl.int32
batch_size, # tl.int32
HAS_NUM_REJECTED: tl.constexpr = False,
):
for req_idx in range(0, batch_size):
ctx_start = tl.load(query_start_loc_ptr + req_idx)
ctx_end = tl.load(query_start_loc_ptr + req_idx + 1)
num_ctx = ctx_end - ctx_start

for j in range(0, num_ctx):
ctx_pos_idx = ctx_start + j
pos = tl.load(target_positions_ptr + ctx_pos_idx)
tl.store(out_context_positions_ptr + ctx_pos_idx, pos)

block_num = pos // block_size
block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_num).to(tl.int64)
slot = block_id * block_size + (pos % block_size)
tl.store(out_context_slot_mapping_ptr + ctx_pos_idx, slot)

if HAS_NUM_REJECTED:
num_rejected = tl.load(num_rejected_tokens_ptr + req_idx)
valid_ctx_end = ctx_end - num_rejected
else:
valid_ctx_end = ctx_end

last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1)

for q_idx in range(0, num_query_per_req):
query_pos = last_pos + 1 + q_idx
query_out_idx = req_idx * num_query_per_req + q_idx

tl.store(out_query_positions_ptr + query_out_idx, query_pos)

block_num_q = query_pos // block_size
block_id_q = tl.load(block_table_ptr + req_idx * block_table_stride + block_num_q).to(tl.int64)
slot_q = block_id_q * block_size + (query_pos % block_size)
tl.store(out_query_slot_mapping_ptr + query_out_idx, slot_q)

if q_idx == 0:
bonus_token = tl.load(next_token_ids_ptr + req_idx)
tl.store(out_input_ids_ptr + query_out_idx, bonus_token)
else:
tl.store(out_input_ids_ptr + query_out_idx, parallel_drafting_token_id)

sample_out_idx = req_idx * num_speculative_tokens + (q_idx - 1)
tl.store(out_token_indices_ptr + sample_out_idx, query_out_idx)
10 changes: 10 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,13 @@
# when using mrope.
# Future Plan:
# Remove this patch when vllm-ascend supports pattern matching for this fused kernel.
# ** 29. File: worker/patch_qwen3_dflash.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3_dflash.DFlashQwen3Model.precompute_and_store_context_kv`
# Why:
# The function directly calls the ops.rms_norm and ops.rotary_imbedding operators,
# but NPU does not have a corresponding implementation.
# How:
# Replace ops.* with the internal implementation of vllm-ascend.
# Future Plan:
# Remove this patch when vllm-ascend supports pattern matching for ops.*.
5 changes: 4 additions & 1 deletion vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from vllm.triton_utils import HAS_TRITON

from vllm_ascend.utils import is_310p
from vllm_ascend.utils import is_310p, vllm_version_is

if HAS_TRITON:
import vllm_ascend.patch.worker.patch_triton
Expand All @@ -39,6 +39,9 @@
if not is_310p():
import vllm_ascend.patch.worker.patch_qwen3_5 # noqa
import vllm_ascend.patch.worker.patch_gdn_attn # noqa

if not vllm_version_is("0.19.0"):
import vllm_ascend.patch.worker.patch_qwen3_dflash # noqa
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
import vllm_ascend.patch.worker.patch_v2.patch_uva # noqa
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
Expand Down
62 changes: 62 additions & 0 deletions vllm_ascend/patch/worker/patch_qwen3_dflash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3Model


def precompute_and_store_context_kv(
self,
context_states: torch.Tensor,
context_positions: torch.Tensor,
context_slot_mapping: torch.Tensor | None = None,
) -> None:
if not hasattr(self, "_num_attn_layers"):
self._build_fused_kv_buffers()

num_ctx = context_states.shape[0]
L = self._num_attn_layers
kv = self._kv_size
hd = self._head_dim
nkv = self._num_kv_heads

# --- Fused KV projection (one GEMM for all layers) ---
normed_context_states = self.hidden_norm(context_states)
all_kv_flat = F.linear(normed_context_states, self._fused_kv_weight, self._fused_kv_bias)
# Single contiguous copy that separates K/V and transposes to
# layer-major layout. Result: [2, L, num_ctx, nkv, hd] contiguous.
# Indexing dim-0 gives contiguous [L, num_ctx, nkv, hd] for K and V.
all_kv = all_kv_flat.view(num_ctx, L, 2, nkv, hd).permute(2, 1, 0, 3, 4).contiguous()
all_k = all_kv[0] # [L, num_ctx, nkv, hd], contiguous
all_v = all_kv[1] # [L, num_ctx, nkv, hd], contiguous

# --- Per-layer RMSNorm K (3D: [num_ctx, nkv, hd] per layer) ---
all_k_normed = torch.empty_like(all_k)
for i in range(L):
k_norm_layer = self.layers[i].self_attn.k_norm
all_k_normed[i] = k_norm_layer(all_k[i])

# --- Fused RoPE across all layers ---
# View as [L * num_ctx, kv] so RoPE sees one big batch (no copy).
# In-place RoPE: pass K as the "query" arg with key=None.
all_k_flat = all_k_normed.view(L * num_ctx, kv)
positions_repeated = context_positions.repeat(L)
tmpv = all_k_flat.clone()
self.layers[0].self_attn.rotary_emb(positions_repeated, all_k_flat, tmpv)

if context_slot_mapping is None:
return

# --- Per-layer cache insert ---
all_k_final = all_k_flat.view(L, num_ctx, nkv, hd)
for i in range(L):
attn = self._attn_layers[i]
kv_cache = attn.kv_cache
attn.impl.do_kv_cache_update(
attn,
all_k_final[i],
all_v[i],
kv_cache,
context_slot_mapping,
)


DFlashQwen3Model.precompute_and_store_context_kv = precompute_and_store_context_kv
8 changes: 8 additions & 0 deletions vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#


from vllm_ascend.spec_decode.dflash_proposer import AscendDflashProposer
from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
from vllm_ascend.utils import vllm_version_is


def get_spec_decode_method(method, vllm_config, device, runner):
Expand All @@ -33,6 +36,11 @@ def get_spec_decode_method(method, vllm_config, device, runner):
return AscendMedusaProposer(vllm_config, device)
elif method in ("eagle", "eagle3", "mtp"):
return AscendEagleProposer(vllm_config, device, runner)
elif method == "dflash":
if not vllm_version_is("0.19.0"):
return AscendDflashProposer(vllm_config, device, runner)
else:
raise ValueError(f"VLLM v0.19.0 doesn't support {method} now")
elif method == "draft_model":
return AscendDraftModelProposer(vllm_config, device, runner)
else:
Expand Down
Loading
Loading