diff --git a/csrc/add_rms_norm_bias/add_rms_norm_bias_torch_adpt.h b/csrc/add_rms_norm_bias/add_rms_norm_bias_torch_adpt.h new file mode 100644 index 00000000000..dbfa81e9005 --- /dev/null +++ b/csrc/add_rms_norm_bias/add_rms_norm_bias_torch_adpt.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ADD_RMS_NORM_BIAS_TORCH_ADPT_H +#define ADD_RMS_NORM_BIAS_TORCH_ADPT_H + +namespace vllm_ascend { + +std::tuple npu_add_rms_norm_bias( + const at::Tensor& x1, + const at::Tensor& x2, + const at::Tensor& gamma, + const c10::optional &beta, + double epsilon) +{ + int64_t dim_x = x1.dim(); + int64_t dim_gamma = gamma.dim(); + int64_t diff = dim_x - dim_gamma; + std::vector new_shape; + at::Tensor rstd; + + if (diff > 0) { + new_shape.reserve(dim_x); + auto x1_sizes = x1.sizes(); + for (int64_t i = 0; i < diff; ++i) { + new_shape.push_back(x1_sizes[i]); + } + for (int64_t i = 0; i < dim_gamma; ++i) { + new_shape.push_back(1); + } + } else { + new_shape.assign(dim_x, 1); + } + rstd = at::empty(new_shape, x1.options().dtype(at::kFloat)); + at::Tensor y = at::empty(x1.sizes(), x1.options()); + at::Tensor x = at::empty(x1.sizes(), x1.options()); + EXEC_NPU_CMD(aclnnAddRmsNormBias, x1, x2, gamma, beta, epsilon, y, rstd, x); + return std::tuple(y, rstd, x); +} +} +#endif \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/apply_top_k_top_p_custom_torch_adpt.h b/csrc/apply_top_k_top_p_custom/apply_top_k_top_p_custom_torch_adpt.h new file mode 100644 index 00000000000..bcb07c72e26 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/apply_top_k_top_p_custom_torch_adpt.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef APPLY_TOP_K_TOP_P_CUSTOM_TORCH_ADPT_H +#define APPLY_TOP_K_TOP_P_CUSTOM_TORCH_ADPT_H + +namespace vllm_ascend { +at::Tensor npu_apply_top_k_top_p( + const at::Tensor& logits, + const c10::optional& p, + const c10::optional& k) +{ + TORCH_CHECK(p.has_value() || k.has_value(), + "apply_top_k_top_p: p and k cannot be None at the same time."); + + at::Tensor out = at::empty_like(logits); + + EXEC_NPU_CMD( + aclnnApplyTopKTopPCustom, + logits, + p, + k, + out); + + return out; +} +} +#endif \ No newline at end of file diff --git a/csrc/batch_matmul_transpose/batch_matmul_transpose_torch_adpt.h b/csrc/batch_matmul_transpose/batch_matmul_transpose_torch_adpt.h new file mode 100644 index 00000000000..64e177b9f10 --- /dev/null +++ b/csrc/batch_matmul_transpose/batch_matmul_transpose_torch_adpt.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BATCH_MATMUL_TRANSPOSE_TORCH_ADPT_H +#define BATCH_MATMUL_TRANSPOSE_TORCH_ADPT_H +#include "op_host/batch_matmul_transpose.h" + +namespace vllm_ascend { + +void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling( + tensor_a, + tensor_b, + tensor_c, + format_mode, + quant_mode + ); + + void *gm_a = tensor_a.data_ptr(); + void *gm_b = tensor_b.data_ptr(); + void *gm_c = tensor_c.data_ptr(); + void *gm_tiling_data = tiling_tensor.data_ptr(); + + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("batch_matmul_transpose"); + + cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data, + block_dim]() -> int { + batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data, + block_dim); + return 0; + }); + cmd.Run(); + return; +} + +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/dispatch_ffn_combine_torch_adpt.h b/csrc/dispatch_ffn_combine/dispatch_ffn_combine_torch_adpt.h new file mode 100644 index 00000000000..0eb95c8c5c4 --- /dev/null +++ b/csrc/dispatch_ffn_combine/dispatch_ffn_combine_torch_adpt.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DISPATCH_FFN_COMBINE_TORCH_ADPT_H +#define DISPATCH_FFN_COMBINE_TORCH_ADPT_H + +namespace vllm_ascend { +std::tuple dispatch_ffn_combine( + const at::Tensor& x, + const at::TensorList& weight1, + const at::TensorList& weight2, + const at::Tensor& expert_idx, + const at::TensorList& scale1, + const at::TensorList& scale2, + const at::Tensor& probs, + c10::string_view group, + int64_t max_output_size, + at::Tensor& out, + at::Tensor& expert_token_nums +) { + char *group_ep_ptr = const_cast(group.data()); + bool is_int8 = weight1[0].dtype() == at::kChar; + if (is_int8) { + EXEC_NPU_CMD(aclnnDispatchFFNCombine, + x, + weight1, + weight2, + expert_idx, + scale1, + scale2, + probs, + group_ep_ptr, + max_output_size, + out, + expert_token_nums); + } else { + EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16, + x, + weight1, + weight2, + expert_idx, + scale1, + scale2, + probs, + group_ep_ptr, + max_output_size, + out, + expert_token_nums); + } + return {out, expert_token_nums}; +} +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_gmm_combine_decode/dispatch_gmm_combine_decode_torch_adpt.h b/csrc/dispatch_gmm_combine_decode/dispatch_gmm_combine_decode_torch_adpt.h new file mode 100644 index 00000000000..cbbcc711796 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/dispatch_gmm_combine_decode_torch_adpt.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DISPATCH_GMM_COMBINE_TORCH_ADPT_H +#define DISPATCH_GMM_COMBINE_TORCH_ADPT_H + +namespace vllm_ascend { + +std::tuple dispatch_gmm_combine_decode( + const at::Tensor &x, + const at::Tensor &expert_ids, + const at::TensorList &gmm1_permuted_weight, + const at::TensorList &gmm1_permuted_weight_scale, + const at::TensorList &gmm2_weight, + const at::TensorList &gmm2_weight_scale, + const at::Tensor &expert_scales, + const c10::optional &expert_smooth_scales, + const c10::optional &x_active_mask, + c10::string_view group_ep, + int64_t ep_rank_size, + int64_t ep_rank_id, + int64_t moe_expert_num, + int64_t shared_expert_num, + int64_t shared_expert_rank_num, + int64_t quant_mode, + int64_t global_bs) +{ + auto x_shape = x.sizes(); + int bs = x_shape[0]; + int h = x_shape[1]; + + at::Tensor output = at::empty({bs, h}, x.options()); + + bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); + auto opts = expert_ids.options().dtype(at::kLong); + at::Tensor expert_token_nums = at::empty({num_local_experts}, opts); + + vector group_ep_chrs(group_ep.begin(), group_ep.end()); + group_ep_chrs.push_back('\0'); + char *group_ep_ptr = &group_ep_chrs[0]; + EXEC_NPU_CMD( + // op api + aclnnDispatchGmmCombineDecode, + // input tensors + x, + expert_ids, + gmm1_permuted_weight, + gmm1_permuted_weight_scale, + gmm2_weight, + gmm2_weight_scale, + expert_scales, + expert_smooth_scales, + x_active_mask, + //input attrs + group_ep_ptr, + ep_rank_size, + ep_rank_id, + moe_expert_num, + shared_expert_num, + shared_expert_rank_num, + quant_mode, + global_bs, + // output tensors + output, + expert_token_nums); + return {output, expert_token_nums}; +} + +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_layout/dispatch_layout_torch_adpt.h b/csrc/dispatch_layout/dispatch_layout_torch_adpt.h new file mode 100644 index 00000000000..6523313d175 --- /dev/null +++ b/csrc/dispatch_layout/dispatch_layout_torch_adpt.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DISPATCH_LAYOUT_TORCH_ADPT_H +#define DISPATCH_LAYOUT_TORCH_ADPT_H + +namespace vllm_ascend { +std::tuple get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts, + int64_t num_ranks) { + TORCH_BIND_ASSERT(topk_idx.dim() == 2); + TORCH_BIND_ASSERT(topk_idx.is_contiguous()); + TORCH_BIND_ASSERT(num_experts > 0); + + const int num_tokens = topk_idx.size(0); + const int num_topk = topk_idx.size(1); + + auto device = topk_idx.device(); + auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device)); + auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device)); + auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); + + EXEC_NPU_CMD(aclnnDispatchLayout, + topk_idx, + num_tokens, + num_ranks, + num_experts, + num_topk, + num_tokens_per_rank, + num_tokens_per_expert, + is_token_in_rank); + + auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool); + + return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool); +} + +} +#endif \ No newline at end of file diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/grouped_matmul_swiglu_quant_torch_adpt.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/grouped_matmul_swiglu_quant_torch_adpt.h new file mode 100644 index 00000000000..3b36e9384c9 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/grouped_matmul_swiglu_quant_torch_adpt.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GROUPED_MATMUL_SWIGLU_QUANT_TORCH_ADPT_H +#define GROUPED_MATMUL_SWIGLU_QUANT_TORCH_ADPT_H +namespace vllm_ascend { +const int64_t INT4_NUMS_IN_INT32 = 8; +std::tuple grouped_matmul_swiglu_quant( + const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, + const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) +{ + int m = x.sizes()[0]; + int n = weight.sizes()[2]; + bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; + if (is_a8w4) { + n *= INT4_NUMS_IN_INT32; + } + + at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); + + EXEC_NPU_CMD( + aclnnGroupedMatmulSwigluQuantWeightNZ, + x, + weight, + bias, + offset, + weight_scale, + x_scale, + group_list, + output, + output_scale, + output_offset); + return std::tuple(output, output_scale, output_offset); +} + +std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list( + const at::Tensor & x, + const at::TensorList & weight, + const at::TensorList & weight_scale, + const at::Tensor & x_scale, + const at::Tensor & group_list, + const c10::optional & bias, + const c10::optional & offset) +{ + auto x_size = x.sizes(); + int n = weight[0].sizes()[1]; + int m = x_size[0]; + int k = x_size[1]; + + at::Tensor output = at::empty({m, n/2}, x.options().dtype(at::kChar)); + at::Tensor output_scale = at::empty({m}, x.options().dtype(at::kFloat)); + at::Tensor output_offset = at::empty({m}, x.options().dtype(at::kFloat)); + + EXEC_NPU_CMD( + aclnnGroupedMatmulSwigluQuantWeightNzTensorList, + x, + weight, + bias, + offset, + weight_scale, + x_scale, + group_list, + output, + output_scale, + output_offset); + + return std::tuple(output, output_scale, output_offset); +} +} +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_vllm/lightning_indexer_vllm_torch_adpt.h b/csrc/lightning_indexer_vllm/lightning_indexer_vllm_torch_adpt.h new file mode 100644 index 00000000000..a3e56522ea3 --- /dev/null +++ b/csrc/lightning_indexer_vllm/lightning_indexer_vllm_torch_adpt.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LIGHTING_INDEXER_VLLM_TORCH_ADPT_H +#define LIGHTING_INDEXER_VLLM_TORCH_ADPT_H +namespace vllm_ascend { + +at::Tensor npu_lightning_indexer( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, c10::string_view layout_query, + c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + // npu tensor max size + constexpr int32_t SIZE = 8; + constexpr int32_t DIM_0 = 0; + constexpr int32_t DIM_1 = 1; + constexpr int32_t DIM_2 = 2; + constexpr int32_t DIM_3 = 3; + + TORCH_CHECK(query.numel() > 0, "Query is empty."); + TORCH_CHECK(key.numel() > 0, "Key is empty."); + TORCH_CHECK(weights.numel() > 0, "Weights is empty."); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + + at::SmallVector output_size; + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count}; + } else { + int n_dim_index = 0; + n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2; + output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count}; + } + at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt)); + // convert str + char *query_layout_ptr = const_cast(query_layout_str.c_str()); + char *key_layout_ptr = const_cast(key_layout_str.c_str()); + EXEC_NPU_CMD( + aclnnLightningIndexerVllm, + query, + key, + weights, + actual_seq_lengths_query, + actual_seq_lengths_key, + block_table, + query_layout_ptr, + key_layout_ptr, + sparse_count, + sparse_mode, + lightning_indexer_output); + return lightning_indexer_output; +} +} +#endif \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/matmul_allreduce_add_rmsnorm_torch_adpt.h b/csrc/matmul_allreduce_add_rmsnorm/matmul_allreduce_add_rmsnorm_torch_adpt.h new file mode 100644 index 00000000000..0518306ad3b --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/matmul_allreduce_add_rmsnorm_torch_adpt.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_TORCH_ADPT_H +#define MATMUL_ALLREDUCE_ADD_RMSNORM_TORCH_ADPT_H +namespace vllm_ascend { + +std::tuple matmul_allreduce_add_rmsnorm( + const at::Tensor &x1, + const at::Tensor &x2, + const at::Tensor &residual, + const at::Tensor &gamma, + c10::string_view group_tp, + int64_t tp_rank_size, + int64_t tp_rank_id, + double epsilon, + bool is_trans_b, + bool is_gather_add_out) + { + at::Tensor output = at::empty_like(residual); + at::Tensor add_out = at::empty_like(residual); + + std::string group_tp_str(group_tp); + + char *group_tp_ptr = group_tp_str.data(); + + float epsilon_f = static_cast(epsilon); + EXEC_NPU_CMD(aclnnMatmulAllreduceAddRmsnorm, + // input + x1, x2, residual, gamma, + // attr + group_tp_ptr, tp_rank_size, tp_rank_id, epsilon_f, is_trans_b, is_gather_add_out, + // output + output, add_out); + + return {output, add_out}; + } +} +#endif \ No newline at end of file diff --git a/csrc/mla_preprocess/mla_preprocess_torch_adpt.h b/csrc/mla_preprocess/mla_preprocess_torch_adpt.h new file mode 100644 index 00000000000..0bcd3ca836c --- /dev/null +++ b/csrc/mla_preprocess/mla_preprocess_torch_adpt.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MLA_PREPROCESS_TORCH_ADPT_H +#define MLA_PREPROCESS_TORCH_ADPT_H + + +#include "op_host/mla_preprocess.h" + +namespace vllm_ascend { +std::tuple mla_preprocess( + const at::Tensor &hiddenState, const at::Tensor &wdqkv, + const c10::optional &descale0, const at::Tensor &gamma1, const c10::optional &beta1, const at::Tensor &wuq, + const c10::optional &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, + const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, + const c10::optional &quant_scale0, const c10::optional &quant_offset0, const c10::optional &bias0, + const c10::optional &quant_scale1, const c10::optional &quant_offset1, const c10::optional &bias1, + const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, + c10::optional cache_mode, c10::optional quant_mode, c10::optional enable_inner_out, at::Tensor &q_out0, + at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) +{ + at::Tensor Descale0 = + descale0.has_value() + ? descale0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Descale1 = + descale1.has_value() + ? descale1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Beta1 = + beta1.has_value() + ? beta1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_scale0 = + quant_scale0.has_value() + ? quant_scale0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_scale1 = + quant_scale1.has_value() + ? quant_scale1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_offset0 = + quant_offset0.has_value() + ? quant_offset0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_offset1 = + quant_offset1.has_value() + ? quant_offset1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Bias0 = + bias0.has_value() + ? bias0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Bias1 = + bias1.has_value() + ? bias1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor CtkvScale = + ctkv_scale.has_value() + ? ctkv_scale.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor QnopeScale = + q_nope_scale.has_value() + ? q_nope_scale.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + bool enableInnerOut = + enable_inner_out.has_value() + ? enable_inner_out.value() + : false; + + auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( + hiddenState, + wdqkv, + wuk, + cache_mode, + quant_mode, + enableInnerOut + ); + + void *hidden_state_ptr = hiddenState.data_ptr(); + void *quant_scale0_ptr = Quant_scale0.data_ptr(); + void *quant_offset0_ptr = Quant_offset0.data_ptr(); + void *wdqkv_ptr = wdqkv.data_ptr(); + void *bias0_ptr = Bias0.data_ptr(); + void *gamma1_ptr = gamma1.data_ptr(); + void *beta1_ptr = Beta1.data_ptr(); + void *quant_scale1_ptr = Quant_scale1.data_ptr(); + void *quant_offset1_ptr = Quant_offset1.data_ptr(); + void *gamma2_ptr = gamma2.data_ptr(); + void *sin_ptr = sin.data_ptr(); + void *cos_ptr = cos.data_ptr(); + void *kv_cache_ptr = kv_cache.data_ptr(); + void *slotmapping_ptr = slotmapping.data_ptr(); + void *wuq_ptr = wuq.data_ptr(); + void *bias1_ptr = Bias1.data_ptr(); + void *wuk_ptr = wuk.data_ptr(); + void *descale0_ptr = Descale0.data_ptr(); + void *descale1_ptr = Descale1.data_ptr(); + void *ctkv_scale_ptr = CtkvScale.data_ptr(); + void *qnope_scale_ptr = QnopeScale.data_ptr(); + void *q_out0_ptr = q_out0.data_ptr(); + void *kv_cache_out0_ptr = kv_cache_out0.data_ptr(); + void *q_out1_ptr = q_out1.data_ptr(); + void *kv_cache_out1_ptr = kv_cache_out1.data_ptr(); + void *inner_out_ptr = inner_out.data_ptr(); + void *workspace_ptr = workspace_tensor.data_ptr(); + void *tiling_ptr = tiling.data_ptr(); + + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("mla_preprocess"); + + cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, + gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, + kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, + qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, + tiling_ptr, block_dim]() -> int { + mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, + gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, + kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, + qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, + tiling_ptr, block_dim); + return 0; + }); + cmd.Run(); + return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out); +} +} +#endif \ No newline at end of file diff --git a/csrc/moe_combine_normal/moe_combine_normal_torch_adpt.h b/csrc/moe_combine_normal/moe_combine_normal_torch_adpt.h new file mode 100644 index 00000000000..3c9fd34f69e --- /dev/null +++ b/csrc/moe_combine_normal/moe_combine_normal_torch_adpt.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MOE_COMBINE_NORMAL_TORCH_ADPT_H +#define MOE_COMBINE_NORMAL_TORCH_ADPT_H + +namespace vllm_ascend { +at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, + const at::Tensor& src_idx, const at::Tensor& send_head, c10::string_view groupEp, + int64_t rank, int64_t num_ranks) { + std::vector group_ep_chrs(groupEp.begin(), groupEp.end()); + group_ep_chrs.push_back('\0'); + char* group_ep_ptr = &group_ep_chrs[0]; + + TORCH_BIND_ASSERT(x.dim() == 2 and x.is_contiguous()); + at::Tensor recv_x = x; + + at::Tensor topk_idx_p = topk_idx; + + auto topk_idx_int32 = topk_idx_p.to(at::kInt); + at::Tensor expand_ids = topk_idx_int32; + at::Tensor token_src_info = src_idx; + at::Tensor ep_send_counts = send_head; + auto device = x.device(); + + const int num_tokens = topk_idx_p.size(0); + const int num_topk = topk_idx_p.size(1); + + int64_t hidden = static_cast(recv_x.size(1)); + at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device)); + int64_t tp_world_size = 1; + int64_t tp_rankId = 0; + int64_t moe_expert_number = send_head.size(0); + int64_t global_bs = topk_idx_p.size(0) * num_ranks; + + // Combine data + auto combined_x = torch::empty({topk_weights.size(0), hidden}, x.options()); + + EXEC_NPU_CMD(aclnnMoeCombineNormal, + recv_x, + token_src_info, + ep_send_counts, + topk_weights, + tp_send_counts, + group_ep_ptr, + num_ranks, + rank, + group_ep_ptr, + tp_world_size, + tp_rankId, + moe_expert_number, + global_bs, + combined_x); + + return combined_x; +} + +} + +#endif \ No newline at end of file diff --git a/csrc/moe_gating_top_k/moe_gating_top_k_torch_adpt.h b/csrc/moe_gating_top_k/moe_gating_top_k_torch_adpt.h new file mode 100644 index 00000000000..0bf5b9a2fb6 --- /dev/null +++ b/csrc/moe_gating_top_k/moe_gating_top_k_torch_adpt.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MOE_GATING_TOP_K_TORCH_ADPT_H +#define MOE_GATING_TOP_K_TORCH_ADPT_H +namespace vllm_ascend { +std::tuple moe_gating_top_k( + const at::Tensor& x, + int64_t k, + int64_t k_group, + int64_t group_count, + int64_t group_select_mode, + int64_t renorm, + int64_t norm_type, + bool out_flag, + double routed_scaling_factor, + double eps, + const c10::optional& bias_opt + ) +{ + TORCH_CHECK(x.dim() == 2, "The x should be 2D"); + TORCH_CHECK( + x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16, + "float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ", + x.scalar_type()); + + auto x_size = x.sizes(); + auto rows = x_size[0]; + auto expert_num = x_size[1]; + const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); }); + if (bias.defined()) { + TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same"); + TORCH_CHECK(bias.dim() == 1, "The bias should be 1D"); + auto bias_size = bias.sizes(); + TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim"); + } + at::Tensor y = at::empty({rows, k}, x.options()); + at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt)); + at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat)); + + EXEC_NPU_CMD(aclnnMoeGatingTopK, + x, + bias, + k, + k_group, + group_count, + group_select_mode, + renorm, + norm_type, + out_flag, + routed_scaling_factor, + eps, + y, + expert_idx, + out + ); + + return std::tuple(y,expert_idx,out); +} + +} +#endif \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h b/csrc/moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h new file mode 100644 index 00000000000..6d717e86d53 --- /dev/null +++ b/csrc/moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MOE_INIT_ROUTING_CUSTOM_TORCH_ADPT_H +#define MOE_INIT_ROUTING_CUSTOM_TORCH_ADPT_H +namespace vllm_ascend { +std::tuple npu_moe_init_routing_custom( + const at::Tensor &x, const at::Tensor &expert_idx, + const c10::optional &scale, const c10::optional &offset, int64_t active_num, + int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type, + bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type) +{ + constexpr int64_t DIM_X = 2; + constexpr int64_t DIM_EXPERT_IDX = 2; + constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2; + constexpr int64_t EXPERT_TOKENS_COUNT = 1; + constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2; + constexpr int64_t QUANT_MODE_UNQUANT = -1; + constexpr int64_t QUANT_MODE_DYNAMIC_QUANT = 1; + constexpr int64_t CUMSUM = 0; + constexpr int64_t COUNT = 1; + constexpr int64_t KEY_VALUE = 2; + + if (active_expert_range.empty()) { + active_expert_range = at::IntArrayRef({0, expert_num}); + } + + int64_t x_dim = x.dim(); + TORCH_CHECK(x_dim == DIM_X, "The x should be ", DIM_X, + "-Dimension, current is ", x_dim, "-Dimension."); + + int64_t expert_idx_dim = expert_idx.dim(); + TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX, "The expert_idx should be ", DIM_EXPERT_IDX, + "-Dimension, current is ", expert_idx_dim, "-Dimension."); + + int64_t active_expert_range_length = active_expert_range.size(); + TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE, "The active_expert_range should be ", LENGTH_ACTIVE_EXPERT_RANGE, + "-Dimension, current is ", expert_idx_dim, "-Dimension."); + + int expert_length = active_expert_range[1] - active_expert_range[0]; + auto x_size = x.sizes(); + auto expert_idx_size = expert_idx.sizes(); + + int bs = x_size[0]; + int h = x_size[1]; + int k = expert_idx_size[1]; + int64_t expanded_scale_len = 0; + at::Tensor expanded_x; + + if (drop_pad_mode == 1) { // Drop/Pad + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({expert_num, expert_capacity, h}, x.options()); + } else { + expanded_x = at::empty({expert_num, expert_capacity, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = expert_num * expert_capacity; + } else { // Dropless / Active + if (active_num > 0) { // Active + int64_t num_out_tokens = std::min((int64_t)bs * k, active_num); + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({num_out_tokens, h}, x.options()); + } else { + expanded_x = at::empty({num_out_tokens, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = num_out_tokens; + } else { // Dropless + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({bs * k, h}, x.options()); + } else { + expanded_x = at::empty({bs * k, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = bs * k; + } + } + + at::Tensor expanded_row_idx = at::empty({bs * k}, expert_idx.options()); + at::Tensor expert_tokens_count_or_cumsum; + if (expert_tokens_num_type >= CUMSUM && expert_tokens_num_type <= COUNT) { + // expert_tokens_count_or_cumsum in [end-start, ] + expert_tokens_count_or_cumsum = at::empty({expert_length}, x.options().dtype(at::kLong)); + } else if (expert_tokens_num_type == KEY_VALUE) { + // key_value in [2, end-start] + expert_tokens_count_or_cumsum = at::empty({expert_num, 2}, x.options().dtype(at::kLong)); + } + at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat)); + EXEC_NPU_CMD(aclnnMoeInitRoutingCustom, + x, + expert_idx, + scale, + offset, + active_num, + expert_capacity, + expert_num, + drop_pad_mode, + expert_tokens_num_type, + expert_tokens_num_flag, + quant_mode, + active_expert_range, + row_idx_type, + expanded_x, + expanded_row_idx, + expert_tokens_count_or_cumsum, + expanded_scale); + return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale); +} +} +#endif \ No newline at end of file diff --git a/csrc/sparse_flash_attention/sparse_flash_attention_torch_adpt.h b/csrc/sparse_flash_attention/sparse_flash_attention_torch_adpt.h new file mode 100644 index 00000000000..425c29ee6ef --- /dev/null +++ b/csrc/sparse_flash_attention/sparse_flash_attention_torch_adpt.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SPARSE_FLASH_ATTENTION_TORCH_ADPT_H +#define SPARSE_FLASH_ATTENTION_TORCH_ADPT_H +namespace vllm_ascend { + +at::Tensor npu_sparse_flash_attention( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, + const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size, + const c10::optional &block_table, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_kv, + const c10::optional &query_rope, + const c10::optional &key_rope, c10::string_view layout_query, + c10::string_view layout_kv, + int64_t sparse_mode) +{ + std::string layout_query_str = std::string(layout_query); + std::string layout_kv_str = std::string(layout_kv); + + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + // construct the output tensor + at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); + // convert str + char *layout_query_ptr = const_cast(layout_query_str.c_str()); + char *layout_kv_ptr = const_cast(layout_kv_str.c_str()); + + EXEC_NPU_CMD( + aclnnSparseFlashAttention, + query, + key, + value, + sparse_indices, + block_table, + actual_seq_lengths_query, + actual_seq_lengths_kv, + query_rope, + key_rope, + scale_value, + sparse_block_size, + layout_query_ptr, + layout_kv_ptr, + sparse_mode, + output); + return output; +} +} +#endif \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index c43e1a9c742..77eaf4d5016 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -27,16 +27,26 @@ #include "acl/acl_rt.h" #include "ops.h" #include "utils.h" -#include "mla_preprocess/op_host/mla_preprocess.h" -#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h" #include "aclnn_torch_adapter/op_api_common.h" - +#include "add_rms_norm_bias/add_rms_norm_bias_torch_adpt.h" +#include "apply_top_k_top_p_custom/apply_top_k_top_p_custom_torch_adpt.h" +#include "batch_matmul_transpose/batch_matmul_transpose_torch_adpt.h" +#include "dispatch_ffn_combine/dispatch_ffn_combine_torch_adpt.h" +#include "dispatch_gmm_combine_decode/dispatch_gmm_combine_decode_torch_adpt.h" +#include "dispatch_layout/dispatch_layout_torch_adpt.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list/grouped_matmul_swiglu_quant_torch_adpt.h" +#include "lightning_indexer_vllm/lightning_indexer_vllm_torch_adpt.h" +#include "matmul_allreduce_add_rmsnorm/matmul_allreduce_add_rmsnorm_torch_adpt.h" +#include "mla_preprocess/mla_preprocess_torch_adpt.h" +#include "moe_combine_normal/moe_combine_normal_torch_adpt.h" +#include "moe_gating_top_k/moe_gating_top_k_torch_adpt.h" +#include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h" +#include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h" #include #include #include namespace vllm_ascend { -const int64_t INT4_NUMS_IN_INT32 = 8; void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping, aclrtStream stream) { @@ -105,124 +115,6 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType) } } -std::tuple mla_preprocess( - const at::Tensor &hiddenState, const at::Tensor &wdqkv, - const c10::optional &descale0, const at::Tensor &gamma1, const c10::optional &beta1, const at::Tensor &wuq, - const c10::optional &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, - const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, - const c10::optional &quant_scale0, const c10::optional &quant_offset0, const c10::optional &bias0, - const c10::optional &quant_scale1, const c10::optional &quant_offset1, const c10::optional &bias1, - const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, - c10::optional cache_mode, c10::optional quant_mode, c10::optional enable_inner_out, at::Tensor &q_out0, - at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) -{ - at::Tensor Descale0 = - descale0.has_value() - ? descale0.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Descale1 = - descale1.has_value() - ? descale1.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Beta1 = - beta1.has_value() - ? beta1.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Quant_scale0 = - quant_scale0.has_value() - ? quant_scale0.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Quant_scale1 = - quant_scale1.has_value() - ? quant_scale1.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Quant_offset0 = - quant_offset0.has_value() - ? quant_offset0.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Quant_offset1 = - quant_offset1.has_value() - ? quant_offset1.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Bias0 = - bias0.has_value() - ? bias0.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor Bias1 = - bias1.has_value() - ? bias1.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor CtkvScale = - ctkv_scale.has_value() - ? ctkv_scale.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - at::Tensor QnopeScale = - q_nope_scale.has_value() - ? q_nope_scale.value() - : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); - bool enableInnerOut = - enable_inner_out.has_value() - ? enable_inner_out.value() - : false; - - auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( - hiddenState, - wdqkv, - wuk, - cache_mode, - quant_mode, - enableInnerOut - ); - - void *hidden_state_ptr = hiddenState.data_ptr(); - void *quant_scale0_ptr = Quant_scale0.data_ptr(); - void *quant_offset0_ptr = Quant_offset0.data_ptr(); - void *wdqkv_ptr = wdqkv.data_ptr(); - void *bias0_ptr = Bias0.data_ptr(); - void *gamma1_ptr = gamma1.data_ptr(); - void *beta1_ptr = Beta1.data_ptr(); - void *quant_scale1_ptr = Quant_scale1.data_ptr(); - void *quant_offset1_ptr = Quant_offset1.data_ptr(); - void *gamma2_ptr = gamma2.data_ptr(); - void *sin_ptr = sin.data_ptr(); - void *cos_ptr = cos.data_ptr(); - void *kv_cache_ptr = kv_cache.data_ptr(); - void *slotmapping_ptr = slotmapping.data_ptr(); - void *wuq_ptr = wuq.data_ptr(); - void *bias1_ptr = Bias1.data_ptr(); - void *wuk_ptr = wuk.data_ptr(); - void *descale0_ptr = Descale0.data_ptr(); - void *descale1_ptr = Descale1.data_ptr(); - void *ctkv_scale_ptr = CtkvScale.data_ptr(); - void *qnope_scale_ptr = QnopeScale.data_ptr(); - void *q_out0_ptr = q_out0.data_ptr(); - void *kv_cache_out0_ptr = kv_cache_out0.data_ptr(); - void *q_out1_ptr = q_out1.data_ptr(); - void *kv_cache_out1_ptr = kv_cache_out1.data_ptr(); - void *inner_out_ptr = inner_out.data_ptr(); - void *workspace_ptr = workspace_tensor.data_ptr(); - void *tiling_ptr = tiling.data_ptr(); - - aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - at_npu::native::OpCommand cmd; - cmd.Name("mla_preprocess"); - - cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, - gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, - kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, - qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, - tiling_ptr, block_dim]() -> int { - mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, - gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, - kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, - qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, - tiling_ptr, block_dim); - return 0; - }); - cmd.Run(); - return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out); -} - std::tuple get_masked_input_and_mask( at::Tensor &input, const int64_t org_vocab_start_index, @@ -500,363 +392,6 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic return y_out; } -std::tuple grouped_matmul_swiglu_quant( - const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, - const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) -{ - int m = x.sizes()[0]; - int n = weight.sizes()[2]; - bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; - if (is_a8w4) { - n *= INT4_NUMS_IN_INT32; - } - - at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); - at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); - at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); - - EXEC_NPU_CMD( - aclnnGroupedMatmulSwigluQuantWeightNZ, - x, - weight, - bias, - offset, - weight_scale, - x_scale, - group_list, - output, - output_scale, - output_offset); - return std::tuple(output, output_scale, output_offset); -} - -std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list( - const at::Tensor & x, - const at::TensorList & weight, - const at::TensorList & weight_scale, - const at::Tensor & x_scale, - const at::Tensor & group_list, - const c10::optional & bias, - const c10::optional & offset) -{ - auto x_size = x.sizes(); - int n = weight[0].sizes()[1]; - int m = x_size[0]; - int k = x_size[1]; - - at::Tensor output = at::empty({m, n/2}, x.options().dtype(at::kChar)); - at::Tensor output_scale = at::empty({m}, x.options().dtype(at::kFloat)); - at::Tensor output_offset = at::empty({m}, x.options().dtype(at::kFloat)); - - EXEC_NPU_CMD( - aclnnGroupedMatmulSwigluQuantWeightNzTensorList, - x, - weight, - bias, - offset, - weight_scale, - x_scale, - group_list, - output, - output_scale, - output_offset); - - return std::tuple(output, output_scale, output_offset); -} - -std::tuple dispatch_gmm_combine_decode( - const at::Tensor &x, - const at::Tensor &expert_ids, - const at::TensorList &gmm1_permuted_weight, - const at::TensorList &gmm1_permuted_weight_scale, - const at::TensorList &gmm2_weight, - const at::TensorList &gmm2_weight_scale, - const at::Tensor &expert_scales, - const c10::optional &expert_smooth_scales, - const c10::optional &x_active_mask, - c10::string_view group_ep, - int64_t ep_rank_size, - int64_t ep_rank_id, - int64_t moe_expert_num, - int64_t shared_expert_num, - int64_t shared_expert_rank_num, - int64_t quant_mode, - int64_t global_bs) -{ - auto x_shape = x.sizes(); - int bs = x_shape[0]; - int h = x_shape[1]; - - at::Tensor output = at::empty({bs, h}, x.options()); - - bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); - int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); - auto opts = expert_ids.options().dtype(at::kLong); - at::Tensor expert_token_nums = at::empty({num_local_experts}, opts); - - vector group_ep_chrs(group_ep.begin(), group_ep.end()); - group_ep_chrs.push_back('\0'); - char *group_ep_ptr = &group_ep_chrs[0]; - EXEC_NPU_CMD( - // op api - aclnnDispatchGmmCombineDecode, - // input tensors - x, - expert_ids, - gmm1_permuted_weight, - gmm1_permuted_weight_scale, - gmm2_weight, - gmm2_weight_scale, - expert_scales, - expert_smooth_scales, - x_active_mask, - //input attrs - group_ep_ptr, - ep_rank_size, - ep_rank_id, - moe_expert_num, - shared_expert_num, - shared_expert_rank_num, - quant_mode, - global_bs, - // output tensors - output, - expert_token_nums); - return {output, expert_token_nums}; -} - -void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, - c10::optional format_mode, - c10::optional quant_mode) -{ - auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling( - tensor_a, - tensor_b, - tensor_c, - format_mode, - quant_mode - ); - - void *gm_a = tensor_a.data_ptr(); - void *gm_b = tensor_b.data_ptr(); - void *gm_c = tensor_c.data_ptr(); - void *gm_tiling_data = tiling_tensor.data_ptr(); - - aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - at_npu::native::OpCommand cmd; - cmd.Name("batch_matmul_transpose"); - - cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data, - block_dim]() -> int { - batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data, - block_dim); - return 0; - }); - cmd.Run(); - return; -} - -std::tuple dispatch_ffn_combine( - const at::Tensor& x, - const at::TensorList& weight1, - const at::TensorList& weight2, - const at::Tensor& expert_idx, - const at::TensorList& scale1, - const at::TensorList& scale2, - const at::Tensor& probs, - c10::string_view group, - int64_t max_output_size, - at::Tensor& out, - at::Tensor& expert_token_nums -) { - char *group_ep_ptr = const_cast(group.data()); - bool is_int8 = weight1[0].dtype() == at::kChar; - if (is_int8) { - EXEC_NPU_CMD(aclnnDispatchFFNCombine, - x, - weight1, - weight2, - expert_idx, - scale1, - scale2, - probs, - group_ep_ptr, - max_output_size, - out, - expert_token_nums); - } else { - EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16, - x, - weight1, - weight2, - expert_idx, - scale1, - scale2, - probs, - group_ep_ptr, - max_output_size, - out, - expert_token_nums); - } - return {out, expert_token_nums}; -} - -at::Tensor npu_lightning_indexer( - const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, - const c10::optional &actual_seq_lengths_query, - const c10::optional &actual_seq_lengths_key, - const c10::optional &block_table, c10::string_view layout_query, - c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) -{ - // npu tensor max size - constexpr int32_t SIZE = 8; - constexpr int32_t DIM_0 = 0; - constexpr int32_t DIM_1 = 1; - constexpr int32_t DIM_2 = 2; - constexpr int32_t DIM_3 = 3; - - TORCH_CHECK(query.numel() > 0, "Query is empty."); - TORCH_CHECK(key.numel() > 0, "Key is empty."); - TORCH_CHECK(weights.numel() > 0, "Weights is empty."); - for (size_t i = 0; i < query.sizes().size(); i++) { - TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " - "than 0, but shape[", i, "] is ", query.size(i)); - } - TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); - - at::SmallVector output_size; - std::string query_layout_str = std::string(layout_query); - std::string key_layout_str = std::string(layout_key); - if (query_layout_str == "BSND") { - output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count}; - } else { - int n_dim_index = 0; - n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2; - output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count}; - } - at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt)); - // convert str - char *query_layout_ptr = const_cast(query_layout_str.c_str()); - char *key_layout_ptr = const_cast(key_layout_str.c_str()); - EXEC_NPU_CMD( - aclnnLightningIndexerVllm, - query, - key, - weights, - actual_seq_lengths_query, - actual_seq_lengths_key, - block_table, - query_layout_ptr, - key_layout_ptr, - sparse_count, - sparse_mode, - lightning_indexer_output); - return lightning_indexer_output; -} - -at::Tensor npu_sparse_flash_attention( - const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, - const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size, - const c10::optional &block_table, - const c10::optional &actual_seq_lengths_query, - const c10::optional &actual_seq_lengths_kv, - const c10::optional &query_rope, - const c10::optional &key_rope, c10::string_view layout_query, - c10::string_view layout_kv, - int64_t sparse_mode) -{ - std::string layout_query_str = std::string(layout_query); - std::string layout_kv_str = std::string(layout_kv); - - for (size_t i = 0; i < query.sizes().size(); i++) { - TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " - "than 0, but shape[", i, "] is ", query.size(i)); - } - // construct the output tensor - at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); - // convert str - char *layout_query_ptr = const_cast(layout_query_str.c_str()); - char *layout_kv_ptr = const_cast(layout_kv_str.c_str()); - - EXEC_NPU_CMD( - aclnnSparseFlashAttention, - query, - key, - value, - sparse_indices, - block_table, - actual_seq_lengths_query, - actual_seq_lengths_kv, - query_rope, - key_rope, - scale_value, - sparse_block_size, - layout_query_ptr, - layout_kv_ptr, - sparse_mode, - output); - return output; -} -std::tuple matmul_allreduce_add_rmsnorm( - const at::Tensor &x1, - const at::Tensor &x2, - const at::Tensor &residual, - const at::Tensor &gamma, - c10::string_view group_tp, - int64_t tp_rank_size, - int64_t tp_rank_id, - double epsilon, - bool is_trans_b, - bool is_gather_add_out) - { - at::Tensor output = at::empty_like(residual); - at::Tensor add_out = at::empty_like(residual); - - std::string group_tp_str(group_tp); - - char *group_tp_ptr = group_tp_str.data(); - - float epsilon_f = static_cast(epsilon); - EXEC_NPU_CMD(aclnnMatmulAllreduceAddRmsnorm, - // input - x1, x2, residual, gamma, - // attr - group_tp_ptr, tp_rank_size, tp_rank_id, epsilon_f, is_trans_b, is_gather_add_out, - // output - output, add_out); - - return {output, add_out}; - } - -std::tuple get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts, - int64_t num_ranks) { - TORCH_BIND_ASSERT(topk_idx.dim() == 2); - TORCH_BIND_ASSERT(topk_idx.is_contiguous()); - TORCH_BIND_ASSERT(num_experts > 0); - - const int num_tokens = topk_idx.size(0); - const int num_topk = topk_idx.size(1); - - auto device = topk_idx.device(); - auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device)); - auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device)); - auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); - - EXEC_NPU_CMD(aclnnDispatchLayout, - topk_idx, - num_tokens, - num_ranks, - num_experts, - num_topk, - num_tokens_per_rank, - num_tokens_per_expert, - is_token_in_rank); - - auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool); - - return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool); -} - std::tuple dispatch_prefill( const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, const at::Tensor& num_tokens_per_rank, const at::Tensor& is_token_in_rank, at::Tensor& num_tokens_per_expert, @@ -1018,262 +553,6 @@ std::tuple dispatch_prefill( return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert}; } -at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, - const at::Tensor& src_idx, const at::Tensor& send_head, c10::string_view groupEp, - int64_t rank, int64_t num_ranks) { - std::vector group_ep_chrs(groupEp.begin(), groupEp.end()); - group_ep_chrs.push_back('\0'); - char* group_ep_ptr = &group_ep_chrs[0]; - - TORCH_BIND_ASSERT(x.dim() == 2 and x.is_contiguous()); - at::Tensor recv_x = x; - - at::Tensor topk_idx_p = topk_idx; - - auto topk_idx_int32 = topk_idx_p.to(at::kInt); - at::Tensor expand_ids = topk_idx_int32; - at::Tensor token_src_info = src_idx; - at::Tensor ep_send_counts = send_head; - auto device = x.device(); - - const int num_tokens = topk_idx_p.size(0); - const int num_topk = topk_idx_p.size(1); - - int64_t hidden = static_cast(recv_x.size(1)); - at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device)); - int64_t tp_world_size = 1; - int64_t tp_rankId = 0; - int64_t moe_expert_number = send_head.size(0); - int64_t global_bs = topk_idx_p.size(0) * num_ranks; - - // Combine data - auto combined_x = torch::empty({topk_weights.size(0), hidden}, x.options()); - - EXEC_NPU_CMD(aclnnMoeCombineNormal, - recv_x, - token_src_info, - ep_send_counts, - topk_weights, - tp_send_counts, - group_ep_ptr, - num_ranks, - rank, - group_ep_ptr, - tp_world_size, - tp_rankId, - moe_expert_number, - global_bs, - combined_x); - - return combined_x; -} - -std::tuple npu_moe_init_routing_custom( - const at::Tensor &x, const at::Tensor &expert_idx, - const c10::optional &scale, const c10::optional &offset, int64_t active_num, - int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type, - bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type) -{ - constexpr int64_t DIM_X = 2; - constexpr int64_t DIM_EXPERT_IDX = 2; - constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2; - constexpr int64_t EXPERT_TOKENS_COUNT = 1; - constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2; - constexpr int64_t QUANT_MODE_UNQUANT = -1; - constexpr int64_t QUANT_MODE_DYNAMIC_QUANT = 1; - constexpr int64_t CUMSUM = 0; - constexpr int64_t COUNT = 1; - constexpr int64_t KEY_VALUE = 2; - - if (active_expert_range.empty()) { - active_expert_range = at::IntArrayRef({0, expert_num}); - } - - int64_t x_dim = x.dim(); - TORCH_CHECK(x_dim == DIM_X, "The x should be ", DIM_X, - "-Dimension, current is ", x_dim, "-Dimension."); - - int64_t expert_idx_dim = expert_idx.dim(); - TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX, "The expert_idx should be ", DIM_EXPERT_IDX, - "-Dimension, current is ", expert_idx_dim, "-Dimension."); - - int64_t active_expert_range_length = active_expert_range.size(); - TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE, "The active_expert_range should be ", LENGTH_ACTIVE_EXPERT_RANGE, - "-Dimension, current is ", expert_idx_dim, "-Dimension."); - - int expert_length = active_expert_range[1] - active_expert_range[0]; - auto x_size = x.sizes(); - auto expert_idx_size = expert_idx.sizes(); - - int bs = x_size[0]; - int h = x_size[1]; - int k = expert_idx_size[1]; - int64_t expanded_scale_len = 0; - at::Tensor expanded_x; - - if (drop_pad_mode == 1) { // Drop/Pad - if (quant_mode == QUANT_MODE_UNQUANT) { - expanded_x = at::empty({expert_num, expert_capacity, h}, x.options()); - } else { - expanded_x = at::empty({expert_num, expert_capacity, h}, x.options().dtype(at::kChar)); - } - expanded_scale_len = expert_num * expert_capacity; - } else { // Dropless / Active - if (active_num > 0) { // Active - int64_t num_out_tokens = std::min((int64_t)bs * k, active_num); - if (quant_mode == QUANT_MODE_UNQUANT) { - expanded_x = at::empty({num_out_tokens, h}, x.options()); - } else { - expanded_x = at::empty({num_out_tokens, h}, x.options().dtype(at::kChar)); - } - expanded_scale_len = num_out_tokens; - } else { // Dropless - if (quant_mode == QUANT_MODE_UNQUANT) { - expanded_x = at::empty({bs * k, h}, x.options()); - } else { - expanded_x = at::empty({bs * k, h}, x.options().dtype(at::kChar)); - } - expanded_scale_len = bs * k; - } - } - - at::Tensor expanded_row_idx = at::empty({bs * k}, expert_idx.options()); - at::Tensor expert_tokens_count_or_cumsum; - if (expert_tokens_num_type >= CUMSUM && expert_tokens_num_type <= COUNT) { - // expert_tokens_count_or_cumsum in [end-start, ] - expert_tokens_count_or_cumsum = at::empty({expert_length}, x.options().dtype(at::kLong)); - } else if (expert_tokens_num_type == KEY_VALUE) { - // key_value in [2, end-start] - expert_tokens_count_or_cumsum = at::empty({expert_num, 2}, x.options().dtype(at::kLong)); - } - at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat)); - EXEC_NPU_CMD(aclnnMoeInitRoutingCustom, - x, - expert_idx, - scale, - offset, - active_num, - expert_capacity, - expert_num, - drop_pad_mode, - expert_tokens_num_type, - expert_tokens_num_flag, - quant_mode, - active_expert_range, - row_idx_type, - expanded_x, - expanded_row_idx, - expert_tokens_count_or_cumsum, - expanded_scale); - return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale); -} - -at::Tensor npu_apply_top_k_top_p( - const at::Tensor& logits, - const c10::optional& p, - const c10::optional& k) -{ - TORCH_CHECK(p.has_value() || k.has_value(), - "apply_top_k_top_p: p and k cannot be None at the same time."); - - at::Tensor out = at::empty_like(logits); - - EXEC_NPU_CMD( - aclnnApplyTopKTopPCustom, - logits, - p, - k, - out); - - return out; -} - -std::tuple moe_gating_top_k( - const at::Tensor& x, - int64_t k, - int64_t k_group, - int64_t group_count, - int64_t group_select_mode, - int64_t renorm, - int64_t norm_type, - bool out_flag, - double routed_scaling_factor, - double eps, - const c10::optional& bias_opt - ) -{ - TORCH_CHECK(x.dim() == 2, "The x should be 2D"); - TORCH_CHECK( - x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16, - "float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ", - x.scalar_type()); - - auto x_size = x.sizes(); - auto rows = x_size[0]; - auto expert_num = x_size[1]; - const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); }); - if (bias.defined()) { - TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same"); - TORCH_CHECK(bias.dim() == 1, "The bias should be 1D"); - auto bias_size = bias.sizes(); - TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim"); - } - at::Tensor y = at::empty({rows, k}, x.options()); - at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt)); - at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat)); - - EXEC_NPU_CMD(aclnnMoeGatingTopK, - x, - bias, - k, - k_group, - group_count, - group_select_mode, - renorm, - norm_type, - out_flag, - routed_scaling_factor, - eps, - y, - expert_idx, - out - ); - - return std::tuple(y,expert_idx,out); -} - -std::tuple npu_add_rms_norm_bias( - const at::Tensor& x1, - const at::Tensor& x2, - const at::Tensor& gamma, - const c10::optional &beta, - double epsilon) -{ - int64_t dim_x = x1.dim(); - int64_t dim_gamma = gamma.dim(); - int64_t diff = dim_x - dim_gamma; - std::vector new_shape; - at::Tensor rstd; - - if (diff > 0) { - new_shape.reserve(dim_x); - auto x1_sizes = x1.sizes(); - for (int64_t i = 0; i < diff; ++i) { - new_shape.push_back(x1_sizes[i]); - } - for (int64_t i = 0; i < dim_gamma; ++i) { - new_shape.push_back(1); - } - } else { - new_shape.assign(dim_x, 1); - } - rstd = at::empty(new_shape, x1.options().dtype(at::kFloat)); - at::Tensor y = at::empty(x1.sizes(), x1.options()); - at::Tensor x = at::empty(x1.sizes(), x1.options()); - EXEC_NPU_CMD(aclnnAddRmsNormBias, x1, x2, gamma, beta, epsilon, y, rstd, x); - return std::tuple(y, rstd, x); -} - void transpose_kv_cache_by_block( const at::TensorList &kCache, const at::TensorList &vCache,