-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Refactor the ops PyTorch adapter,cleanup for csrc/torch_binding.cpp #6732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
da98c71
[P/D] layerwise connector support recompute scheduler (#5900)
liziyu179 11c63d7
test
luomin2005 075bd99
Add Worker Interface:check_health
luomin2005 e7f5357
Merge branch 'main' of https://github.com/luomin2005/vllm-ascend
luomin2005 d259d14
Add Worker Interface:check_health
luomin2005 f441a03
Add Worker Interface:check_health
luomin2005 e327caf
Merge branch 'vllm-project:main' into main
luomin2005 32fe4fd
Refactor the ops pytorch adpater
luomin2005 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef ADD_RMS_NORM_BIAS_TORCH_ADPT_H | ||
| #define ADD_RMS_NORM_BIAS_TORCH_ADPT_H | ||
|
|
||
| namespace vllm_ascend { | ||
|
|
||
| std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias( | ||
| const at::Tensor& x1, | ||
| const at::Tensor& x2, | ||
| const at::Tensor& gamma, | ||
| const c10::optional<at::Tensor> &beta, | ||
| double epsilon) | ||
| { | ||
| int64_t dim_x = x1.dim(); | ||
| int64_t dim_gamma = gamma.dim(); | ||
| int64_t diff = dim_x - dim_gamma; | ||
| std::vector<int64_t> new_shape; | ||
| at::Tensor rstd; | ||
|
|
||
| if (diff > 0) { | ||
| new_shape.reserve(dim_x); | ||
| auto x1_sizes = x1.sizes(); | ||
| for (int64_t i = 0; i < diff; ++i) { | ||
| new_shape.push_back(x1_sizes[i]); | ||
| } | ||
| for (int64_t i = 0; i < dim_gamma; ++i) { | ||
| new_shape.push_back(1); | ||
| } | ||
| } else { | ||
| new_shape.assign(dim_x, 1); | ||
| } | ||
| rstd = at::empty(new_shape, x1.options().dtype(at::kFloat)); | ||
| at::Tensor y = at::empty(x1.sizes(), x1.options()); | ||
| at::Tensor x = at::empty(x1.sizes(), x1.options()); | ||
| EXEC_NPU_CMD(aclnnAddRmsNormBias, x1, x2, gamma, beta, epsilon, y, rstd, x); | ||
| return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x); | ||
| } | ||
| } | ||
| #endif |
40 changes: 40 additions & 0 deletions
40
csrc/apply_top_k_top_p_custom/apply_top_k_top_p_custom_torch_adpt.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef APPLY_TOP_K_TOP_P_CUSTOM_TORCH_ADPT_H | ||
| #define APPLY_TOP_K_TOP_P_CUSTOM_TORCH_ADPT_H | ||
|
|
||
| namespace vllm_ascend { | ||
| at::Tensor npu_apply_top_k_top_p( | ||
| const at::Tensor& logits, | ||
| const c10::optional<at::Tensor>& p, | ||
| const c10::optional<at::Tensor>& k) | ||
| { | ||
| TORCH_CHECK(p.has_value() || k.has_value(), | ||
| "apply_top_k_top_p: p and k cannot be None at the same time."); | ||
|
|
||
| at::Tensor out = at::empty_like(logits); | ||
|
|
||
| EXEC_NPU_CMD( | ||
| aclnnApplyTopKTopPCustom, | ||
| logits, | ||
| p, | ||
| k, | ||
| out); | ||
|
|
||
| return out; | ||
| } | ||
| } | ||
| #endif |
54 changes: 54 additions & 0 deletions
54
csrc/batch_matmul_transpose/batch_matmul_transpose_torch_adpt.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef BATCH_MATMUL_TRANSPOSE_TORCH_ADPT_H | ||
| #define BATCH_MATMUL_TRANSPOSE_TORCH_ADPT_H | ||
| #include "op_host/batch_matmul_transpose.h" | ||
|
|
||
| namespace vllm_ascend { | ||
|
|
||
| void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, | ||
| c10::optional<c10::string_view> format_mode, | ||
| c10::optional<c10::string_view> quant_mode) | ||
| { | ||
| auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling( | ||
| tensor_a, | ||
| tensor_b, | ||
| tensor_c, | ||
| format_mode, | ||
| quant_mode | ||
| ); | ||
|
|
||
| void *gm_a = tensor_a.data_ptr(); | ||
| void *gm_b = tensor_b.data_ptr(); | ||
| void *gm_c = tensor_c.data_ptr(); | ||
| void *gm_tiling_data = tiling_tensor.data_ptr(); | ||
|
|
||
| aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); | ||
| at_npu::native::OpCommand cmd; | ||
| cmd.Name("batch_matmul_transpose"); | ||
|
|
||
| cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data, | ||
| block_dim]() -> int { | ||
| batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data, | ||
| block_dim); | ||
| return 0; | ||
| }); | ||
| cmd.Run(); | ||
| return; | ||
| } | ||
|
|
||
| } | ||
| #endif |
65 changes: 65 additions & 0 deletions
65
csrc/dispatch_ffn_combine/dispatch_ffn_combine_torch_adpt.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef DISPATCH_FFN_COMBINE_TORCH_ADPT_H | ||
| #define DISPATCH_FFN_COMBINE_TORCH_ADPT_H | ||
|
|
||
| namespace vllm_ascend { | ||
| std::tuple<at::Tensor&, at::Tensor&> dispatch_ffn_combine( | ||
| const at::Tensor& x, | ||
| const at::TensorList& weight1, | ||
| const at::TensorList& weight2, | ||
| const at::Tensor& expert_idx, | ||
| const at::TensorList& scale1, | ||
| const at::TensorList& scale2, | ||
| const at::Tensor& probs, | ||
| c10::string_view group, | ||
| int64_t max_output_size, | ||
| at::Tensor& out, | ||
| at::Tensor& expert_token_nums | ||
| ) { | ||
| char *group_ep_ptr = const_cast<char *>(group.data()); | ||
| bool is_int8 = weight1[0].dtype() == at::kChar; | ||
| if (is_int8) { | ||
| EXEC_NPU_CMD(aclnnDispatchFFNCombine, | ||
| x, | ||
| weight1, | ||
| weight2, | ||
| expert_idx, | ||
| scale1, | ||
| scale2, | ||
| probs, | ||
| group_ep_ptr, | ||
| max_output_size, | ||
| out, | ||
| expert_token_nums); | ||
| } else { | ||
| EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16, | ||
| x, | ||
| weight1, | ||
| weight2, | ||
| expert_idx, | ||
| scale1, | ||
| scale2, | ||
| probs, | ||
| group_ep_ptr, | ||
| max_output_size, | ||
| out, | ||
| expert_token_nums); | ||
| } | ||
| return {out, expert_token_nums}; | ||
| } | ||
| } | ||
| #endif | ||
83 changes: 83 additions & 0 deletions
83
csrc/dispatch_gmm_combine_decode/dispatch_gmm_combine_decode_torch_adpt.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef DISPATCH_GMM_COMBINE_TORCH_ADPT_H | ||
| #define DISPATCH_GMM_COMBINE_TORCH_ADPT_H | ||
|
|
||
| namespace vllm_ascend { | ||
|
|
||
| std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode( | ||
| const at::Tensor &x, | ||
| const at::Tensor &expert_ids, | ||
| const at::TensorList &gmm1_permuted_weight, | ||
| const at::TensorList &gmm1_permuted_weight_scale, | ||
| const at::TensorList &gmm2_weight, | ||
| const at::TensorList &gmm2_weight_scale, | ||
| const at::Tensor &expert_scales, | ||
| const c10::optional<at::Tensor> &expert_smooth_scales, | ||
| const c10::optional<at::Tensor> &x_active_mask, | ||
| c10::string_view group_ep, | ||
| int64_t ep_rank_size, | ||
| int64_t ep_rank_id, | ||
| int64_t moe_expert_num, | ||
| int64_t shared_expert_num, | ||
| int64_t shared_expert_rank_num, | ||
| int64_t quant_mode, | ||
| int64_t global_bs) | ||
| { | ||
| auto x_shape = x.sizes(); | ||
| int bs = x_shape[0]; | ||
| int h = x_shape[1]; | ||
|
|
||
| at::Tensor output = at::empty({bs, h}, x.options()); | ||
|
|
||
| bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); | ||
| int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); | ||
| auto opts = expert_ids.options().dtype(at::kLong); | ||
| at::Tensor expert_token_nums = at::empty({num_local_experts}, opts); | ||
|
|
||
| vector<char> group_ep_chrs(group_ep.begin(), group_ep.end()); | ||
| group_ep_chrs.push_back('\0'); | ||
| char *group_ep_ptr = &group_ep_chrs[0]; | ||
| EXEC_NPU_CMD( | ||
| // op api | ||
| aclnnDispatchGmmCombineDecode, | ||
| // input tensors | ||
| x, | ||
| expert_ids, | ||
| gmm1_permuted_weight, | ||
| gmm1_permuted_weight_scale, | ||
| gmm2_weight, | ||
| gmm2_weight_scale, | ||
| expert_scales, | ||
| expert_smooth_scales, | ||
| x_active_mask, | ||
| //input attrs | ||
| group_ep_ptr, | ||
| ep_rank_size, | ||
| ep_rank_id, | ||
| moe_expert_num, | ||
| shared_expert_num, | ||
| shared_expert_rank_num, | ||
| quant_mode, | ||
| global_bs, | ||
| // output tensors | ||
| output, | ||
| expert_token_nums); | ||
| return {output, expert_token_nums}; | ||
| } | ||
|
|
||
| } | ||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef DISPATCH_LAYOUT_TORCH_ADPT_H | ||
| #define DISPATCH_LAYOUT_TORCH_ADPT_H | ||
|
|
||
| namespace vllm_ascend { | ||
| std::tuple<at::Tensor, at::Tensor, at::Tensor> get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts, | ||
| int64_t num_ranks) { | ||
| TORCH_BIND_ASSERT(topk_idx.dim() == 2); | ||
| TORCH_BIND_ASSERT(topk_idx.is_contiguous()); | ||
| TORCH_BIND_ASSERT(num_experts > 0); | ||
|
|
||
| const int num_tokens = topk_idx.size(0); | ||
| const int num_topk = topk_idx.size(1); | ||
|
|
||
| auto device = topk_idx.device(); | ||
| auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device)); | ||
| auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device)); | ||
| auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); | ||
|
|
||
| EXEC_NPU_CMD(aclnnDispatchLayout, | ||
| topk_idx, | ||
| num_tokens, | ||
| num_ranks, | ||
| num_experts, | ||
| num_topk, | ||
| num_tokens_per_rank, | ||
| num_tokens_per_expert, | ||
| is_token_in_rank); | ||
|
|
||
| auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool); | ||
|
|
||
| return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool); | ||
| } | ||
|
|
||
| } | ||
| #endif |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
const_caston a pointer obtained fromstring_view::data()is unsafe. It can lead to undefined behavior if the underlying data is modified through the resulting pointer. A safer approach is to create a mutable copy of the string data.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed