Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/build_aclnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}

CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
Expand Down Expand Up @@ -68,7 +68,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then

CUSTOM_OPS_ARRAY=(
"grouped_matmul_swiglu_quant_weight_nz_tensor_list"
"lightning_indexer"
"lightning_indexer_vllm"
"sparse_flash_attention"
"dispatch_ffn_combine"
"dispatch_ffn_combine_bf16"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@
# ======================================================================================================================

add_ops_compile_options(
OP_NAME LightningIndexer
OP_NAME LightningIndexerVllm
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
-mllvm -cce-aicore-hoist-movemask=false
--op_relocatable_kernel_binary=true
)

set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE)
set(lightning_indexer_vllm_depends transformer/attention/lightning_indexer_vllm PARENT_SCOPE)

target_sources(op_host_aclnn PRIVATE
lightning_indexer_def.cpp
lightning_indexer_vllm_def.cpp
)

target_sources(optiling PRIVATE
lightning_indexer_tiling.cpp
lightning_indexer_vllm_tiling.cpp
)

if (NOT BUILD_OPEN_PROJECT)
target_sources(opmaster_ct PRIVATE
lightning_indexer_tiling.cpp
lightning_indexer_vllm_tiling.cpp
)
endif ()

Expand All @@ -37,6 +37,6 @@ target_include_directories(optiling PRIVATE
)

target_sources(opsproto PRIVATE
lightning_indexer_proto.cpp
lightning_indexer_vllm_proto.cpp
)

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
#include "register/op_def_registry.h"

namespace ops {
class LightningIndexer : public OpDef {
class LightningIndexerVllm : public OpDef {
public:
explicit LightningIndexer(const char *name) : OpDef(name)
explicit LightningIndexerVllm(const char *name) : OpDef(name)
{
this->Input("query")
.ParamType(REQUIRED)
Expand Down Expand Up @@ -68,5 +68,5 @@ class LightningIndexer : public OpDef {
this->AICore().AddConfig("ascend910_93", aicore_config);
}
};
OP_ADD(LightningIndexer);
OP_ADD(LightningIndexerVllm);
} // namespace ops
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext
return GRAPH_SUCCESS;
}

IMPL_OP_INFERSHAPE(LightningIndexer)
IMPL_OP_INFERSHAPE(LightningIndexerVllm)
.InferShape(InferShapeLightningIndexer)
.InferDataType(InferDataTypeLightningIndexer);
} // namespace ops
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* \brief
*/

#include "lightning_indexer_tiling.h"
#include "lightning_indexer_vllm_tiling.h"
#include "../op_kernel/lightning_indexer_template_tiling_key.h"

using namespace ge;
Expand Down Expand Up @@ -687,7 +687,7 @@ ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context)
return liTiling.DoTiling(&liInfo);
}

IMPL_OP_OPTILING(LightningIndexer)
IMPL_OP_OPTILING(LightningIndexerVllm)
.Tiling(TilingForLightningIndexer)
.TilingParse<LICompileInfo>(TilingPrepareForLightningIndexer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TILING_DATA_FIELD_DEF(uint32_t, blockSize)
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
END_TILING_DATA_DEF
REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData)
REGISTER_TILING_DATA_CLASS(LightningIndexerVllm, LITilingData)

struct LICompileInfo {};

Expand Down Expand Up @@ -212,4 +212,4 @@ class LightningIndexerTiling {
};

} // namespace optiling
#endif // LIGHTNING_INDEXER_TILING_H_
#endif // LIGHTNING_INDEXER_TILING_H_
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#define ASCENDC_TPL_4_BW 4

ASCENDC_TPL_ARGS_DECL(LightningIndexer,
ASCENDC_TPL_ARGS_DECL(LightningIndexerVllm,
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16),
ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16),
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using namespace LIKernel;


template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int LAYOUT_T, int K_LAYOUT_T>
__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__global__ __aicore__ void lightning_indexer_vllm(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ at::Tensor npu_lightning_indexer(
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
EXEC_NPU_CMD(
aclnnLightningIndexer,
aclnnLightningIndexerVllm,
query,
key,
weights,
Expand Down
50 changes: 37 additions & 13 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ def __init__(
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.cp_size = 1
self.is_rope_neox_style = True
self.use_torch_npu_lightning_indexer = False
if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]:
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
Copy link
Copy Markdown
Collaborator

@whx-sjtu whx-sjtu Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to obtain is_neox_style from model_config. Refactor here later.


self.enable_dsa_cp = enable_dsa_cp()
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
Expand Down Expand Up @@ -973,7 +978,9 @@ def indexer_select_pre_process(

cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True)
q, k = rope_forward_triton(
q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
)
else:
k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1)

Expand Down Expand Up @@ -1036,18 +1043,35 @@ def indexer_select_post_process(
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp

topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q,
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
else:
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
return topk_indices

def _init_o_proj_tp_full_params(self):
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/quantization/modelslim_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"glm_moe_dsa": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
Expand Down
18 changes: 16 additions & 2 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ def dummy_run(
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
) -> None:
if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch:
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing.
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
super().dummy_run(
num_tokens,
with_prefill,
Expand Down Expand Up @@ -166,7 +173,14 @@ def _propose(
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor:
if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch:
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing.
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
draft_token_ids = super()._propose(
target_token_ids,
target_positions,
Expand Down
Loading