Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/ut/eplb/core/test_eplb_device_transfer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_generate_task_and_state_flow(mock_adaptor):
loader_obj.state = loader.ExpertWeightUpdateState.WAITING

loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0)
assert loader_obj.comm_op_list is None
assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING
assert not loader_obj.comm_op_list
assert loader_obj.state == loader.ExpertWeightUpdateState.READY


def test_asyn_transfer_and_update(mock_adaptor):
Expand Down
12 changes: 6 additions & 6 deletions tests/ut/ops/test_expert_load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def test_generate_expert_placement_map(self):
)
self.assertEqual(expert_placement_map.shape,
(self.expert_load_balancer.layers_num,
self.expert_load_balancer.ranks_num, 10))
self.expert_load_balancer.ranks_num, 8))
self.assertTrue(torch.all(expert_placement_map >= -1))

def test_generate_log2phy_expert_map(self):
layer_id = 0
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
layer_id)
self.assertEqual(log2phy_map.shape,
(self.expert_load_balancer.ranks_num, 10))
(self.expert_load_balancer.ranks_num, 8))
self.assertTrue(torch.all(log2phy_map >= -1))

@mock.patch("torch_npu.npu._lazy_init")
Expand All @@ -101,15 +101,15 @@ def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
self.assertEqual(rank_local_expert_num, 5)
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0, -1, -1],
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))

rank_id = 1
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3, -1, -1],
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
Expand All @@ -119,15 +119,15 @@ def test_get_rank_log2phy_map(self):
rank_id = 0
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0, -1, -1],
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))

rank_id = 1
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8, -1, -1],
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))
Expand Down
5 changes: 3 additions & 2 deletions tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,13 @@ def setUp(self):
def test_cumsum_group_list_with_type_0(self):
group_list = self.experts.cumsum(dim=0)
group_list_type = 0
result = cumsum_group_list(group_list, group_list_type)
result = cumsum_group_list(group_list, group_list_type, 0)
self.assertTrue(torch.equal(result, self.group_list))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test suite for cumsum_group_list is incomplete. It's missing a test case for converting from src_list_type=0 (cumulative sum) to dst_list_type=1 (counts). Adding this test case would have caught the critical bug introduced in moe_mlp.py.

Please add a test to cover this conversion. For example:

def test_cumsum_group_list_from_type_0_to_1(self):
    group_list_cumsum = self.experts.cumsum(dim=0)
    result = cumsum_group_list(group_list_cumsum, src_list_type=0, dst_list_type=1)
    self.assertTrue(torch.equal(result, self.experts))


def test_cumsum_group_list_with_type_1(self):
group_list = self.experts
group_list_type = 1
result = cumsum_group_list(group_list, group_list_type)
result = cumsum_group_list(group_list, group_list_type, 0)
self.assertTrue(torch.equal(result, self.group_list))

def test_cumsum_group_list_with_type_2(self):
Expand All @@ -312,6 +312,7 @@ def test_cumsum_group_list_with_type_2(self):
group_list_type = 2
result = cumsum_group_list(group_list,
group_list_type,
0,
active_num=self.active_num,
expert_num=self.expert_num)
self.assertTrue(torch.equal(result, self.group_list))
Expand Down
3 changes: 2 additions & 1 deletion tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))

self.dispatcher.moe_expert_num = len(self.dispatcher.expert_map)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
self.assertIn("tp_send_counts", kwargs)

Expand All @@ -148,6 +148,7 @@ def test_token_combine_with_shared_experts(self):
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.dispatcher.moe_expert_num = len(self.dispatcher.expert_map)
self.hidden_states = torch.randn(10, 128)

with patch("torch_npu.npu_moe_distribute_combine_v2",
Expand Down
51 changes: 28 additions & 23 deletions vllm_ascend/ops/moe/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,39 @@


def cumsum_group_list(group_list: torch.Tensor,
group_list_type: int,
src_list_type: int,
dst_list_type: int,
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
if group_list_type not in [0, 1, 2]:
if src_list_type not in [0, 1, 2]:
raise ValueError(
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
)

if group_list_type == 0:
if src_list_type == dst_list_type:
return group_list
if group_list_type == 1:
if src_list_type == 1 and dst_list_type == 0:
return group_list.cumsum(dim=0)
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
return new_group
Comment on lines +42 to +45
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a bug in the logic for converting group_list from cumulative sum (src_list_type=0) to counts (dst_list_type=1). The implementation incorrectly uses group_diff[0] instead of group_list[0] to construct the new group tensor. This will lead to incorrect counts and subsequent errors in npu_dequant_swiglu_quant.

The previous implementation before this refactoring was correct. You should use group_list[0] to get the first element of the cumulative sum, which corresponds to the first count.

Suggested change
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
return new_group
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
return new_group

if src_list_type == 2 and dst_list_type == 0:
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)

experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]

for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]

return cumsum_group_list
return cumsum_group_list
raise NotImplementedError(
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
"This feature is under development.")


def quant_apply_mlp(hidden_states: torch.Tensor,
Expand Down Expand Up @@ -89,7 +97,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
Expand All @@ -105,17 +113,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
group_diff = torch.diff(group_list)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff],
dim=0)
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=new_group,
group_index=cumsum_group_list(group_list, group_list_type, 1),
activate_left=True,
quant_mode=1,
)
Expand Down Expand Up @@ -148,7 +153,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
x=hidden_states,
weight=w1,
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
Expand Down Expand Up @@ -258,4 +263,4 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales,
need_trans=need_trans)
need_trans=need_trans)
Loading