|
15 | 15 |
|
16 | 16 | import pytest |
17 | 17 |
|
| 18 | +from vllm_ascend.ascend_forward_context import MoECommType |
18 | 19 | from vllm_ascend.utils import AscendSocVersion |
19 | 20 | from vllm_ascend.worker.model_runner_v1 import NPUModelRunner |
20 | 21 |
|
|
24 | 25 | "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method", |
25 | 26 | [ |
26 | 27 | # Case 1: Expert parallel is disabled, should always be 'allgather' |
27 | | - (AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"), |
28 | | - (AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"), |
| 28 | + (AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER), |
| 29 | + (AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER), |
29 | 30 |
|
30 | 31 | # Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2 |
31 | | - (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"), |
32 | | - (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"), |
33 | | - (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition |
| 32 | + (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL), |
| 33 | + (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL), |
| 34 | + (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition |
34 | 35 |
|
35 | 36 | # Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather |
36 | | - (AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"), |
37 | | - (AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"), |
| 37 | + (AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER), |
| 38 | + (AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER), |
38 | 39 |
|
39 | 40 | # Case 4: A3 SOC |
40 | | - (AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"), |
41 | | - (AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"), |
| 41 | + (AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2), |
| 42 | + (AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL), |
42 | 43 | ]) |
43 | 44 | # yapf: enable |
44 | 45 | def test_select_moe_comm_method(soc_version, enable_expert_parallel, |
|
0 commit comments