Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
cee3155
[feat] ngram npu implementtation compatilble with aync scheduler
Mar 16, 2026
fc0a481
[feat] fix format
Mar 16, 2026
ec52d9c
Merge branch 'main' into ngram_aync_dev
HF-001 Mar 19, 2026
5b242fb
Merge branch 'main' into ngram_aync_dev
wangxiyuan Mar 19, 2026
8430721
Merge branch 'main' into ngram_aync_dev
HF-001 Mar 23, 2026
897a4ed
Merge branch 'main' into ngram_aync_dev
HF-001 Mar 23, 2026
41f7993
Merge branch 'main' into ngram_aync_dev
HF-001 Mar 24, 2026
df46f31
add ci test
Mar 26, 2026
5be5f27
Merge branch 'main' into ngram_aync_dev
HF-001 Mar 26, 2026
2ca987e
[Feature] ngram npu implementtation compatilble with aync scheduler
HF-001 Apr 16, 2026
a8892a3
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 16, 2026
83f89f0
[Feature] ngram npu implementtation compatilble with aync scheduler
HF-001 Apr 16, 2026
6ed8f5e
[Feature] ngram npu implementtation compatilble with aync scheduler
HF-001 Apr 16, 2026
69a09b6
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 16, 2026
0549943
[Feature] ngram npu implementtation compatilble with aync scheduler
HF-001 Apr 17, 2026
7371dca
fix
HF-001 Apr 17, 2026
b80eaa6
fix
HF-001 Apr 17, 2026
d1d5e0e
fix
HF-001 Apr 17, 2026
333445a
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 17, 2026
f622da8
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 21, 2026
89b9e79
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 21, 2026
1be1c8c
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 21, 2026
f888711
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 21, 2026
1793862
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 22, 2026
e64d663
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 22, 2026
d6774f8
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 22, 2026
436ad8a
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 22, 2026
aaf0f37
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 22, 2026
415a23e
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
3e81f72
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
2ea648b
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
4abe961
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
a38f73e
fix
HF-001 Apr 23, 2026
c17d023
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
bc56054
fix
HF-001 Apr 23, 2026
790352d
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
1d2810a
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
00c3428
fix
HF-001 Apr 23, 2026
e6e2ac4
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 23, 2026
ca5ca2f
fix
HF-001 Apr 24, 2026
50f0f30
fix
HF-001 Apr 24, 2026
14b11f0
fix
HF-001 Apr 25, 2026
7b0d8f8
fix
HF-001 Apr 25, 2026
b41673b
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 25, 2026
6041181
fix
HF-001 Apr 25, 2026
d7478b3
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 25, 2026
1c7d325
fix
HF-001 Apr 25, 2026
2c8d4d8
fix ci
HF-001 Apr 25, 2026
d80aeb7
fix
HF-001 Apr 25, 2026
c599597
fix
HF-001 Apr 26, 2026
3ba1432
fix
HF-001 Apr 27, 2026
6736c84
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 27, 2026
9026c57
fix
HF-001 Apr 27, 2026
d288256
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 27, 2026
4ac556e
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 27, 2026
9ca4259
fix
HF-001 Apr 27, 2026
daf0fe4
fix
HF-001 Apr 28, 2026
71c4374
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 28, 2026
fbb37a6
fix
HF-001 Apr 28, 2026
c3ba517
fix
HF-001 Apr 29, 2026
854a421
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 29, 2026
9ef3490
fix
HF-001 Apr 29, 2026
bef0e34
Merge branch 'main' into ngram_async_dev2
HF-001 Apr 29, 2026
dc17c7f
fix
HF-001 Apr 29, 2026
a0abcf2
Revert "fix"
HF-001 Apr 29, 2026
67a0c1d
Merge branch 'main' into ngram_async_dev2
HF-001 May 8, 2026
5203bfa
fix
HF-001 May 8, 2026
871835f
fix
HF-001 May 9, 2026
918ad9d
Merge branch 'main' into ngram_async_dev2
HF-001 May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/build_aclnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}


CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;hamming_dist_top_k;reshape_and_cache_bnsd;recurrent_gated_delta_rule;"
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;ngram_spec_decode;hamming_dist_top_k;reshape_and_cache_bnsd;recurrent_gated_delta_rule;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
Expand Down Expand Up @@ -61,6 +61,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"causal_conv1d"
"moe_grouped_matmul"
"lightning_indexer_quant"
"ngram_spec_decode"
"hamming_dist_top_k"
"reshape_and_cache_bnsd"
"recurrent_gated_delta_rule"
Expand Down
82 changes: 82 additions & 0 deletions csrc/ngram_spec_decode/ngram_spec_decode_torch_adpt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*/
#ifndef NGRAM_SPEC_DECODE_TORCH_ADPT_H
#define NGRAM_SPEC_DECODE_TORCH_ADPT_H

#include <torch/extension.h>
#include <torch_npu/csrc/framework/OpCommand.h>

namespace vllm_ascend {

// N-gram spec decode op
// inputs:
// token_ids: [batch_size, max_seq_len], int32,
// num_tokens_no_spec: [batch_size], int32
// sampled_token_ids: [batch_size, max_new_tokens], int32
// discard_request_mask: [batch_size], int32
// vocab_size, min_n, max_n, k
// outputs:
// token_ids (in-place change), next_token_ids, draft_token_ids, num_valid_draft_tokens
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_ngram_spec_decode(
at::Tensor &token_ids,
const at::Tensor &num_tokens_no_spec,
const at::Tensor &sampled_token_ids,
const at::Tensor &discard_request_mask,
int64_t vocab_size,
int64_t min_n,
int64_t max_n,
int64_t k)
{
int64_t batch_size = token_ids.size(0);
auto device = token_ids.device();

at::Tensor discard_mask_int = discard_request_mask.dtype() == at::kBool
? discard_request_mask.to(at::kInt)
: discard_request_mask;

// Allocate outputs with a trailing over-write cushion. The kernel's
// CopyOut path issues DataCopyPad GM writes whose burst length can
// be smaller than the NPU's 32-byte MTE alignment; under that
// alignment the underlying MTE3 burst can write past the apparent
// tensor end on the last row. Tightly-sized allocations (the original
// ``at::empty({batch_size}, ...)``) leave no room for that
// alignment-driven over-write, surfacing as a multi-core MTE OOB on
// device (CI signature: fixp_error0 = 0x30266b9 across cores).
//
// We therefore allocate ``batch_size + OVER_WRITE_MARGIN`` rows /
// ``(batch_size + OVER_WRITE_MARGIN) * k`` elements and ``narrow``
// back to the user-visible shape. The narrowed view shares storage
// with the larger allocation, so any kernel-side alignment
// over-write lands inside owned memory rather than off the end.
constexpr int64_t OVER_WRITE_MARGIN = 8; // 32 bytes / sizeof(int32) = 8 ints

at::Tensor next_token_ids_storage = at::empty(
{batch_size + OVER_WRITE_MARGIN},
at::dtype(at::kInt).device(device));
at::Tensor next_token_ids = next_token_ids_storage.narrow(0, 0, batch_size);

at::Tensor draft_token_ids_storage = at::empty(
{batch_size + OVER_WRITE_MARGIN, k},
at::dtype(at::kInt).device(device));
at::Tensor draft_token_ids = draft_token_ids_storage.narrow(0, 0, batch_size);

at::Tensor num_valid_draft_tokens_storage = at::empty(
{batch_size + OVER_WRITE_MARGIN},
at::dtype(at::kInt).device(device));
at::Tensor num_valid_draft_tokens =
num_valid_draft_tokens_storage.narrow(0, 0, batch_size);

EXEC_NPU_CMD(aclnnNgramSpecDecode,
token_ids, num_tokens_no_spec, sampled_token_ids, discard_mask_int,
vocab_size, min_n, max_n, k,
next_token_ids, draft_token_ids, num_valid_draft_tokens);

return std::make_tuple(token_ids, next_token_ids, draft_token_ids, num_valid_draft_tokens);
}

} // namespace vllm_ascend

#endif // NGRAM_SPEC_DECODE_TORCH_ADPT_H
40 changes: 40 additions & 0 deletions csrc/ngram_spec_decode/op_host/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
add_ops_compile_options(
OP_NAME NgramSpecDecode
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)

target_sources(op_host_aclnnInner PRIVATE
ngram_spec_decode_def.cpp
)

target_sources(opapi PRIVATE
aclnn_ngram_spec_decode.cpp
)

if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_ngram_spec_decode.cpp
)

target_sources(aclnn_ops_infer PRIVATE
aclnn_ngram_spec_decode.cpp
)
endif ()

target_sources(optiling PRIVATE
ngram_spec_decode_tiling.cpp
)

target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)

target_sources(opsproto PRIVATE)

file(GLOB _Ngram_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_ngram_spec_decode.h")

install(FILES ${_Ngram_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)
73 changes: 73 additions & 0 deletions csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <string.h>
#include "graph/types.h"
#include "aclnn_ngram_spec_decode.h"

enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);

#ifdef __cplusplus
extern "C" {
#endif

extern aclnnStatus aclnnInnerNgramSpecDecodeGetWorkspaceSize(
const aclTensor *tokenIds,
const aclTensor *numTokensNoSpec,
const aclTensor *sampledTokenIds,
const aclTensor *discardRequestMask,
int64_t vocabSize,
int64_t minN,
int64_t maxN,
int64_t k,
const aclTensor *nextTokenIds,
const aclTensor *draftTokenIds,
const aclTensor *numValidDraftTokens,
uint64_t *workspaceSize,
aclOpExecutor **executor);

extern aclnnStatus aclnnInnerNgramSpecDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);

aclnnStatus aclnnNgramSpecDecodeGetWorkspaceSize(
const aclTensor *tokenIds,
const aclTensor *numTokensNoSpec,
const aclTensor *sampledTokenIds,
const aclTensor *discardRequestMask,
int64_t vocabSize,
int64_t minN,
int64_t maxN,
int64_t k,
const aclTensor *nextTokenIds,
const aclTensor *draftTokenIds,
const aclTensor *numValidDraftTokens,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerNgramSpecDecodeGetWorkspaceSize(
tokenIds, numTokensNoSpec, sampledTokenIds, discardRequestMask,
vocabSize, minN, maxN, k,
nextTokenIds, draftTokenIds, numValidDraftTokens,
workspaceSize, executor);
}

aclnnStatus aclnnNgramSpecDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
return aclnnInnerNgramSpecDecode(workspace, workspaceSize, executor, stream);
}

#ifdef __cplusplus
}
#endif
56 changes: 56 additions & 0 deletions csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef ACLNN_NGRAM_SPEC_DECODE_H_
#define ACLNN_NGRAM_SPEC_DECODE_H_

#include "aclnn/acl_meta.h"

#ifdef __cplusplus
extern "C" {
#endif

/* function: aclnnNgramSpecDecodeGetWorkspaceSize
* tokenIds : required, [batch_size, max_seq_len], int32
* numTokensNoSpec : required, [batch_size], int32
* sampledTokenIds : required, [batch_size, max_new_tokens], int32
* discardRequestMask : required, [batch_size], int32
* vocabSize : required, int
* minN : required, int
* maxN : required, int
* k : required, int
* nextTokenIds : required, [batch_size], int32
* draftTokenIds : required, [batch_size, k], int32
* numValidDraftTokens : required, [batch_size], int32
* workspaceSize : size of workspace(output).
* executor : executor context(output).
*/
__attribute__((visibility("default"))) aclnnStatus aclnnNgramSpecDecodeGetWorkspaceSize(
const aclTensor *tokenIds,
const aclTensor *numTokensNoSpec,
const aclTensor *sampledTokenIds,
const aclTensor *discardRequestMask,
int64_t vocabSize,
int64_t minN,
int64_t maxN,
int64_t k,
const aclTensor *nextTokenIds,
const aclTensor *draftTokenIds,
const aclTensor *numValidDraftTokens,
uint64_t *workspaceSize,
aclOpExecutor **executor);

/* function: aclnnNgramSpecDecode
* workspace : workspace memory addr(input).
* workspaceSize : size of workspace(input).
* executor : executor context(input).
* stream : acl stream.
*/
__attribute__((visibility("default"))) aclnnStatus aclnnNgramSpecDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);

#ifdef __cplusplus
}
#endif

#endif // ACLNN_NGRAM_SPEC_DECODE_H_
72 changes: 72 additions & 0 deletions csrc/ngram_spec_decode/op_host/ngram_spec_decode_def.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "register/op_def_registry.h"

namespace ops {
class NgramSpecDecode : public OpDef {
public:
explicit NgramSpecDecode(const char *name) : OpDef(name)
{
this->Input("tokenIds")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

this->Input("numTokensNoSpec")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

this->Input("sampledTokenIds")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

this->Input("discardRequestMask")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

this->Attr("vocab_size").Int();
this->Attr("min_n").Int();
this->Attr("max_n").Int();
this->Attr("k").Int();

this->Output("nextTokenIds")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

this->Output("draftTokenIds")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

this->Output("numValidDraftTokens")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_true")
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");

this->AICore().AddConfig("ascend910b", aicore_config);
this->AICore().AddConfig("ascend910_93", aicore_config);
}
};

OP_ADD(NgramSpecDecode);
} // namespace ops
Loading
Loading