Skip to content

Commit a54cfef

Browse files
committed
fix test
Signed-off-by: jiahanc <[email protected]>
1 parent 20bd60b commit a54cfef

File tree

1 file changed

+56
-24
lines changed

1 file changed

+56
-24
lines changed

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,32 @@ def cache_permute_indices():
18341834
return _cache_permute_indices
18351835

18361836

1837+
@pytest.fixture(autouse=True)
1838+
def clear_cache_between_test_functions(cache_permute_indices, request):
1839+
"""Automatically clear cache when switching between different test functions.
1840+
1841+
This keeps the cache within the same test function (across all parametrized runs)
1842+
but clears it when moving to a different test function.
1843+
"""
1844+
# Get the base test function name (without parameters)
1845+
test_function_name = request.node.originalname or request.node.name.split("[")[0]
1846+
1847+
# Store the current test function name in the module
1848+
if not hasattr(request.module, "_current_test_function"):
1849+
request.module._current_test_function = test_function_name
1850+
elif request.module._current_test_function != test_function_name:
1851+
# We've switched to a different test function, clear the cache and GPU state
1852+
cache_permute_indices.clear()
1853+
request.module._current_test_function = test_function_name
1854+
1855+
# Synchronize and clear GPU memory/cache
1856+
torch.cuda.synchronize()
1857+
torch.cuda.empty_cache()
1858+
1859+
yield # Run the test
1860+
# No cleanup needed here - we clear at the start of the next different function
1861+
1862+
18371863
def skip_checks(
18381864
moe_impl,
18391865
routing_config,
@@ -2313,18 +2339,21 @@ def test_renormalize_routing(
23132339
@pytest.mark.parametrize(
23142340
"routing_config",
23152341
[
2316-
{
2317-
"num_experts": 16,
2318-
"top_k": 2,
2319-
"padding": 8,
2320-
"n_groups": None,
2321-
"top_k_groups": None,
2322-
"routed_scaling": None,
2323-
"has_routing_bias": False,
2324-
"routing_method_type": RoutingMethodType.TopK,
2325-
"compatible_moe_impls": [FP4Moe],
2326-
"compatible_intermediate_size": [384, 512, 768, 1024],
2327-
},
2342+
pytest.param(
2343+
{
2344+
"num_experts": 16,
2345+
"top_k": 2,
2346+
"padding": 8,
2347+
"n_groups": None,
2348+
"top_k_groups": None,
2349+
"routed_scaling": None,
2350+
"has_routing_bias": False,
2351+
"routing_method_type": RoutingMethodType.TopK,
2352+
"compatible_moe_impls": [FP4Moe],
2353+
"compatible_intermediate_size": [384, 512, 768, 1024],
2354+
},
2355+
id="TopK",
2356+
),
23282357
],
23292358
)
23302359
@pytest.mark.parametrize(
@@ -2383,18 +2412,21 @@ def test_topk_routing(
23832412
@pytest.mark.parametrize(
23842413
"routing_config",
23852414
[
2386-
{
2387-
"num_experts": 128,
2388-
"top_k": 1,
2389-
"padding": 8,
2390-
"n_groups": 0,
2391-
"top_k_groups": 0,
2392-
"routed_scaling": 2.5,
2393-
"has_routing_bias": True,
2394-
"routing_method_type": RoutingMethodType.Llama4,
2395-
"compatible_moe_impls": [FP8PerTensorMoe],
2396-
"compatible_intermediate_size": [1024, 2048],
2397-
},
2415+
pytest.param(
2416+
{
2417+
"num_experts": 128,
2418+
"top_k": 1,
2419+
"padding": 8,
2420+
"n_groups": 0,
2421+
"top_k_groups": 0,
2422+
"routed_scaling": 2.5,
2423+
"has_routing_bias": True,
2424+
"routing_method_type": RoutingMethodType.Llama4,
2425+
"compatible_moe_impls": [FP8PerTensorMoe],
2426+
"compatible_intermediate_size": [1024, 2048],
2427+
},
2428+
id="Llama4",
2429+
),
23982430
],
23992431
)
24002432
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)