-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Bugfix] bugfix for moe_mlp in vllm-ascend/v0.11.0-dev #4885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ce1e68e
1b47203
559fc6b
1fc4005
42ab319
ac2a72c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,31 +26,39 @@ | |||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def cumsum_group_list(group_list: torch.Tensor, | ||||||||||||||||||
| group_list_type: int, | ||||||||||||||||||
| src_list_type: int, | ||||||||||||||||||
| dst_list_type: int, | ||||||||||||||||||
| active_num: int = 0, | ||||||||||||||||||
| expert_num: int = 0) -> torch.Tensor: | ||||||||||||||||||
| if group_list_type not in [0, 1, 2]: | ||||||||||||||||||
| if src_list_type not in [0, 1, 2]: | ||||||||||||||||||
| raise ValueError( | ||||||||||||||||||
| f"group_list_type should be in [0, 1, 2], but received {group_list_type}" | ||||||||||||||||||
| f"group_list_type should be in [0, 1, 2], but received {src_list_type}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| if group_list_type == 0: | ||||||||||||||||||
| if src_list_type == dst_list_type: | ||||||||||||||||||
| return group_list | ||||||||||||||||||
| if group_list_type == 1: | ||||||||||||||||||
| if src_list_type == 1 and dst_list_type == 0: | ||||||||||||||||||
| return group_list.cumsum(dim=0) | ||||||||||||||||||
| if src_list_type == 0 and dst_list_type == 1: | ||||||||||||||||||
| group_diff = torch.diff(group_list) | ||||||||||||||||||
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | ||||||||||||||||||
| return new_group | ||||||||||||||||||
|
Comment on lines
+42
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a bug in the logic for converting The previous implementation before this refactoring was correct. You should use
Suggested change
|
||||||||||||||||||
| if src_list_type == 2 and dst_list_type == 0: | ||||||||||||||||||
| experts = pad(group_list[:, 0], (1, 0)) | ||||||||||||||||||
| tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) | ||||||||||||||||||
| cumsum_group_list = torch.full(size=(expert_num, ), | ||||||||||||||||||
| fill_value=active_num, | ||||||||||||||||||
| dtype=group_list.dtype, | ||||||||||||||||||
| device=group_list.device) | ||||||||||||||||||
|
|
||||||||||||||||||
| experts = pad(group_list[:, 0], (1, 0)) | ||||||||||||||||||
| tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) | ||||||||||||||||||
| cumsum_group_list = torch.full(size=(expert_num, ), | ||||||||||||||||||
| fill_value=active_num, | ||||||||||||||||||
| dtype=group_list.dtype, | ||||||||||||||||||
| device=group_list.device) | ||||||||||||||||||
| for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): | ||||||||||||||||||
| if end > start: | ||||||||||||||||||
| cumsum_group_list[start:end] = tokens[i] | ||||||||||||||||||
|
|
||||||||||||||||||
| for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): | ||||||||||||||||||
| if end > start: | ||||||||||||||||||
| cumsum_group_list[start:end] = tokens[i] | ||||||||||||||||||
|
|
||||||||||||||||||
| return cumsum_group_list | ||||||||||||||||||
| return cumsum_group_list | ||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||
| f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. " | ||||||||||||||||||
| "This feature is under development.") | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def quant_apply_mlp(hidden_states: torch.Tensor, | ||||||||||||||||||
|
|
@@ -89,7 +97,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, | |||||||||||||||||
| hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( | ||||||||||||||||||
| x=hidden_states, | ||||||||||||||||||
| weight=w1, | ||||||||||||||||||
| group_list=cumsum_group_list(group_list, group_list_type), | ||||||||||||||||||
| group_list=cumsum_group_list(group_list, group_list_type, 0), | ||||||||||||||||||
| weight_scale=w1_scale, | ||||||||||||||||||
| x_scale=pertoken_scale) | ||||||||||||||||||
| else: | ||||||||||||||||||
|
|
@@ -105,17 +113,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor, | |||||||||||||||||
| group_list=group_list, | ||||||||||||||||||
| output_dtype=torch.int32)[0] | ||||||||||||||||||
| # act_fn: swiglu | ||||||||||||||||||
| group_diff = torch.diff(group_list) | ||||||||||||||||||
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], | ||||||||||||||||||
| dim=0) | ||||||||||||||||||
| hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( | ||||||||||||||||||
| x=hidden_states, | ||||||||||||||||||
| weight_scale=w1_scale, | ||||||||||||||||||
| activation_scale=pertoken_scale, | ||||||||||||||||||
| bias=None, | ||||||||||||||||||
| quant_scale=None, | ||||||||||||||||||
| quant_offset=None, | ||||||||||||||||||
| group_index=new_group, | ||||||||||||||||||
| group_index=cumsum_group_list(group_list, group_list_type, 1), | ||||||||||||||||||
| activate_left=True, | ||||||||||||||||||
| quant_mode=1, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -148,7 +153,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, | |||||||||||||||||
| x=hidden_states, | ||||||||||||||||||
| weight=w1, | ||||||||||||||||||
| bias=bias1, | ||||||||||||||||||
| group_list=cumsum_group_list(group_list, group_list_type), | ||||||||||||||||||
| group_list=cumsum_group_list(group_list, group_list_type, 0), | ||||||||||||||||||
| weight_scale=w1_scale, | ||||||||||||||||||
| x_scale=pertoken_scale) | ||||||||||||||||||
| else: | ||||||||||||||||||
|
|
@@ -258,4 +263,4 @@ def unified_apply_mlp(hidden_states: torch.Tensor, | |||||||||||||||||
| group_list=group_list, | ||||||||||||||||||
| group_list_type=group_list_type, | ||||||||||||||||||
| topk_scales=topk_scales, | ||||||||||||||||||
| need_trans=need_trans) | ||||||||||||||||||
| need_trans=need_trans) | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test suite for
cumsum_group_listis incomplete. It's missing a test case for converting fromsrc_list_type=0(cumulative sum) todst_list_type=1(counts). Adding this test case would have caught the critical bug introduced inmoe_mlp.py.Please add a test to cover this conversion. For example: