Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)

sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
context_metadata = dispatch_output.context_metadata

expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
Expand All @@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather(
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)

torch.testing.assert_close(combined_output,
torch.testing.assert_close(combined_output.routed_out,
torch_output,
atol=4e-2,
rtol=1)
Expand Down Expand Up @@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant(
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)

sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
dynamic_scale = dispatch_output.dynamic_scale
context_metadata = dispatch_output.context_metadata

expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
Expand All @@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k)
assert combined_output.routed_out.shape == (m, k)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
Expand Down
16 changes: 9 additions & 7 deletions tests/ut/ops/test_moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)


class TestMoECommMethod(TestBase):
Expand Down Expand Up @@ -178,12 +180,12 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,

# Mock token dispatcher
mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = {
"hidden_states": torch.randn(6, 8),
"group_list": torch.tensor([2, 2, 2]),
"group_list_type": 1
}
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
hidden_states=torch.randn(6, 8),
group_list=torch.tensor([2, 2, 2]),
group_list_type=1)
mock_td_instance.token_combine.return_value = TokenCombineResult(
routed_out=torch.randn(4, 8))
mock_token_dispatcher.return_value = mock_td_instance

# Mock unified_apply_mlp
Expand Down Expand Up @@ -213,7 +215,7 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
activation="silu")

# Verify result shape
self.assertEqual(result.shape, (4, 8))
self.assertEqual(result.routed_out.shape, (4, 8))

# Verify token_dispatch was called
mock_td_instance.token_dispatch.assert_called_once()
Expand Down
90 changes: 26 additions & 64 deletions tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def test_token_permutation_dispatch(self):
topk_weights, topk_ids,
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
0) # group_list_type == 0
self.assertEqual(output.group_list_type, 0) # group_list_type == 0

def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
Expand Down Expand Up @@ -149,43 +148,6 @@ def test_get_combine_mc_kwargs_with_quant(self):
context_metadata)
self.assertIn("tp_send_counts", kwargs)

def test_token_combine_with_shared_experts(self):
shared_experts = MagicMock()
shared_experts.down_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))

topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])

context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": None,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"shared_experts": shared_experts,
"shared_act": torch.randn(10, 128),
"swiglu_out_scale": torch.randn(10, 1),
"tp_recv_counts": tp_recv_counts
}

self.dispatcher.with_quant = True
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True

hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
result = self.dispatcher.token_combine(hidden_states,
context_metadata)
self.assertIsInstance(result, tuple)


class TestTokenDispatcherWithAllGather(TestBase):

Expand Down Expand Up @@ -233,7 +195,7 @@ def test_token_dispatch_without_expert_map(self):
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args

self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)

def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
Expand All @@ -248,7 +210,7 @@ def test_token_dispatch_with_expert_map(self):
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args

self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)

def test_token_dispatch_without_quant(self):
kwargs = {
Expand All @@ -268,7 +230,7 @@ def test_token_dispatch_without_quant(self):
topk_weights, topk_ids,
None)

self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)

def test_token_dispatch_with_quant(self):
kwargs = {
Expand All @@ -290,10 +252,10 @@ def test_token_dispatch_with_quant(self):
None,
with_quant=True)

self.assertIsNotNone(results["hidden_states"])
self.assertIsNotNone(results["group_list"])
self.assertIsNotNone(results["dynamic_scale"])
self.assertEqual(results["group_list_type"], 1)
self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list)
self.assertIsNotNone(results.dynamic_scale)
self.assertEqual(results.group_list_type, 1)

def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128)
Expand All @@ -303,7 +265,7 @@ def test_token_combine_with_expert_map(self):
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
hidden_states, context_metadata).routed_out
self.assertEqual(final_hidden_states.shape, (6, 128))

def test_token_combine_without_expert_map(self):
Expand All @@ -314,7 +276,7 @@ def test_token_combine_without_expert_map(self):
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
hidden_states, context_metadata).routed_out
self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128))

Expand All @@ -326,7 +288,7 @@ def test_token_dispatch_with_router_weight(self):

results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
self.assertEqual(results.hidden_states.shape, (6, 128))


class TestTokenDispatcherWithAll2AllV(TestBase):
Expand Down Expand Up @@ -437,9 +399,9 @@ def test_token_dispatch(self):
topk_ids=topk_ids,
expert_map=expert_map)

self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)

def test_token_combine(self):
hidden_states = torch.randn(16, 16)
Expand All @@ -458,7 +420,7 @@ def test_token_combine(self):

output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
self.assertEqual(output.routed_out.shape, (8, 16))

def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
Expand All @@ -480,10 +442,10 @@ def test_token_dispatch_with_quant(self):
expert_map=expert_map,
with_quant=True)

self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)

def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
Expand All @@ -508,10 +470,10 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
expert_map=expert_map,
with_quant=True)

self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)

def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
Expand All @@ -530,6 +492,6 @@ def test_token_dispatch_with_log2phy(self):
expert_map=expert_map,
log2phy=log2phy)

self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)
39 changes: 20 additions & 19 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
set_flash_common3_context)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.w4a8_dynamic import \
Expand Down Expand Up @@ -325,7 +326,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
pertoken_scale = None

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
Expand All @@ -350,25 +351,25 @@ def forward_impl(self, hidden_states: torch.Tensor,
global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)

if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb:

moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()

moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)

final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
if self.dynamic_eplb:
expert_tokens = fused_experts_results.expert_tokens
group_list_type = fused_experts_results.group_list_type
assert expert_tokens is not None and group_list_type is not None, \
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()
moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)

routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata)

return final_hidden_states
return routed_out


class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
Expand Down Expand Up @@ -439,7 +440,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
else:
set_flash_common3_context(shared_experts=self._shared_experts)

fused_output = AscendFusedMoE.forward_impl(
routed_out = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
Expand All @@ -462,4 +463,4 @@ def forward_impl(self, hidden_states: torch.Tensor,
assert fc3_context is not None
shared_out = fc3_context.shared_out

return shared_out, fused_output
return shared_out, routed_out
Loading
Loading