diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index bbfe70a3a66..43c7a031ba3 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -23,7 +23,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;hamming_dist_top_k;reshape_and_cache_bnsd;recurrent_gated_delta_rule;" + CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;ngram_spec_decode;hamming_dist_top_k;reshape_and_cache_bnsd;recurrent_gated_delta_rule;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -61,6 +61,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "causal_conv1d" "moe_grouped_matmul" "lightning_indexer_quant" + "ngram_spec_decode" "hamming_dist_top_k" "reshape_and_cache_bnsd" "recurrent_gated_delta_rule" diff --git a/csrc/ngram_spec_decode/ngram_spec_decode_torch_adpt.h b/csrc/ngram_spec_decode/ngram_spec_decode_torch_adpt.h new file mode 100644 index 00000000000..8bec08c2925 --- /dev/null +++ b/csrc/ngram_spec_decode/ngram_spec_decode_torch_adpt.h @@ -0,0 +1,82 @@ +/* + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + */ +#ifndef NGRAM_SPEC_DECODE_TORCH_ADPT_H +#define NGRAM_SPEC_DECODE_TORCH_ADPT_H + +#include +#include + +namespace vllm_ascend { + +// N-gram spec decode op +// inputs: +// token_ids: [batch_size, max_seq_len], int32, +// num_tokens_no_spec: [batch_size], int32 +// sampled_token_ids: [batch_size, max_new_tokens], int32 +// discard_request_mask: [batch_size], int32 +// vocab_size, min_n, max_n, k +// outputs: +// token_ids (in-place change), next_token_ids, draft_token_ids, num_valid_draft_tokens +inline std::tuple npu_ngram_spec_decode( + at::Tensor &token_ids, + const at::Tensor &num_tokens_no_spec, + const at::Tensor &sampled_token_ids, + const at::Tensor &discard_request_mask, + int64_t vocab_size, + int64_t min_n, + int64_t max_n, + int64_t k) +{ + int64_t batch_size = token_ids.size(0); + auto device = token_ids.device(); + + at::Tensor discard_mask_int = discard_request_mask.dtype() == at::kBool + ? discard_request_mask.to(at::kInt) + : discard_request_mask; + + // Allocate outputs with a trailing over-write cushion. The kernel's + // CopyOut path issues DataCopyPad GM writes whose burst length can + // be smaller than the NPU's 32-byte MTE alignment; under that + // alignment the underlying MTE3 burst can write past the apparent + // tensor end on the last row. Tightly-sized allocations (the original + // ``at::empty({batch_size}, ...)``) leave no room for that + // alignment-driven over-write, surfacing as a multi-core MTE OOB on + // device (CI signature: fixp_error0 = 0x30266b9 across cores). + // + // We therefore allocate ``batch_size + OVER_WRITE_MARGIN`` rows / + // ``(batch_size + OVER_WRITE_MARGIN) * k`` elements and ``narrow`` + // back to the user-visible shape. The narrowed view shares storage + // with the larger allocation, so any kernel-side alignment + // over-write lands inside owned memory rather than off the end. + constexpr int64_t OVER_WRITE_MARGIN = 8; // 32 bytes / sizeof(int32) = 8 ints + + at::Tensor next_token_ids_storage = at::empty( + {batch_size + OVER_WRITE_MARGIN}, + at::dtype(at::kInt).device(device)); + at::Tensor next_token_ids = next_token_ids_storage.narrow(0, 0, batch_size); + + at::Tensor draft_token_ids_storage = at::empty( + {batch_size + OVER_WRITE_MARGIN, k}, + at::dtype(at::kInt).device(device)); + at::Tensor draft_token_ids = draft_token_ids_storage.narrow(0, 0, batch_size); + + at::Tensor num_valid_draft_tokens_storage = at::empty( + {batch_size + OVER_WRITE_MARGIN}, + at::dtype(at::kInt).device(device)); + at::Tensor num_valid_draft_tokens = + num_valid_draft_tokens_storage.narrow(0, 0, batch_size); + + EXEC_NPU_CMD(aclnnNgramSpecDecode, + token_ids, num_tokens_no_spec, sampled_token_ids, discard_mask_int, + vocab_size, min_n, max_n, k, + next_token_ids, draft_token_ids, num_valid_draft_tokens); + + return std::make_tuple(token_ids, next_token_ids, draft_token_ids, num_valid_draft_tokens); +} + +} // namespace vllm_ascend + +#endif // NGRAM_SPEC_DECODE_TORCH_ADPT_H diff --git a/csrc/ngram_spec_decode/op_host/CMakeLists.txt b/csrc/ngram_spec_decode/op_host/CMakeLists.txt new file mode 100644 index 00000000000..e22ecd87c20 --- /dev/null +++ b/csrc/ngram_spec_decode/op_host/CMakeLists.txt @@ -0,0 +1,40 @@ +add_ops_compile_options( + OP_NAME NgramSpecDecode + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnInner PRIVATE + ngram_spec_decode_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_ngram_spec_decode.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + aclnn_ngram_spec_decode.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + aclnn_ngram_spec_decode.cpp + ) +endif () + +target_sources(optiling PRIVATE + ngram_spec_decode_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE) + +file(GLOB _Ngram_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_ngram_spec_decode.h") + +install(FILES ${_Ngram_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.cpp b/csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.cpp new file mode 100644 index 00000000000..3bc617fce0c --- /dev/null +++ b/csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.cpp @@ -0,0 +1,73 @@ +#include +#include "graph/types.h" +#include "aclnn_ngram_spec_decode.h" + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +#ifdef __cplusplus +extern "C" { +#endif + +extern aclnnStatus aclnnInnerNgramSpecDecodeGetWorkspaceSize( + const aclTensor *tokenIds, + const aclTensor *numTokensNoSpec, + const aclTensor *sampledTokenIds, + const aclTensor *discardRequestMask, + int64_t vocabSize, + int64_t minN, + int64_t maxN, + int64_t k, + const aclTensor *nextTokenIds, + const aclTensor *draftTokenIds, + const aclTensor *numValidDraftTokens, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +extern aclnnStatus aclnnInnerNgramSpecDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +aclnnStatus aclnnNgramSpecDecodeGetWorkspaceSize( + const aclTensor *tokenIds, + const aclTensor *numTokensNoSpec, + const aclTensor *sampledTokenIds, + const aclTensor *discardRequestMask, + int64_t vocabSize, + int64_t minN, + int64_t maxN, + int64_t k, + const aclTensor *nextTokenIds, + const aclTensor *draftTokenIds, + const aclTensor *numValidDraftTokens, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + return aclnnInnerNgramSpecDecodeGetWorkspaceSize( + tokenIds, numTokensNoSpec, sampledTokenIds, discardRequestMask, + vocabSize, minN, maxN, k, + nextTokenIds, draftTokenIds, numValidDraftTokens, + workspaceSize, executor); +} + +aclnnStatus aclnnNgramSpecDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + return aclnnInnerNgramSpecDecode(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.h b/csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.h new file mode 100644 index 00000000000..5e52d5fddea --- /dev/null +++ b/csrc/ngram_spec_decode/op_host/aclnn_ngram_spec_decode.h @@ -0,0 +1,56 @@ +#ifndef ACLNN_NGRAM_SPEC_DECODE_H_ +#define ACLNN_NGRAM_SPEC_DECODE_H_ + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* function: aclnnNgramSpecDecodeGetWorkspaceSize + * tokenIds : required, [batch_size, max_seq_len], int32 + * numTokensNoSpec : required, [batch_size], int32 + * sampledTokenIds : required, [batch_size, max_new_tokens], int32 + * discardRequestMask : required, [batch_size], int32 + * vocabSize : required, int + * minN : required, int + * maxN : required, int + * k : required, int + * nextTokenIds : required, [batch_size], int32 + * draftTokenIds : required, [batch_size, k], int32 + * numValidDraftTokens : required, [batch_size], int32 + * workspaceSize : size of workspace(output). + * executor : executor context(output). + */ +__attribute__((visibility("default"))) aclnnStatus aclnnNgramSpecDecodeGetWorkspaceSize( + const aclTensor *tokenIds, + const aclTensor *numTokensNoSpec, + const aclTensor *sampledTokenIds, + const aclTensor *discardRequestMask, + int64_t vocabSize, + int64_t minN, + int64_t maxN, + int64_t k, + const aclTensor *nextTokenIds, + const aclTensor *draftTokenIds, + const aclTensor *numValidDraftTokens, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +/* function: aclnnNgramSpecDecode + * workspace : workspace memory addr(input). + * workspaceSize : size of workspace(input). + * executor : executor context(input). + * stream : acl stream. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnNgramSpecDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif // ACLNN_NGRAM_SPEC_DECODE_H_ diff --git a/csrc/ngram_spec_decode/op_host/ngram_spec_decode_def.cpp b/csrc/ngram_spec_decode/op_host/ngram_spec_decode_def.cpp new file mode 100644 index 00000000000..b48b4de7f9d --- /dev/null +++ b/csrc/ngram_spec_decode/op_host/ngram_spec_decode_def.cpp @@ -0,0 +1,72 @@ +#include "register/op_def_registry.h" + +namespace ops { +class NgramSpecDecode : public OpDef { +public: + explicit NgramSpecDecode(const char *name) : OpDef(name) + { + this->Input("tokenIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("numTokensNoSpec") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("sampledTokenIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("discardRequestMask") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Attr("vocab_size").Int(); + this->Attr("min_n").Int(); + this->Attr("max_n").Int(); + this->Attr("k").Int(); + + this->Output("nextTokenIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("draftTokenIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("numValidDraftTokens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_true") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; + +OP_ADD(NgramSpecDecode); +} // namespace ops diff --git a/csrc/ngram_spec_decode/op_host/ngram_spec_decode_tiling.cpp b/csrc/ngram_spec_decode/op_host/ngram_spec_decode_tiling.cpp new file mode 100644 index 00000000000..05cc4711d4e --- /dev/null +++ b/csrc/ngram_spec_decode/op_host/ngram_spec_decode_tiling.cpp @@ -0,0 +1,118 @@ +#include +#include +#include "log/ops_log.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "../op_kernel/ngram_spec_decode_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "platform/platform_infos_def.h" + +using namespace ge; +namespace { +constexpr uint32_t INPUT_TOKEN_IDS_INDEX = 0; +constexpr uint32_t INPUT_NUM_TOKENS_INDEX = 1; +constexpr uint32_t INPUT_SAMPLED_INDEX = 2; +constexpr uint32_t INPUT_DISCARD_INDEX = 3; + +constexpr uint32_t ATTR_VOCAB_SIZE_INDEX = 0; +constexpr uint32_t ATTR_MIN_N_INDEX = 1; +constexpr uint32_t ATTR_MAX_N_INDEX = 2; +constexpr uint32_t ATTR_K_INDEX = 3; + +constexpr int64_t ELEM_SIZE = 4; // int32 +} // namespace + +namespace optiling { + +static ge::graphStatus NgramSpecDecodeTilingFunc(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + NgramSpecDecodeTilingData *tilingData = context->GetTilingData(); + OPS_CHECK(tilingData == nullptr, + OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + + auto attrs = context->GetAttrs(); + OPS_CHECK(attrs == nullptr, + OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto vocabSizePtr = attrs->GetAttrPointer(static_cast(ATTR_VOCAB_SIZE_INDEX)); + auto minNPtr = attrs->GetAttrPointer(static_cast(ATTR_MIN_N_INDEX)); + auto maxNPtr = attrs->GetAttrPointer(static_cast(ATTR_MAX_N_INDEX)); + auto kPtr = attrs->GetAttrPointer(static_cast(ATTR_K_INDEX)); + + OPS_CHECK(vocabSizePtr == nullptr, OPS_LOG_E(nodeName, "vocabSizePtr is null."), return ge::GRAPH_FAILED); + OPS_CHECK(minNPtr == nullptr, OPS_LOG_E(nodeName, "minNPtr is null."), return ge::GRAPH_FAILED); + OPS_CHECK(maxNPtr == nullptr, OPS_LOG_E(nodeName, "maxNPtr is null."), return ge::GRAPH_FAILED); + OPS_CHECK(kPtr == nullptr, OPS_LOG_E(nodeName, "kPtr is null."), return ge::GRAPH_FAILED); + + int64_t vocab_size = *vocabSizePtr; + int64_t min_n = *minNPtr; + int64_t max_n = *maxNPtr; + int64_t k = *kPtr; + + const gert::StorageShape *tokenIdsShape = context->GetInputShape(INPUT_TOKEN_IDS_INDEX); + const gert::StorageShape *sampledShape = context->GetInputShape(INPUT_SAMPLED_INDEX); + OPS_CHECK(tokenIdsShape == nullptr, OPS_LOG_E(nodeName, "tokenIdsShape is null."), return ge::GRAPH_FAILED); + OPS_CHECK(sampledShape == nullptr, OPS_LOG_E(nodeName, "sampledShape is null."), return ge::GRAPH_FAILED); + + int64_t batch_size = tokenIdsShape->GetStorageShape().GetDim(0); + int64_t max_seq_len = tokenIdsShape->GetStorageShape().GetDim(1); + int64_t max_new_tokens = sampledShape->GetStorageShape().GetDim(1); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0UL; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + int64_t ub_size_limit = static_cast(ubSize); + + int64_t align_elems = 32 / ELEM_SIZE; + int64_t max_seq_len_align = ((max_seq_len + align_elems - 1) / align_elems) * align_elems; + int64_t max_new_tokens_align = ((max_new_tokens + align_elems - 1) / align_elems) * align_elems; + int64_t k_align = ((k + align_elems - 1) / align_elems) * align_elems; + + int64_t ub_per_row = (max_seq_len_align + max_new_tokens_align + k_align) * ELEM_SIZE; + int64_t ub_overhead = 4 * 32 + static_cast(max_n) * ELEM_SIZE + + ((max_seq_len_align + 7) / 8); // maskBuf + int64_t ub_available = ub_size_limit - ub_overhead; + int64_t max_block_rows = (ub_available > 0) ? (ub_available / ub_per_row) : 1; + max_block_rows = std::max(max_block_rows, static_cast(1)); + + int64_t block_dim = std::min(batch_size, static_cast(aivNum)); + int64_t rows_per_core = (block_dim > 0) ? (batch_size / block_dim) : 0; + int64_t former_num = (block_dim > 0) ? (block_dim - 1) : 0; + int64_t tail_rows = batch_size - former_num * rows_per_core; + int64_t block_rows = std::min(rows_per_core, max_block_rows); + + tilingData->ngramInfo.batchSize = static_cast(batch_size); + tilingData->ngramInfo.maxSeqLen = static_cast(max_seq_len); + tilingData->ngramInfo.maxNewTokens = static_cast(max_new_tokens); + tilingData->ngramInfo.vocabSize = static_cast(vocab_size); + tilingData->ngramInfo.minN = static_cast(min_n); + tilingData->ngramInfo.maxN = static_cast(max_n); + tilingData->ngramInfo.k = static_cast(k); + tilingData->ngramInfo.formerNum = static_cast(former_num); + tilingData->ngramInfo.rowsPerCore = static_cast(rows_per_core); + tilingData->ngramInfo.tailRows = static_cast(tail_rows); + tilingData->ngramInfo.blockRows = static_cast(block_rows); + + context->SetBlockDim(static_cast(block_dim)); + + OPS_LOG_D(nodeName, "batchSize=%lu, maxSeqLen=%lu, maxNewTokens=%lu, k=%lu, blockDim=%lu, blockRows=%lu", + batch_size, max_seq_len, max_new_tokens, k, block_dim, block_rows); + + return ge::GRAPH_SUCCESS; +} + +struct NgramSpecDecodeCompileInfo {}; + +ge::graphStatus TilingParseForNgramSpecDecode(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(NgramSpecDecode) + .Tiling(NgramSpecDecodeTilingFunc) + .TilingParse(TilingParseForNgramSpecDecode); + +} // namespace optiling diff --git a/csrc/ngram_spec_decode/op_host/ngram_spec_decode_tiling.h b/csrc/ngram_spec_decode/op_host/ngram_spec_decode_tiling.h new file mode 100644 index 00000000000..27847391458 --- /dev/null +++ b/csrc/ngram_spec_decode/op_host/ngram_spec_decode_tiling.h @@ -0,0 +1,26 @@ +#ifndef NGRAM_SPEC_DECODE_TILING_H +#define NGRAM_SPEC_DECODE_TILING_H + +#include "kernel_tiling/kernel_tiling.h" + +struct NgramSpecDecodeInfo { + uint32_t batchSize; + uint32_t maxSeqLen; + uint32_t maxNewTokens; + uint32_t vocabSize; + uint32_t minN; + uint32_t maxN; + uint32_t k; + uint32_t formerNum; + uint32_t rowsPerCore; + uint32_t tailRows; + uint32_t blockRows; +}; + +struct NgramSpecDecodeTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling1; + NgramSpecDecodeInfo ngramInfo; +}; + +#endif // NGRAM_SPEC_DECODE_TILING_H diff --git a/csrc/ngram_spec_decode/op_kernel/ngram_spec_decode.cpp b/csrc/ngram_spec_decode/op_kernel/ngram_spec_decode.cpp new file mode 100644 index 00000000000..60634149297 --- /dev/null +++ b/csrc/ngram_spec_decode/op_kernel/ngram_spec_decode.cpp @@ -0,0 +1,655 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernel_operator.h" +#include "ngram_spec_decode_tiling.h" + +constexpr int32_t ELEM_SIZE = sizeof(int32_t); // 4 bytes +// Safety UB buffer size:32768(128KB) +constexpr uint32_t SAFE_CHUNK = 32768u; + +class KernelNgramSpecDecode { +public: + __aicore__ inline KernelNgramSpecDecode() {} + + __aicore__ inline void Init( + GM_ADDR token_ids_gm, GM_ADDR num_tokens_gm, GM_ADDR sampled_gm, + GM_ADDR discard_gm, GM_ADDR next_tokens_gm, GM_ADDR draft_tokens_gm, + GM_ADDR num_valid_gm, GM_ADDR workspace, GM_ADDR tiling) + { + REGISTER_TILING_DEFAULT(NgramSpecDecodeTilingData); + GET_TILING_DATA_WITH_STRUCT(NgramSpecDecodeTilingData, tilingData, tiling); + + this->batch_size = static_cast(tilingData.ngramInfo.batchSize); + this->max_seq_len = static_cast(tilingData.ngramInfo.maxSeqLen); + this->max_new_tokens = static_cast(tilingData.ngramInfo.maxNewTokens); + this->vocab_size_val = static_cast(tilingData.ngramInfo.vocabSize); + this->min_n_val = static_cast(tilingData.ngramInfo.minN); + this->max_n_val = static_cast(tilingData.ngramInfo.maxN); + this->k_val = static_cast(tilingData.ngramInfo.k); + this->former_num = static_cast(tilingData.ngramInfo.formerNum); + this->rows_per_core = static_cast(tilingData.ngramInfo.rowsPerCore); + this->tail_rows = static_cast(tilingData.ngramInfo.tailRows); + this->block_rows = static_cast(tilingData.ngramInfo.blockRows); + + int32_t align_elems = 32 / ELEM_SIZE; // = 8 + this->max_seq_len_align = ((this->max_seq_len + align_elems - 1) / align_elems) * align_elems; + this->max_new_tokens_align = ((this->max_new_tokens + align_elems - 1) / align_elems) * align_elems; + this->k_align = ((this->k_val + align_elems - 1) / align_elems) * align_elems; + + this->is_large_row = (this->max_seq_len_align > static_cast(SAFE_CHUNK)); + + uint32_t blockIdx = AscendC::GetBlockIdx(); + if (blockIdx < static_cast(this->former_num)) { + this->my_rows = static_cast(this->rows_per_core); + this->row_offset = static_cast(this->rows_per_core) * blockIdx; + } else { + this->my_rows = static_cast(this->tail_rows); + this->row_offset = static_cast(this->rows_per_core) * static_cast(this->former_num); + } + + tokenGm.SetGlobalBuffer((__gm__ int32_t *)token_ids_gm, + static_cast(this->batch_size) * this->max_seq_len); + numTokensGm.SetGlobalBuffer((__gm__ int32_t *)num_tokens_gm, + static_cast(this->batch_size)); + sampledGm.SetGlobalBuffer((__gm__ int32_t *)sampled_gm, + static_cast(this->batch_size) * this->max_new_tokens); + discardGm.SetGlobalBuffer((__gm__ int32_t *)discard_gm, + static_cast(this->batch_size)); + nextTokensGm.SetGlobalBuffer((__gm__ int32_t *)next_tokens_gm, + static_cast(this->batch_size)); + draftTokensGm.SetGlobalBuffer((__gm__ int32_t *)draft_tokens_gm, + static_cast(this->batch_size) * this->k_val); + numValidGm.SetGlobalBuffer((__gm__ int32_t *)num_valid_gm, + static_cast(this->batch_size)); + + uint32_t br = static_cast(this->block_rows); + uint32_t br_align = ((br * ELEM_SIZE + 31) / 32) * 32 / ELEM_SIZE; + + if (!this->is_large_row) { + pipe.InitBuffer(tokenTileBuf, br * static_cast(this->max_seq_len_align) * ELEM_SIZE); + } else { + uint32_t chunk_ub = SAFE_CHUNK + static_cast(this->max_n_val); + uint32_t chunk_ub_align = ((chunk_ub + 7u) / 8u) * 8u; + pipe.InitBuffer(tokenTileBuf, chunk_ub_align * ELEM_SIZE); + } + + uint32_t mask_bytes = ((SAFE_CHUNK + 7u) / 8u); + pipe.InitBuffer(maskBuf, mask_bytes); + + pipe.InitBuffer(sampledTileBuf, br * static_cast(this->max_new_tokens_align) * ELEM_SIZE); + pipe.InitBuffer(numTokensBuf, br_align * ELEM_SIZE); + pipe.InitBuffer(discardTileBuf, br_align * ELEM_SIZE); + pipe.InitBuffer(nextTokenBuf, br_align * ELEM_SIZE); + pipe.InitBuffer(draftBuf, br * static_cast(this->k_align) * ELEM_SIZE); + pipe.InitBuffer(numValidBuf, br_align * ELEM_SIZE); + pipe.InitBuffer(suffixBuf, static_cast(this->max_n_val) * ELEM_SIZE); + } + + __aicore__ inline void Process() + { + uint32_t remaining = this->my_rows; + uint32_t cur_offset = 0; + while (remaining > 0) { + uint32_t cur_rows = (remaining > static_cast(this->block_rows)) + ? static_cast(this->block_rows) : remaining; + if (this->is_large_row) { + ProcessChunkedRows(this->row_offset + cur_offset, cur_rows); + } else { + CopyIn(this->row_offset + cur_offset, cur_rows); + Compute(cur_rows); + CopyOut(this->row_offset + cur_offset, cur_rows); + } + cur_offset += cur_rows; + remaining -= cur_rows; + } + } + +private: + + __aicore__ inline void ProcessChunkedRows(uint32_t start_row, uint32_t rows) + { + uint32_t msl = static_cast(this->max_seq_len); + uint32_t mnta = static_cast(this->max_new_tokens_align); + uint32_t ka = static_cast(this->k_align); + + auto sampledLocal = sampledTileBuf.Get(); + auto numTokensLocal = numTokensBuf.Get(); + auto discardLocal = discardTileBuf.Get(); + auto nextLocal = nextTokenBuf.Get(); + auto draftLocal = draftBuf.Get(); + auto numValidLocal = numValidBuf.Get(); + auto suffixLocal = suffixBuf.Get(); + auto tokenLocal = tokenTileBuf.Get(); + auto maskLocal = maskBuf.Get(); + + uint32_t metaBytes = rows * ELEM_SIZE; + AscendC::DataCopyExtParams metaParams{1, metaBytes, 0, metaBytes, 0}; + AscendC::DataCopyPadExtParams noPadT{false, 0, 0, 0}; + AscendC::DataCopyPad(numTokensLocal, numTokensGm[start_row], metaParams, noPadT); + AscendC::DataCopyPad(discardLocal, discardGm[start_row], metaParams, noPadT); + + uint32_t srcRowBytes2 = static_cast(this->max_new_tokens) * ELEM_SIZE; + uint32_t dstRowBytes2 = mnta * ELEM_SIZE; + AscendC::DataCopyExtParams sampledParams{1, srcRowBytes2, 0, dstRowBytes2, 0}; + AscendC::DataCopyPadExtParams sampledPad{ + false, 0, static_cast(mnta - this->max_new_tokens), 0}; + for (uint32_t r = 0; r < rows; ++r) { + AscendC::DataCopyPad(sampledLocal[static_cast(r) * mnta], + sampledGm[static_cast(start_row + r) * this->max_new_tokens], + sampledParams, sampledPad); + } + + for (uint32_t i = 0; i < rows; ++i) { + uint64_t gmRow = static_cast(start_row + i) * msl; + int32_t seq_len = numTokensLocal.GetValue(i); + int32_t discard = discardLocal.GetValue(i); + int32_t valid_count = 0; + + int32_t backup_pos = (seq_len > 0) ? (seq_len - 1) : 0; + + for (int32_t j = 0; j < this->max_new_tokens; ++j) { + int32_t val = sampledLocal.GetValue(i * mnta + j); + if (discard != 0) { + sampledLocal.SetValue(i * mnta + j, -1); + } else if (val != -1 && val < this->vocab_size_val) { + valid_count++; + } else { + sampledLocal.SetValue(i * mnta + j, -1); + } + } + + int32_t avail_space = this->max_seq_len - seq_len; + if (avail_space < 0) avail_space = 0; + if (valid_count > avail_space) valid_count = avail_space; + + LoadGmElements(gmRow + backup_pos, 1); + int32_t backup_token = tokenLocal.GetValue(0); + + if (valid_count > 0) { + nextLocal.SetValue(i, sampledLocal.GetValue(i * mnta + valid_count - 1)); + } else { + nextLocal.SetValue(i, backup_token); + } + + int32_t nt = seq_len + valid_count; + if (valid_count > 0) { + for (int32_t j = 0; j < valid_count; ++j) { + tokenLocal.SetValue(j, sampledLocal.GetValue(i * mnta + j)); + } + StoreGmElements(gmRow + seq_len, valid_count); + } + + int32_t best_match_pos = -1; + int32_t best_ngram_len = 0; + + if (valid_count > 0 && nt >= this->min_n_val) { + int32_t suffix_gm_start = nt - this->max_n_val; + if (suffix_gm_start < 0) suffix_gm_start = 0; + LoadGmElements(gmRow + suffix_gm_start, this->max_n_val); + for (int32_t s = 0; s < this->max_n_val; ++s) { + suffixLocal.SetValue(static_cast(s), tokenLocal.GetValue(static_cast(s))); + } + + for (int32_t ngram_len = this->min_n_val; ngram_len <= this->max_n_val; ++ngram_len) { + if (ngram_len > nt) break; + int32_t wc = nt - ngram_len; + if (wc <= 0) break; + + int32_t suffix_offset = this->max_n_val - ngram_len; + int32_t suffix0 = suffixLocal.GetValue(static_cast(suffix_offset)); + + for (int32_t chunk_start = 0; chunk_start < wc; chunk_start += SAFE_CHUNK) { + int32_t chunk_count = (chunk_start + SAFE_CHUNK <= wc) ? SAFE_CHUNK : (wc - chunk_start); + int32_t load_count = chunk_count + (ngram_len - 1); + if (chunk_start + load_count > nt) load_count = nt - chunk_start; + LoadGmElements(gmRow + chunk_start, load_count); + + uint32_t cmp_count = ((static_cast(chunk_count) + 63u) / 64u) * 64u; + uint32_t max_cmp = SAFE_CHUNK > 8192u ? 8192u : SAFE_CHUNK; + if (cmp_count > max_cmp) cmp_count = max_cmp; + if (cmp_count > static_cast(load_count)) { + cmp_count = ((static_cast(load_count) + 63u) / 64u) * 64u; + } + + for (uint32_t cmp_off = 0; cmp_off < static_cast(chunk_count); cmp_off += cmp_count) { + uint32_t rem = static_cast(chunk_count) - cmp_off; + uint32_t elements = (rem >= cmp_count) ? cmp_count : rem; + uint32_t aligned = ((elements + 63u) / 64u) * 64u; + + AscendC::CompareScalar( + maskLocal, tokenLocal[cmp_off], + suffix0, AscendC::CMPMODE::EQ, aligned); + + for (uint32_t p = 0; p < elements; ++p) { + uint8_t bv = maskLocal.GetValue(p >> 3); + if (bv & (1u << (p & 7u))) { + bool all_match = true; + for (int32_t s = 1; s < ngram_len; ++s) { + int32_t sv = suffixLocal.GetValue(static_cast(suffix_offset + s)); + if (cmp_off + p + s < static_cast(load_count)) { + int32_t tv = tokenLocal.GetValue(cmp_off + p + static_cast(s)); + if (tv != sv) { all_match = false; break; } + } else { + all_match = false; break; + } + } + if (all_match) { + best_match_pos = chunk_start + static_cast(cmp_off + p); + best_ngram_len = ngram_len; + break; + } + } + } + if (best_match_pos >= 0) break; + } + if (best_match_pos >= 0) break; + } + if (best_match_pos >= 0) break; + } + } + + if (best_match_pos >= 0) { + int32_t draft_start = best_match_pos + best_ngram_len; + int32_t tokens_available = nt - draft_start; + int32_t draft_load = (tokens_available < this->k_val) ? tokens_available : this->k_val; + if (draft_load > 0) { + LoadGmElements(gmRow + draft_start, draft_load); + for (int32_t j = 0; j < this->k_val; ++j) { + if (j < draft_load) { + draftLocal.SetValue(i * ka + j, tokenLocal.GetValue(static_cast(j))); + } else { + draftLocal.SetValue(i * ka + j, -1); + } + } + } else { + for (int32_t j = 0; j < this->k_val; ++j) { + draftLocal.SetValue(i * ka + j, -1); + } + } + } else { + for (int32_t j = 0; j < this->k_val; ++j) { + draftLocal.SetValue(i * ka + j, -1); + } + } + + int32_t valid_draft_count = 0; + for (int32_t j = 0; j < this->k_val; ++j) { + if (draftLocal.GetValue(i * ka + j) != -1) { + valid_draft_count++; + } else { + break; + } + } + numValidLocal.SetValue(i, valid_draft_count); + } + + uint32_t metaBytes32 = static_cast(rows) * ELEM_SIZE; + AscendC::DataCopyExtParams nextParams{1, metaBytes32, 0, 0, 0}; + AscendC::DataCopyPad(nextTokensGm[start_row], nextLocal, nextParams); + + uint32_t kBytes = static_cast(this->k_val) * ELEM_SIZE; + for (uint32_t r = 0; r < rows; ++r) { + AscendC::DataCopyExtParams draftRowParams{1, kBytes, 0, 0, 0}; + AscendC::DataCopyPad( + draftTokensGm[static_cast(start_row + r) * this->k_val], + draftLocal[static_cast(r) * this->k_align], draftRowParams); + } + + AscendC::DataCopyPad(numValidGm[start_row], numValidLocal, nextParams); + } + + __aicore__ inline void LoadGmElements(uint64_t gm_offset, int32_t count) + { + if (count <= 0) return; + auto tokenLocal = tokenTileBuf.Get(); + uint32_t c = static_cast(count); + uint32_t aligned = ((c + 7u) / 8u) * 8u; + uint8_t pad = static_cast(aligned - c); + AscendC::DataCopyExtParams p{1, c * ELEM_SIZE, 0, aligned * ELEM_SIZE, 0}; + AscendC::DataCopyPadExtParams pp{false, 0, pad, 0}; + AscendC::DataCopyPad(tokenLocal[0], tokenGm[gm_offset], p, pp); + } + + __aicore__ inline void StoreGmElements(uint64_t gm_offset, int32_t count) + { + if (count <= 0) return; + auto tokenLocal = tokenTileBuf.Get(); + constexpr uint32_t STORE_MAX = 16383u; + uint32_t c = static_cast(count); + for (uint32_t off = 0; off < c; off += STORE_MAX) { + uint32_t chunk = (off + STORE_MAX <= c) ? STORE_MAX : (c - off); + AscendC::DataCopyExtParams p{1, chunk * ELEM_SIZE, 0, 0, 0}; + AscendC::DataCopyPad(tokenGm[gm_offset + off], tokenLocal[off], p); + } + } + + + __aicore__ inline void CopyIn(uint32_t start_row, uint32_t rows) + { + uint32_t msa = static_cast(this->max_seq_len_align); + uint32_t mnta = static_cast(this->max_new_tokens_align); + constexpr uint32_t MAX_CHUNK_ELEMS = 8192u; + + auto tokenLocal = tokenTileBuf.Get(); + uint32_t msl = static_cast(this->max_seq_len); + for (uint32_t r = 0; r < rows; ++r) { + uint64_t gmRow = static_cast(start_row + r) * msl; + uint32_t ubRow = r * msa; + for (uint32_t off = 0; off < msl; off += MAX_CHUNK_ELEMS) { + uint32_t chunk = (off + MAX_CHUNK_ELEMS <= msl) ? MAX_CHUNK_ELEMS : (msl - off); + uint32_t isLast = (off + chunk >= msl) ? 1u : 0u; + uint32_t dstChunk = isLast ? (msa - off) : MAX_CHUNK_ELEMS; + uint8_t pad = static_cast(dstChunk - chunk); + AscendC::DataCopyExtParams p{1, chunk * ELEM_SIZE, 0, dstChunk * ELEM_SIZE, 0}; + AscendC::DataCopyPadExtParams pp{false, 0, pad, 0}; + AscendC::DataCopyPad(tokenLocal[ubRow + off], tokenGm[gmRow + off], p, pp); + } + } + + auto sampledLocal = sampledTileBuf.Get(); + uint32_t srcRowBytes2 = static_cast(this->max_new_tokens) * ELEM_SIZE; + uint32_t dstRowBytes2 = mnta * ELEM_SIZE; + AscendC::DataCopyExtParams sampledParams{1, srcRowBytes2, 0, dstRowBytes2, 0}; + AscendC::DataCopyPadExtParams sampledPad{ + false, 0, static_cast(mnta - this->max_new_tokens), 0}; + for (uint32_t r = 0; r < rows; ++r) { + AscendC::DataCopyPad(sampledLocal[static_cast(r) * mnta], + sampledGm[static_cast(start_row + r) * this->max_new_tokens], + sampledParams, sampledPad); + } + + auto numTokensLocal = numTokensBuf.Get(); + uint32_t metaBytes = static_cast(rows) * ELEM_SIZE; + AscendC::DataCopyExtParams metaParams{1, metaBytes, 0, metaBytes, 0}; + AscendC::DataCopyPadExtParams noPadT{false, 0, 0, 0}; + AscendC::DataCopyPad(numTokensLocal, numTokensGm[start_row], metaParams, noPadT); + + auto discardLocal = discardTileBuf.Get(); + AscendC::DataCopyPad(discardLocal, discardGm[start_row], metaParams, noPadT); + } + + __aicore__ inline void Compute(uint32_t rows) + { + auto tokenLocal = tokenTileBuf.Get(); + auto sampledLocal = sampledTileBuf.Get(); + auto numTokensLocal = numTokensBuf.Get(); + auto discardLocal = discardTileBuf.Get(); + auto nextLocal = nextTokenBuf.Get(); + auto draftLocal = draftBuf.Get(); + auto numValidLocal = numValidBuf.Get(); + auto suffixLocal = suffixBuf.Get(); + auto maskLocal = maskBuf.Get(); + + for (uint32_t i = 0; i < rows; ++i) { + ComputeOneRow(i, tokenLocal, sampledLocal, numTokensLocal, + discardLocal, nextLocal, draftLocal, numValidLocal, + suffixLocal, maskLocal); + } + } + + __aicore__ inline void ComputeOneRow( + uint32_t idx, + AscendC::LocalTensor &tokenLocal, + AscendC::LocalTensor &sampledLocal, + AscendC::LocalTensor &numTokensLocal, + AscendC::LocalTensor &discardLocal, + AscendC::LocalTensor &nextLocal, + AscendC::LocalTensor &draftLocal, + AscendC::LocalTensor &numValidLocal, + AscendC::LocalTensor &suffixLocal, + AscendC::LocalTensor &maskLocal) + { + uint32_t msa = this->max_seq_len_align; + uint32_t mnta = this->max_new_tokens_align; + uint32_t ka = this->k_align; + + int32_t seq_len = numTokensLocal.GetValue(idx); + int32_t discard = discardLocal.GetValue(idx); + int32_t valid_count = 0; + + int32_t backup_pos = (seq_len > 0) ? (seq_len - 1) : 0; + int32_t backup_token = tokenLocal.GetValue(idx * msa + backup_pos); + + for (int32_t j = 0; j < this->max_new_tokens; ++j) { + int32_t val = sampledLocal.GetValue(idx * mnta + j); + if (discard != 0) { + sampledLocal.SetValue(idx * mnta + j, -1); + } else if (val != -1 && val < this->vocab_size_val) { + valid_count++; + } else { + sampledLocal.SetValue(idx * mnta + j, -1); + } + } + + int32_t avail_space = this->max_seq_len - seq_len; + if (avail_space < 0) avail_space = 0; + if (valid_count > avail_space) valid_count = avail_space; + + if (valid_count > 0) { + nextLocal.SetValue(idx, sampledLocal.GetValue(idx * mnta + valid_count - 1)); + } else { + nextLocal.SetValue(idx, backup_token); + } + + int32_t num_tokens_tmp = seq_len + valid_count; + for (int32_t j = 0; j < valid_count; ++j) { + tokenLocal.SetValue(idx * msa + seq_len + j, sampledLocal.GetValue(idx * mnta + j)); + } + + int32_t best_match_pos = -1; + int32_t best_ngram_len = 0; + + if (valid_count > 0 && num_tokens_tmp >= this->min_n_val) { + if (this->block_rows <= 1) { + int32_t nt = num_tokens_tmp; + constexpr uint32_t CMP_MAX = 8192u; + + for (int32_t ngram_len = this->min_n_val; ngram_len <= this->max_n_val; ++ngram_len) { + if (ngram_len > nt) break; + int32_t wc = nt - ngram_len; + if (wc <= 0) break; + + int32_t suffix0 = tokenLocal.GetValue(static_cast(nt - ngram_len)); + uint32_t msa_cmp = static_cast(msa); + + for (int32_t cmp_off = 0; cmp_off < wc; cmp_off += CMP_MAX) { + uint32_t remaining = static_cast(wc - cmp_off); + uint32_t elements = (remaining >= CMP_MAX) ? CMP_MAX : remaining; + uint32_t count_aligned = ((elements + 63u) / 64u) * 64u; + uint32_t buf_avail = msa_cmp - static_cast(cmp_off); + if (count_aligned > buf_avail) { + count_aligned = (buf_avail / 64u) * 64u; + } + + if (count_aligned == 0) { + for (int32_t p = 0; p < static_cast(elements); ++p) { + if (tokenLocal.GetValue(static_cast(cmp_off + p)) == suffix0) { + bool all_match = true; + for (int32_t s = 1; s < ngram_len; ++s) { + int32_t sv = tokenLocal.GetValue(static_cast(nt - ngram_len + s)); + int32_t tv = tokenLocal.GetValue(static_cast(cmp_off + p + s)); + if (tv != sv) { all_match = false; break; } + } + if (all_match) { + best_match_pos = cmp_off + p; + best_ngram_len = ngram_len; + break; + } + } + } + } else { + AscendC::CompareScalar( + maskLocal, tokenLocal[static_cast(cmp_off)], + suffix0, AscendC::CMPMODE::EQ, count_aligned); + + for (int32_t p = 0; p < static_cast(elements); ++p) { + uint8_t byte_val = maskLocal.GetValue(static_cast(p) >> 3); + if (byte_val & (1u << (static_cast(p) & 7u))) { + bool all_match = true; + for (int32_t s = 1; s < ngram_len; ++s) { + int32_t sv = tokenLocal.GetValue(static_cast(nt - ngram_len + s)); + int32_t tv = tokenLocal.GetValue(static_cast(cmp_off + p + s)); + if (tv != sv) { all_match = false; break; } + } + if (all_match) { + best_match_pos = cmp_off + p; + best_ngram_len = ngram_len; + break; + } + } + } + } + if (best_match_pos >= 0) break; + } + } + } else { + int32_t row_base = static_cast(idx) * static_cast(msa); + + for (int32_t ngram_len = this->min_n_val; ngram_len <= this->max_n_val; ++ngram_len) { + if (ngram_len > num_tokens_tmp) break; + + for (int32_t s = 0; s < ngram_len; ++s) { + suffixLocal.SetValue(static_cast(s), + tokenLocal.GetValue(static_cast( + row_base + num_tokens_tmp - ngram_len + s))); + } + + int32_t max_pos = num_tokens_tmp - ngram_len - 1; + for (int32_t pos = 0; pos <= max_pos; ++pos) { + bool match = true; + for (int32_t s = 0; s < ngram_len; ++s) { + if (tokenLocal.GetValue(static_cast(row_base + pos + s)) + != suffixLocal.GetValue(static_cast(s))) { + match = false; + break; + } + } + if (match) { + best_match_pos = pos; + best_ngram_len = ngram_len; + break; + } + } + } + } + } + + if (best_match_pos >= 0) { + int32_t draft_start = best_match_pos + best_ngram_len; + int32_t tokens_available = num_tokens_tmp - draft_start; + for (int32_t j = 0; j < this->k_val; ++j) { + if (j < tokens_available) { + draftLocal.SetValue(idx * ka + j, tokenLocal.GetValue(idx * msa + draft_start + j)); + } else { + draftLocal.SetValue(idx * ka + j, -1); + } + } + } else { + for (int32_t j = 0; j < this->k_val; ++j) { + draftLocal.SetValue(idx * ka + j, -1); + } + } + + int32_t valid_draft_count = 0; + for (int32_t j = 0; j < this->k_val; ++j) { + if (draftLocal.GetValue(idx * ka + j) != -1) { + valid_draft_count++; + } else { + break; + } + } + numValidLocal.SetValue(idx, valid_draft_count); + } + + __aicore__ inline void CopyOut(uint32_t start_row, uint32_t rows) + { + uint32_t msa = static_cast(this->max_seq_len_align); + uint32_t msl = static_cast(this->max_seq_len); + constexpr uint32_t OUT_CHUNK_ELEMS = 8192u; + + auto tokenLocal = tokenTileBuf.Get(); + for (uint32_t r = 0; r < rows; ++r) { + uint64_t gmRow = static_cast(start_row + r) * msl; + uint32_t ubRow = r * msa; + for (uint32_t off = 0; off < msl; off += OUT_CHUNK_ELEMS) { + uint32_t chunk = (off + OUT_CHUNK_ELEMS <= msl) ? OUT_CHUNK_ELEMS : (msl - off); + AscendC::DataCopyExtParams p{1, chunk * ELEM_SIZE, 0, 0, 0}; + AscendC::DataCopyPad(tokenGm[gmRow + off], tokenLocal[ubRow + off], p); + } + } + + auto nextLocal = nextTokenBuf.Get(); + uint32_t metaBytes32 = static_cast(rows) * ELEM_SIZE; + AscendC::DataCopyExtParams nextParams{1, metaBytes32, 0, 0, 0}; + AscendC::DataCopyPad(nextTokensGm[start_row], nextLocal, nextParams); + + auto draftLocal = draftBuf.Get(); + uint32_t kBytes = static_cast(this->k_val) * ELEM_SIZE; + for (uint32_t r = 0; r < rows; ++r) { + AscendC::DataCopyExtParams draftRowParams{1, kBytes, 0, 0, 0}; + AscendC::DataCopyPad( + draftTokensGm[static_cast(start_row + r) * this->k_val], + draftLocal[static_cast(r) * this->k_align], draftRowParams); + } + + auto numValidLocal = numValidBuf.Get(); + AscendC::DataCopyPad(numValidGm[start_row], numValidLocal, nextParams); + } + +private: + AscendC::TPipe pipe; + AscendC::TBuf tokenTileBuf; + AscendC::TBuf sampledTileBuf; + AscendC::TBuf numTokensBuf; + AscendC::TBuf discardTileBuf; + AscendC::TBuf nextTokenBuf; + AscendC::TBuf draftBuf; + AscendC::TBuf numValidBuf; + AscendC::TBuf suffixBuf; + AscendC::TBuf maskBuf; + + AscendC::GlobalTensor tokenGm; + AscendC::GlobalTensor numTokensGm; + AscendC::GlobalTensor sampledGm; + AscendC::GlobalTensor discardGm; + AscendC::GlobalTensor nextTokensGm; + AscendC::GlobalTensor draftTokensGm; + AscendC::GlobalTensor numValidGm; + + int32_t batch_size; + int32_t max_seq_len; + int32_t max_seq_len_align; + int32_t max_new_tokens; + int32_t max_new_tokens_align; + int32_t k_val; + int32_t k_align; + int32_t vocab_size_val; + int32_t min_n_val; + int32_t max_n_val; + int32_t former_num; + int32_t rows_per_core; + int32_t tail_rows; + int32_t block_rows; + uint32_t my_rows; + uint32_t row_offset; + bool is_large_row; +}; + +extern "C" __global__ __aicore__ void ngram_spec_decode( + GM_ADDR token_ids, GM_ADDR num_tokens, GM_ADDR sampled, + GM_ADDR discard, GM_ADDR next_tokens, GM_ADDR draft_tokens, + GM_ADDR num_valid, GM_ADDR workspace, GM_ADDR tiling) +{ + KernelNgramSpecDecode op; + op.Init(token_ids, num_tokens, sampled, discard, next_tokens, + draft_tokens, num_valid, workspace, tiling); + op.Process(); +} diff --git a/csrc/ngram_spec_decode/op_kernel/ngram_spec_decode_tiling.h b/csrc/ngram_spec_decode/op_kernel/ngram_spec_decode_tiling.h new file mode 100644 index 00000000000..27847391458 --- /dev/null +++ b/csrc/ngram_spec_decode/op_kernel/ngram_spec_decode_tiling.h @@ -0,0 +1,26 @@ +#ifndef NGRAM_SPEC_DECODE_TILING_H +#define NGRAM_SPEC_DECODE_TILING_H + +#include "kernel_tiling/kernel_tiling.h" + +struct NgramSpecDecodeInfo { + uint32_t batchSize; + uint32_t maxSeqLen; + uint32_t maxNewTokens; + uint32_t vocabSize; + uint32_t minN; + uint32_t maxN; + uint32_t k; + uint32_t formerNum; + uint32_t rowsPerCore; + uint32_t tailRows; + uint32_t blockRows; +}; + +struct NgramSpecDecodeTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling1; + NgramSpecDecodeInfo ngramInfo; +}; + +#endif // NGRAM_SPEC_DECODE_TILING_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index e80bccfb004..04254fc2f82 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -44,6 +44,7 @@ #include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h" #include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h" #include "lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h" +#include "ngram_spec_decode/ngram_spec_decode_torch_adpt.h" #include "causal_conv1d_v310/causal_conv1d_310_torch_adpt.h" #include "recurrent_gated_delta_rule_v310/recurrent_gated_delta_rule_310_torch_adpt.h" #include @@ -1279,5 +1280,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int sparse_count=2048, int sparse_mode=3) -> Tensor" ); ops.impl("npu_lightning_indexer_quant", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer_quant); + + // N-gram spec decode + ops.def( + "npu_ngram_spec_decode(Tensor(a!) token_ids, Tensor num_tokens_no_spec, " + "Tensor sampled_token_ids, Tensor discard_request_mask, " + "int vocab_size, int min_n, int max_n, int k) -> " + "(Tensor token_ids, Tensor next_token_ids, Tensor draft_token_ids, Tensor num_valid_draft_tokens)" + ); + ops.impl("npu_ngram_spec_decode", torch::kPrivateUse1, + &vllm_ascend::npu_ngram_spec_decode); } #endif diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 028aaae6c3b..1974a6c0f99 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -665,6 +665,24 @@ at::Tensor npu_lightning_indexer_quant_meta( return lightning_indexer_quant_output; } +// N-gram spec decode meta +std::tuple npu_ngram_spec_decode_meta( + at::Tensor &token_ids, + const at::Tensor &num_tokens_no_spec, + const at::Tensor &sampled_token_ids, + const at::Tensor &discard_request_mask, + int64_t vocab_size, + int64_t min_n, + int64_t max_n, + int64_t k) +{ + int64_t batch_size = token_ids.size(0); + at::Tensor next_token_ids = at::empty({batch_size}, token_ids.options()); + at::Tensor draft_token_ids = at::empty({batch_size, k}, token_ids.options()); + at::Tensor num_valid_draft_tokens = at::empty({batch_size}, token_ids.options()); + return std::make_tuple(token_ids, next_token_ids, draft_token_ids, num_valid_draft_tokens); +} + } // namespace meta } // namespace vllm_ascend @@ -736,6 +754,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta); // Lightning indexer quant ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta); + // N-gram spec decode + ops.impl("npu_ngram_spec_decode", &vllm_ascend::meta::npu_ngram_spec_decode_meta); } } #endif diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_ngram_spec_decode.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_ngram_spec_decode.py new file mode 100644 index 00000000000..599947309f4 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_ngram_spec_decode.py @@ -0,0 +1,503 @@ +"""E2E accuracy test for NgramSpecDecode custom operator. + +Tests the Ascend C kernel against a CPU golden reference implementation +with parametrized test cases covering various configurations. +""" + +import time +import numpy as np +import pytest +import torch + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +SEED = 42 +PERF_WARMUP = 3 +PERF_ITERS = 20 + + +# --------------------------------------------------------------------------- +# Golden reference (CPU, pure Python/NumPy) +# --------------------------------------------------------------------------- + +def golden_ngram_spec_decode( + token_ids: np.ndarray, # [B, M], int32, + num_tokens_no_spec: np.ndarray, # [B], int32 + sampled_token_ids: np.ndarray, # [B, N], int32 + discard_request_mask: np.ndarray, # [B], int32 + vocab_size: int, + min_n: int, + max_n: int, + k: int, +): + """CPU golden reference for NgramSpecDecode. + + Returns: + (token_ids_modified, next_token_ids, draft_token_ids, num_valid_draft_tokens) + """ + B = token_ids.shape[0] + M = token_ids.shape[1] + next_token_ids = np.zeros(B, dtype=np.int32) + draft_token_ids = np.full((B, k), -1, dtype=np.int32) + num_valid_draft_tokens = np.zeros(B, dtype=np.int32) + + for i in range(B): + seq_len = int(num_tokens_no_spec[i]) + discard = int(discard_request_mask[i]) + valid_count = 0 + + # Stage 1: sample token valid + backup_pos = max(seq_len - 1, 0) + backup_token = int(token_ids[i, backup_pos]) + + for j in range(sampled_token_ids.shape[1]): + val = int(sampled_token_ids[i, j]) + if discard != 0: + sampled_token_ids[i, j] = -1 + elif val != -1 and val < vocab_size: + valid_count += 1 + else: + sampled_token_ids[i, j] = -1 + + avail_space = M - seq_len + if avail_space < 0: + avail_space = 0 + if valid_count > avail_space: + valid_count = avail_space + + if valid_count > 0: + next_token_ids[i] = int(sampled_token_ids[i, valid_count - 1]) + else: + next_token_ids[i] = backup_token + + # Stage 2: scatter sampled token to token_ids tail + nt = seq_len + valid_count + for j in range(valid_count): + token_ids[i, seq_len + j] = int(sampled_token_ids[i, j]) + + # Stage 3: suffix n-gram match + best_match_pos = -1 + best_ngram_len = 0 + + if valid_count > 0 and nt >= min_n: + for ngram_len in range(min_n, max_n + 1): + if ngram_len > nt: + break + wc = nt - ngram_len + if wc <= 0: + break + + suffix = token_ids[i, nt - ngram_len: nt].tolist() + found = False + for pos in range(wc): + window = token_ids[i, pos: pos + ngram_len].tolist() + if window == suffix: + best_match_pos = pos + best_ngram_len = ngram_len + found = True + break + if found: + break + + # Stage 4: get draft tokens + if best_match_pos >= 0: + draft_start = best_match_pos + best_ngram_len + tokens_available = nt - draft_start + for j in range(k): + if j < tokens_available: + draft_token_ids[i, j] = int(token_ids[i, draft_start + j]) + else: + draft_token_ids[i, j] = -1 + # else: init to -1 + + # static valid draft token + valid_draft_count = 0 + for j in range(k): + if draft_token_ids[i, j] != -1: + valid_draft_count += 1 + else: + break + num_valid_draft_tokens[i] = valid_draft_count + + return token_ids, next_token_ids, draft_token_ids, num_valid_draft_tokens + + +# --------------------------------------------------------------------------- +# inputs construct helper +# --------------------------------------------------------------------------- + +def _make_inputs( + batch_size: int, + seq_len: int, + max_new_tokens: int, + k: int, + vocab_size: int = 32000, + min_n: int = 3, + max_n: int = 5, + discard_rate: float = 0.0, + invalid_rate: float = 0.0, + seed: int = SEED, +): + """ + + Returns: + (token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + """ + rng = np.random.RandomState(seed) + token_ids = rng.randint(0, vocab_size, size=(batch_size, seq_len), dtype=np.int32) + + max_valid_tokens = seq_len - max_new_tokens + if max_valid_tokens < 1: + max_valid_tokens = 1 + num_tokens_no_spec = rng.randint(1, max_valid_tokens + 1, size=(batch_size,), dtype=np.int32) + + sampled_token_ids = rng.randint(0, vocab_size, size=(batch_size, max_new_tokens), dtype=np.int32) + if invalid_rate > 0: + invalid_mask = rng.rand(batch_size, max_new_tokens) < invalid_rate + sampled_token_ids[invalid_mask] = -1 + + # discard_request_mask + discard_request_mask = np.zeros(batch_size, dtype=np.int32) + if discard_rate > 0: + discard_mask = rng.rand(batch_size) < discard_rate + discard_request_mask[discard_mask] = 1 + + return token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k + + +def _run_npu(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k): + token_ids_t = torch.from_numpy(token_ids).to("npu") + num_tokens_t = torch.from_numpy(num_tokens_no_spec).to("npu") + sampled_t = torch.from_numpy(sampled_token_ids).to("npu") + discard_t = torch.from_numpy(discard_request_mask).to("npu") + + result = torch.ops._C_ascend.npu_ngram_spec_decode( + token_ids_t, num_tokens_t, sampled_t, discard_t, + vocab_size, min_n, max_n, k) + torch.npu.synchronize() + + return result + + +def _measure_perf(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k): + token_ids_t = torch.from_numpy(token_ids.copy()).to("npu") + num_tokens_t = torch.from_numpy(num_tokens_no_spec.copy()).to("npu") + sampled_t = torch.from_numpy(sampled_token_ids.copy()).to("npu") + discard_t = torch.from_numpy(discard_request_mask.copy()).to("npu") + + for _ in range(PERF_WARMUP): + _ = torch.ops._C_ascend.npu_ngram_spec_decode( + token_ids_t, num_tokens_t, sampled_t, discard_t, + vocab_size, min_n, max_n, k) + torch.npu.synchronize() + + t0 = time.perf_counter() + for _ in range(PERF_ITERS): + _ = torch.ops._C_ascend.npu_ngram_spec_decode( + token_ids_t, num_tokens_t, sampled_t, discard_t, + vocab_size, min_n, max_n, k) + torch.npu.synchronize() + elapsed_us = (time.perf_counter() - t0) * 1e6 / PERF_ITERS + + print(f" [perf] B={token_ids.shape[0]} M={token_ids.shape[1]} N={sampled_token_ids.shape[1]} " + f"k={k} min_n={min_n} max_n={max_n} -> {elapsed_us:.1f} us/call", flush=True) + + +# =========================================================================== +# Group 1: basic - basic function +# =========================================================================== + +@pytest.mark.parametrize("batch_size,seq_len,max_new_tokens,k", [ + (1, 16, 4, 3), + (4, 64, 8, 5), + (16, 128, 16, 5), +]) +@torch.inference_mode() +def test_ngram_spec_decode_basic(batch_size, seq_len, max_new_tokens, k): + inputs = _make_inputs(batch_size, seq_len, max_new_tokens, k) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + # CPU golden + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + # NPU + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + # compare + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist(), \ + f"token_ids mismatch: B={batch_size}" + assert result_next.cpu().numpy().tolist() == golden_next.tolist(), \ + f"next_token_ids mismatch: B={batch_size}" + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist(), \ + f"draft_token_ids mismatch: B={batch_size}" + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist(), \ + f"num_valid_draft_tokens mismatch: B={batch_size}" + + _measure_perf(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + +# =========================================================================== +# Group 2: padding / optional +# =========================================================================== + +@pytest.mark.parametrize("batch_size,seq_len,max_new_tokens,k,invalid_rate", [ + (4, 64, 8, 5, 0.3), + (4, 64, 8, 5, 0.7), + (4, 128, 16, 5, 0.5), +]) +@torch.inference_mode() +def test_ngram_spec_decode_padding(batch_size, seq_len, max_new_tokens, k, invalid_rate): + inputs = _make_inputs(batch_size, seq_len, max_new_tokens, k, invalid_rate=invalid_rate) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + _measure_perf(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + +@pytest.mark.parametrize("discard_rate", [0.2, 0.5, 1.0]) +@torch.inference_mode() +def test_ngram_spec_decode_discard(discard_rate): + inputs = _make_inputs(4, 64, 8, 5, discard_rate=discard_rate) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + +# =========================================================================== +# Group 3: min_n / max_n / k +# =========================================================================== + +@pytest.mark.parametrize("min_n,max_n,k", [ + (1, 1, 3), + (2, 4, 5), + (1, 8, 10), + (3, 3, 1), + (5, 10, 8), +]) +@torch.inference_mode() +def test_ngram_spec_decode_attrs(min_n, max_n, k): + inputs = _make_inputs(4, 128, 16, k, min_n=min_n, max_n=max_n) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, _, _, _ = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + _measure_perf(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + +# =========================================================================== +# Group 4: large scale +# =========================================================================== + +@torch.inference_mode() +def test_ngram_spec_decode_prefill(): + inputs = _make_inputs(1, 2048, 16, 5, vocab_size=32000) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + _measure_perf(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + +@torch.inference_mode() +def test_ngram_spec_decode_decode(): + inputs = _make_inputs(64, 32, 5, 3, vocab_size=32000) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + _measure_perf(token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + +# =========================================================================== +# Group 5: boundary +# =========================================================================== + +@torch.inference_mode() +def test_ngram_spec_decode_minimal(): + inputs = _make_inputs(1, 4, 1, 1, vocab_size=100) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + +@torch.inference_mode() +def test_ngram_spec_decode_no_valid_sampled(): + inputs = _make_inputs(4, 64, 8, 5) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + sampled_token_ids[:] = -1 + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + +@torch.inference_mode() +def test_ngram_spec_decode_exact_match(): + B, M, N, k = 2, 32, 4, 3 + vocab_size = 1000 + token_ids = np.array([ + [1, 2, 3, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [10, 20, 10, 20, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], dtype=np.int32) + num_tokens_no_spec = np.array([6, 5], dtype=np.int32) + # sampled tokens + sampled_token_ids = np.array([ + [1, 2, 3, -1], + [10, 20, -1, -1], + ], dtype=np.int32) + discard_request_mask = np.array([0, 0], dtype=np.int32) + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, 3, 5, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, 3, 5, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + +@torch.inference_mode() +def test_ngram_spec_decode_full_capacity(): + inputs = _make_inputs(4, 48, 16, 5, min_n=2, max_n=4) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + num_tokens_no_spec[:] = token_ids.shape[1] - sampled_token_ids.shape[1] + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() + + +@torch.inference_mode() +def test_ngram_spec_decode_k1(): + inputs = _make_inputs(4, 64, 8, 1, min_n=2, max_n=3) + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, \ + vocab_size, min_n, max_n, k = inputs + + golden_ids, golden_next, golden_draft, golden_valid = golden_ngram_spec_decode( + token_ids.copy(), num_tokens_no_spec.copy(), sampled_token_ids.copy(), + discard_request_mask.copy(), vocab_size, min_n, max_n, k) + + result_ids, result_next, result_draft, result_valid = _run_npu( + token_ids, num_tokens_no_spec, sampled_token_ids, discard_request_mask, + vocab_size, min_n, max_n, k) + + assert result_ids.cpu().numpy().tolist() == golden_ids.tolist() + assert result_next.cpu().numpy().tolist() == golden_next.tolist() + assert result_draft.cpu().numpy().tolist() == golden_draft.tolist() + assert result_valid.cpu().numpy().tolist() == golden_valid.tolist() diff --git a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py index 2203a8e054c..99260780abe 100644 --- a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py @@ -156,6 +156,51 @@ def test_ngram_correctness( assert matches > int(0.66 * len(ref_outputs)) +def test_ngram_npu_async_correctness( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + """ + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram_npu speculative decoding + async. + """ + + with VllmRunner( + model_name, + max_model_len=1024, + cudagraph_capture_sizes=[1, 2, 4, 8], + ) as ref_llm: + ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) + + with VllmRunner( + model_name, + speculative_config={ + "method": "ngram_gpu", + "prompt_lookup_max": 2, + "prompt_lookup_min": 2, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + async_scheduling=True, + cudagraph_capture_sizes=[1, 2, 4, 8], + ) as runner: + spec_outputs = runner.model.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + + def test_qwen3_vl_eagle_correctness( test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index e0c56166d9f..fa27d72ea8a 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -26,12 +26,15 @@ ) from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer +from vllm_ascend.spec_decode.ngram_proposer_npu import AscendNgramProposerNPU from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer def get_spec_decode_method(method, vllm_config, device, runner): if method == "ngram": return AscendNgramProposer(vllm_config, runner) + elif method == "ngram_gpu": + return AscendNgramProposerNPU(vllm_config, device, runner) elif method == "suffix": return AscendSuffixDecodingProposer(vllm_config, runner) elif method == "medusa": diff --git a/vllm_ascend/spec_decode/ngram_proposer_npu.py b/vllm_ascend/spec_decode/ngram_proposer_npu.py new file mode 100644 index 00000000000..5247f93d505 --- /dev/null +++ b/vllm_ascend/spec_decode/ngram_proposer_npu.py @@ -0,0 +1,35 @@ +import torch +from vllm.v1.spec_decode.ngram_proposer_gpu import NgramProposerGPU + + +class AscendNgramProposerNPU(NgramProposerGPU): + def __init__(self, vllm_config, device: torch.device, runner): + super().__init__(vllm_config, device=device) + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run( + self, + num_tokens, + with_prefill=None, + in_graph_capturing=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode=None, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False, + ): + pass + + def propose( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] + token_ids_gpu: torch.Tensor, # [batch_size, max_len] + valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1] + valid_sampled_tokens_count: torch.Tensor, # [batch_size] + ): + pass \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 02c9496ce67..21e9098fa52 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -23,7 +23,7 @@ from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import copy, deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from multiprocessing import Manager from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias @@ -78,6 +78,7 @@ from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.ngram_proposer_gpu import copy_num_valid_draft_tokens from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext @@ -125,6 +126,7 @@ ) from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer +from vllm_ascend.spec_decode.ngram_proposer_npu import AscendNgramProposerNPU from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm_ascend.utils import ( @@ -240,6 +242,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): vllm_config.parallel_config.prefill_context_parallel_size * 2 * vllm_config.scheduler_config.max_num_seqs ) vllm_config.scheduler_config.max_num_batched_tokens += max_pcp_pad_tokens + with _torch_cuda_wrapper(): super().__init__(vllm_config, device) @@ -489,6 +492,7 @@ def _set_up_drafter(self): # Set up speculative decoding. self.drafter: ( AscendNgramProposer + | AscendNgramProposerNPU | AscendEagleProposer | AscendDraftModelProposer | AscendDflashProposer @@ -512,6 +516,7 @@ def _set_up_drafter(self): assert isinstance(self.drafter, AscendExtractHiddenStatesProposer) self.use_aux_hidden_state_outputs = True self.rejection_sampler = RejectionSampler(self.sampler) + self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0 @@ -1325,6 +1330,51 @@ def propose_draft_token_ids( draft_token_ids = None elif isinstance(self.drafter, (AscendNgramProposer, AscendSuffixDecodingProposer)): draft_token_ids = self.drafter.propose(valid_sampled_token_ids) + elif isinstance(self.drafter, AscendNgramProposerNPU): + batch_size = min(self.input_batch.num_reqs, self.token_ids_gpu_tensor.shape[0]) + + # prepare sampled_token_ids tensor(list → padded tensor) + sampled_token_ids = valid_sampled_token_ids + if isinstance(sampled_token_ids, list): + max_len = max((len(sublist) for sublist in sampled_token_ids), default=0) + max_len = max(max_len, 1) + padded_list = [ + sublist + [-1] * (max_len - len(sublist)) + for sublist in sampled_token_ids + ] + sampled_token_ids_tensor = torch.tensor( + padded_list, dtype=torch.int32, device=self.device + ) + else: + sampled_token_ids_tensor = sampled_token_ids + + (_token_ids, next_token_ids, draft_token_ids, + num_valid_draft_tokens) = torch.ops._C_ascend.npu_ngram_spec_decode( + self.token_ids_gpu_tensor[:batch_size], # [B, max_seq_len], in-place + self.num_tokens_no_spec_gpu[:batch_size], # [B] + sampled_token_ids_tensor[:batch_size], # [B, max_new_tokens] + self.discard_request_mask.gpu[:batch_size], # [B] + vocab_size=self.model_config.get_vocab_size(), + min_n=self.drafter.min_n, + max_n=self.drafter.max_n, + k=self.drafter.k, + ) + + # only async scheduling, set prev_sampled_token_ids, + if self.use_async_scheduling: + self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) + + # save num_valid_draft_tokens for scheduler trim + self._num_valid_draft_tokens = num_valid_draft_tokens + + # async D2H copy num_valid_draft_tokens + copy_num_valid_draft_tokens( + self._num_valid_draft_tokens_cpu, + self._num_valid_draft_tokens_copy_stream, + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens, + batch_size, + ) elif isinstance(self.drafter, AscendMedusaProposer): draft_token_ids = self.drafter.propose( valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states @@ -1482,6 +1532,36 @@ def propose_draft_token_ids( return draft_token_ids + def _copy_draft_token_ids_to_cpu( + self, scheduler_output: "SchedulerOutput", zeros_only: bool = False + ) -> None: + if not self.num_spec_tokens: + return + if self.use_async_scheduling and not ( + scheduler_output.has_structured_output_requests + or self.input_batch.sampling_metadata.output_token_ids + ): + return + self._draft_token_req_ids = self.input_batch.req_ids.copy() + + draft_token_ids: torch.Tensor = self._draft_token_ids # type: ignore[has-type] + if not torch.is_tensor(draft_token_ids): + return + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_copy_stream is not None + assert self.draft_token_ids_cpu is not None + default_stream = torch.npu.current_stream() + num_reqs = draft_token_ids.shape[0] + with torch.npu.stream(self.draft_token_ids_copy_stream): + if not zeros_only: + self.draft_token_ids_copy_stream.wait_stream(default_stream) + self.draft_token_ids_cpu[:num_reqs].copy_( + draft_token_ids, non_blocking=True + ) + else: + self.draft_token_ids_cpu[:num_reqs] = 0 + self.draft_token_ids_event.record() + @torch.inference_mode() def execute_model( self, @@ -1507,6 +1587,24 @@ def execute_model( self._execution_start_time = time.perf_counter() if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") + + # If ngram_gpu is used, we need to copy the scheduler_output to avoid + # the modification has influence on the scheduler_output in engine core process. + # The replace is much faster than deepcopy. + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy() + spec_decode_tokens_copy = ( + scheduler_output.scheduled_spec_decode_tokens.copy() + ) + scheduler_output = replace( + scheduler_output, + num_scheduled_tokens=num_scheduled_tokens_copy, + scheduled_spec_decode_tokens=spec_decode_tokens_copy, + ) + # self._draft_token_ids is None when `input_fits_in_drafter=False` # and there is no draft tokens scheduled. so it need to update the # spec_decoding info in scheduler_output with async_scheduling. @@ -1528,6 +1626,25 @@ def execute_model( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("prepare input"): with self.synchronize_input_prep(): + # Fix up prev_req_id_to_index for requests that were discarded + # in the previous sample_tokens step. If a request has + # prev_num_draft_len > 0 but is missing from + # prev_req_id_to_index, the parent _update_states would + # hit a KeyError. Reset prev_num_draft_len to 0 for such + # requests so they fall through safely. + if ( + self.use_async_scheduling + and self.num_spec_tokens + and self.input_batch.prev_req_id_to_index is not None + ): + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + if ( + req_id not in self.input_batch.prev_req_id_to_index + and (req_state := self.requests.get(req_id)) is not None + and req_state.prev_num_draft_len + ): + req_state.prev_num_draft_len = 0 + # Update persistent batch states. deferred_state_corrections_fn = self._update_states(scheduler_output) @@ -1951,14 +2068,15 @@ def propose_draft_token_ids(sampled_token_ids): self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() or self.speculative_config.uses_extract_hidden_states() + or self.speculative_config.use_ngram_gpu() ) and not self.speculative_config.disable_padded_drafter_batch ) if use_padded_batch: - # EAGLE speculative decoding can use the GPU sampled tokens + # EAGLE/ngram_gpu speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - if self.speculative_config and not use_padded_batch: + elif self.speculative_config and not use_padded_batch: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2110,6 +2228,7 @@ def _bookkeeping_sync( discard_sampled_tokens_req_indices, logprobs_tensors=logprobs_tensors, ) + else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist()