|
42 | 42 | EP_SIZE = [1, 4] |
43 | 43 | TOP_KS = [2, 6] |
44 | 44 |
|
| 45 | +FUSED_MOE_MNK_FACTORS = [ |
| 46 | + (1, 128, 128), |
| 47 | + (1, 2048, 128), |
| 48 | + (33, 2048, 128), |
| 49 | + (222, 1024, 1024), |
| 50 | + (32768, 128, 128), |
| 51 | + (32768, 2048, 511), |
| 52 | + (40000, 1024, 1024), |
| 53 | +] |
| 54 | + |
| 55 | +FUSED_MOE_WN16_MNK_FACTORS = [ |
| 56 | + (1, 128, 128), |
| 57 | + (1, 1024, 1024), |
| 58 | + (32, 2048, 128), |
| 59 | + (32, 1024, 1024), |
| 60 | + (222, 2048, 1024), |
| 61 | +] |
| 62 | + |
45 | 63 | vllm_config = VllmConfig() |
46 | 64 | vllm_config.scheduler_config.max_num_seqs = 128 |
47 | 65 | vllm_config.scheduler_config.max_model_len = 8192 |
@@ -116,13 +134,11 @@ def run_moe_test( |
116 | 134 | return baseline_output |
117 | 135 |
|
118 | 136 |
|
119 | | -@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) |
120 | | -@pytest.mark.parametrize("n", [128, 1024, 2048]) |
121 | | -@pytest.mark.parametrize("k", [128, 511, 1024]) |
| 137 | +@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) |
122 | 138 | @pytest.mark.parametrize("e", NUM_EXPERTS) |
123 | 139 | @pytest.mark.parametrize("topk", TOP_KS) |
124 | 140 | @pytest.mark.parametrize("ep_size", EP_SIZE) |
125 | | -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 141 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
126 | 142 | @pytest.mark.parametrize("padding", [True, False]) |
127 | 143 | @pytest.mark.parametrize("chunk_size", [8192]) |
128 | 144 | def test_fused_moe( |
@@ -235,13 +251,11 @@ def m_fused_moe( |
235 | 251 | use_cudagraph=use_cudagraph) |
236 | 252 |
|
237 | 253 |
|
238 | | -@pytest.mark.parametrize("m", [1, 32, 222]) |
239 | | -@pytest.mark.parametrize("n", [128, 1024, 2048]) |
240 | | -@pytest.mark.parametrize("k", [128, 1024]) |
| 254 | +@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) |
241 | 255 | @pytest.mark.parametrize("e", NUM_EXPERTS) |
242 | 256 | @pytest.mark.parametrize("topk", TOP_KS) |
243 | 257 | @pytest.mark.parametrize("ep_size", EP_SIZE) |
244 | | -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 258 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
245 | 259 | @pytest.mark.parametrize("group_size", [64, 128]) |
246 | 260 | @pytest.mark.parametrize("has_zp", [True, False]) |
247 | 261 | @pytest.mark.parametrize("weight_bits", [4, 8]) |
@@ -352,8 +366,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, |
352 | 366 | torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) |
353 | 367 |
|
354 | 368 |
|
355 | | -@pytest.mark.parametrize("dtype", |
356 | | - [torch.float32, torch.float16, torch.bfloat16]) |
| 369 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
357 | 370 | @pytest.mark.parametrize("padding", [True, False]) |
358 | 371 | @pytest.mark.parametrize( |
359 | 372 | "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) |
|
0 commit comments