-
Notifications
You must be signed in to change notification settings - Fork 617
[Perf] Optimize fused_experts quantization code to save npu memory #784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| # | ||
|
|
||
| import os | ||
| from typing import Any, Callable, Dict, Optional | ||
| from typing import Any, Callable, Dict, List, Optional | ||
|
|
||
| import torch | ||
| import torch_npu | ||
|
|
@@ -25,7 +25,7 @@ | |
| from vllm_ascend.ops.fused_moe import select_experts | ||
|
|
||
|
|
||
| def apply_mlp(x: torch.Tensor, | ||
| def apply_mlp(hidden_states_wrapper: List[torch.Tensor], | ||
| w1: torch.Tensor, | ||
| w1_scale: torch.Tensor, | ||
| w2: torch.Tensor, | ||
|
|
@@ -37,7 +37,7 @@ def apply_mlp(x: torch.Tensor, | |
| apply MLP: gate_up_proj -> swiglu -> down_proj | ||
|
|
||
| Args: | ||
| x: input hidden states with shape (num_tokens, hidden_size). | ||
| hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). | ||
| w1: expert weights1 with shape | ||
| (num_experts, hidden_size, intermediate_size * 2) | ||
| w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) | ||
|
|
@@ -56,23 +56,27 @@ def apply_mlp(x: torch.Tensor, | |
| hidden_states: output hidden states after MLP. | ||
| """ | ||
|
|
||
| assert len(hidden_states_wrapper) == 1 | ||
| hidden_states = hidden_states_wrapper.pop() | ||
| if dynamic_scale is None: | ||
| h, pertoken_scale = torch_npu.npu_dynamic_quant(x) | ||
| hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( | ||
| hidden_states) | ||
| else: | ||
| h = x | ||
| pertoken_scale = dynamic_scale | ||
|
|
||
| # gmm1: gate_up_proj | ||
| gate_up_out = torch_npu.npu_grouped_matmul(x=[h], | ||
| weight=[w1], | ||
| split_item=3, | ||
| group_list_type=group_list_type, | ||
| group_type=0, | ||
| group_list=group_list, | ||
| output_dtype=torch.int32)[0] | ||
|
|
||
| swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( | ||
| x=gate_up_out, | ||
| hidden_states = torch_npu.npu_grouped_matmul( | ||
| x=[hidden_states], | ||
| weight=[w1], | ||
| split_item=3, | ||
| group_list_type=group_list_type, | ||
| group_type=0, | ||
| group_list=group_list, | ||
| output_dtype=torch.int32)[0] | ||
|
|
||
| # act_fn: swiglu | ||
| hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( | ||
| x=hidden_states, | ||
| weight_scale=w1_scale, | ||
| activation_scale=pertoken_scale, | ||
| bias=None, | ||
|
|
@@ -83,17 +87,18 @@ def apply_mlp(x: torch.Tensor, | |
| quant_mode=1, | ||
| ) | ||
|
|
||
| # down_proj | ||
| down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out], | ||
| weight=[w2], | ||
| scale=[w2_scale], | ||
| per_token_scale=[swiglu_out_scale], | ||
| split_item=2, | ||
| group_list_type=group_list_type, | ||
| group_type=0, | ||
| group_list=group_list, | ||
| output_dtype=w2_scale.dtype)[0] | ||
| return down_out | ||
| # gmm2: down_proj | ||
| hidden_states = torch_npu.npu_grouped_matmul( | ||
| x=[hidden_states], | ||
| weight=[w2], | ||
| scale=[w2_scale], | ||
| per_token_scale=[swiglu_out_scale], | ||
| split_item=2, | ||
| group_list_type=group_list_type, | ||
| group_type=0, | ||
| group_list=group_list, | ||
| output_dtype=w2_scale.dtype)[0] | ||
| return hidden_states | ||
|
|
||
|
|
||
| def fused_experts_with_mc2( | ||
|
|
@@ -152,7 +157,11 @@ def fused_experts_with_mc2( | |
| if quant_mode == 0: | ||
| dynamic_scale = None | ||
|
|
||
| down_out_list = apply_mlp(expand_x, | ||
| # place hidden_states in a list to transfer its ownership into the `apply_mlp` function | ||
| hidden_states_wrapper = [expand_x] | ||
| del expand_x | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There still a lot of del action everywhere. We can create a new func like |
||
|
|
||
| down_out_list = apply_mlp(hidden_states_wrapper, | ||
| w1, | ||
| w1_scale, | ||
| w2, | ||
|
|
@@ -246,7 +255,7 @@ def fused_experts(hidden_states: torch.Tensor, | |
| token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) | ||
| expert_tokens = token_counts[:num_experts] | ||
| # Rearrange hidden_states | ||
| sorted_hidden_states = hidden_states[sorted_token_indices] | ||
| hidden_states = hidden_states[sorted_token_indices] | ||
| group_list_type = 1 | ||
| else: | ||
| row_idx_len = num_tokens * top_k | ||
|
|
@@ -255,19 +264,22 @@ def fused_experts(hidden_states: torch.Tensor, | |
| dtype=torch.int32, | ||
| device=topk_weights.device).view( | ||
| top_k, -1).permute(1, 0).contiguous() | ||
| sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( | ||
| hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( | ||
| hidden_states, | ||
| row_idx=row_idx, | ||
| expert_idx=topk_ids, | ||
| active_num=num_tokens) | ||
| del hidden_states | ||
|
|
||
| expert_tokens = torch_npu.npu_moe_compute_expert_tokens( | ||
| expanded_expert_idx, num_experts) | ||
| expert_tokens = expert_tokens.to(torch.int64) | ||
| group_list_type = 0 | ||
|
|
||
| down_out_list = apply_mlp(sorted_hidden_states, | ||
| # place hidden_states in a list to transfer its ownership into the `apply_mlp` function | ||
| hidden_states_wrapper = [hidden_states] | ||
| del hidden_states | ||
|
|
||
| hidden_states = apply_mlp(hidden_states_wrapper, | ||
| w1, | ||
| w1_scale, | ||
| w2, | ||
|
|
@@ -276,31 +288,31 @@ def fused_experts(hidden_states: torch.Tensor, | |
| group_list_type=group_list_type) | ||
|
|
||
| if expert_map is not None: | ||
| down_out_list.mul_(sorted_weights.unsqueeze(1)) | ||
| hidden_states.mul_(sorted_weights.unsqueeze(1)) | ||
| final_hidden_states = torch.zeros(*original_shape, | ||
| device=hidden_states.device, | ||
| device=device, | ||
| dtype=dtype) | ||
|
|
||
| num_valid_tokens = mask.sum() | ||
| valid_token_mask = torch.arange( | ||
| 0, sorted_token_indices.shape[0], | ||
| device=device).unsqueeze(1) < num_valid_tokens | ||
| down_out_list = down_out_list.masked_fill_(~valid_token_mask, | ||
| hidden_states = hidden_states.masked_fill_(~valid_token_mask, | ||
| 0).to(dtype) | ||
| final_hidden_states.index_add_(0, sorted_token_indices, down_out_list) | ||
| final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) | ||
| else: | ||
| # TODO: Reorder device memory 2 times here, replace the current | ||
| # implementation here when suitable operators become available. | ||
| final_hidden_states = torch_npu.npu_moe_finalize_routing( | ||
| down_out_list, | ||
| hidden_states, | ||
| skip1=None, | ||
| skip2=None, | ||
| bias=None, | ||
| scales=topk_weights, | ||
| expanded_src_to_dst_row=expanded_row_idx, | ||
| export_for_source_row=topk_ids, | ||
| ) | ||
| del down_out_list | ||
|
|
||
| if len(original_shape) == 3: | ||
| final_hidden_states = final_hidden_states.view(original_shape) | ||
| return final_hidden_states | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with this quick fix. you have describe the case more clear later. For example, this
copy to listaction aims to deal with the lifecycle for the tensor so that the memory usage can be improved.