Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions tests/ut/models/test_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2MLP, CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)

Expand Down Expand Up @@ -213,22 +212,6 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
quant_config=None)


def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2

x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)


@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
Expand Down
22 changes: 0 additions & 22 deletions tests/ut/ops/test_ascend_forwad_context.py

This file was deleted.

68 changes: 67 additions & 1 deletion tests/ut/ops/test_fused_moe_prepare_and_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)


class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
Expand Down Expand Up @@ -216,3 +217,68 @@ def mock_reduce_scatter_func(tensor, dim):
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape[0], 3)

@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce,
mock_get_dp_group):
# Mock forward context with DP metadata
mock_context = MagicMock()
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
[2, 5, 7])
mock_get_forward_context.return_value = mock_context

# Setup DP group mock
mock_dp_group = MagicMock()
mock_dp_group.broadcast = MagicMock()
mock_dp_group.all_reduce = MagicMock()
mock_get_dp_group.return_value = mock_dp_group

# Mock all_reduce to just return input (simulate sum)
def mock_all_reduce(tensor):
return tensor * 2

mock_dp_group.all_reduce.side_effect = mock_all_reduce

# Setup config
self.moe_config.dp_size = 3
self.moe_config.dp_rank = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1

layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)

# Local inputs
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)

# Mock gate for router logits recomputation
mock_gate = MagicMock()
mock_gate.return_value = (torch.randn(7, 2), None)

# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)

# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))
self.assertEqual(r_out.shape, (7, 2))

# Run finalize
result = layer.finalize(h_out, reduce_results=False)

# Should slice back to local: [3, 8]
self.assertEqual(result.shape, (3, 8))

# Test with reduce_results=True and TP/EP > 1
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape, (3, 8))
143 changes: 32 additions & 111 deletions tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase

import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts
Expand Down Expand Up @@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):

@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()

mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)

mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather

mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv

mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
mock_moe_comm_method = MagicMock()

captured_dispatchers = {}
def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits

def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_moe_comm_method.prepare.side_effect = mock_prepare

mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_fused_experts_result = torch.randn(16, 2)
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result

mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
def mock_finalize(hidden_states, **kwargs):
return hidden_states

default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_moe_comm_method.finalize.side_effect = mock_finalize

mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
moe_comm_method=mock_moe_comm_method,
moe_comm_method_name="mc2commimpl",
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
Expand All @@ -131,14 +84,12 @@ def capture_register(dispatcher_instance):
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
Expand All @@ -150,29 +101,29 @@ def capture_register(dispatcher_instance):
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None):

yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
'mock_moe_comm_method': mock_moe_comm_method,
}

mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()


@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
Expand Down Expand Up @@ -338,9 +289,7 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
moe.moe_parallel_config.ep_size = 1

moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
Expand Down Expand Up @@ -394,25 +343,10 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):

global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256

if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']

forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']

with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
Expand All @@ -438,35 +372,22 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
global_num_experts=global_num_experts,
is_prefill=is_prefill)

expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()

expected_shape = (16, 2)
assert result.shape == expected_shape

@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):

ep_size = others_param
is_prefill = False

if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']

forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']

with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):

expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
Expand All @@ -493,8 +414,10 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
expert_map=expert_map,
is_prefill=is_prefill)

expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()

expected_shape = (16, 2)
assert result.shape == expected_shape


Expand Down Expand Up @@ -574,7 +497,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_get_forward_context):

mock_forward_context = MagicMock()
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_forward_context.moe_comm_method_name = "mc2commimpl"
mock_get_forward_context.return_value = mock_forward_context

mock_is_310p.return_value = False
Expand Down Expand Up @@ -618,8 +541,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
with_quant=True)

mock_get_forward_context.assert_called()
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)

mock_npu_dynamic_quant.assert_called()

Expand Down
Loading
Loading