Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/ut/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,13 @@ def setUp(self):
def test_cumsum_group_list_with_type_0(self):
group_list = self.experts.cumsum(dim=0)
group_list_type = 0
result = cumsum_group_list(group_list, group_list_type)
result = cumsum_group_list(group_list, group_list_type, 0)
self.assertTrue(torch.equal(result, self.group_list))

def test_cumsum_group_list_with_type_1(self):
group_list = self.experts
group_list_type = 1
result = cumsum_group_list(group_list, group_list_type)
result = cumsum_group_list(group_list, group_list_type, 0)
self.assertTrue(torch.equal(result, self.group_list))

def test_cumsum_group_list_with_type_2(self):
Expand All @@ -314,6 +314,7 @@ def test_cumsum_group_list_with_type_2(self):
group_list_type = 2
result = cumsum_group_list(group_list,
group_list_type,
0,
active_num=self.active_num,
expert_num=self.expert_num)
self.assertTrue(torch.equal(result, self.group_list))
Expand Down Expand Up @@ -593,4 +594,4 @@ def test_unified_apply_mlp_with_quantization_and_fusion_mlp(

self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.dtype, torch.bfloat16)
52 changes: 31 additions & 21 deletions vllm_ascend/ops/fused_moe/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,39 @@ def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):


def cumsum_group_list(group_list: torch.Tensor,
group_list_type: int,
src_list_type: int,
dst_list_type: int,
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
if group_list_type not in [0, 1, 2]:
if src_list_type not in [0, 1, 2]:
raise ValueError(
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
)

if group_list_type == 0:
if src_list_type == dst_list_type:
return group_list
if group_list_type == 1:
if src_list_type == 1 and dst_list_type == 0:
return group_list.cumsum(dim=0)
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
return new_group
if src_list_type == 2 and dst_list_type == 0:
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)

experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]

for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]

return cumsum_group_list
return cumsum_group_list
raise NotImplementedError(
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
"This feature is under development.")


def quant_apply_mlp(hidden_states: torch.Tensor,
Expand Down Expand Up @@ -100,14 +108,15 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type,
0),
))
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1[0],
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
Expand All @@ -134,7 +143,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
group_index=cumsum_group_list(group_list, group_list_type, 1),
activate_left=True,
quant_mode=1,
)
Expand Down Expand Up @@ -170,7 +179,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type,
0),
bias=bias1,
))
elif fusion and not dynamic_eplb:
Expand All @@ -179,7 +189,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
x=hidden_states,
weight=w1[0],
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
Expand Down
Loading