Skip to content

Commit 46cf8ca

Browse files
committed
fix test
Signed-off-by: jiahanc <[email protected]>
1 parent 3a5f898 commit 46cf8ca

File tree

1 file changed

+56
-24
lines changed

1 file changed

+56
-24
lines changed

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,32 @@ def cache_permute_indices():
18341834
return _cache_permute_indices
18351835

18361836

1837+
@pytest.fixture(autouse=True)
1838+
def clear_cache_between_test_functions(cache_permute_indices, request):
1839+
"""Automatically clear cache when switching between different test functions.
1840+
1841+
This keeps the cache within the same test function (across all parametrized runs)
1842+
but clears it when moving to a different test function.
1843+
"""
1844+
# Get the base test function name (without parameters)
1845+
test_function_name = request.node.originalname or request.node.name.split("[")[0]
1846+
1847+
# Store the current test function name in the module
1848+
if not hasattr(request.module, "_current_test_function"):
1849+
request.module._current_test_function = test_function_name
1850+
elif request.module._current_test_function != test_function_name:
1851+
# We've switched to a different test function, clear the cache and GPU state
1852+
cache_permute_indices.clear()
1853+
request.module._current_test_function = test_function_name
1854+
1855+
# Synchronize and clear GPU memory/cache
1856+
torch.cuda.synchronize()
1857+
torch.cuda.empty_cache()
1858+
1859+
yield # Run the test
1860+
# No cleanup needed here - we clear at the start of the next different function
1861+
1862+
18371863
def skip_checks(
18381864
moe_impl,
18391865
routing_config,
@@ -2312,18 +2338,21 @@ def test_renormalize_routing(
23122338
@pytest.mark.parametrize(
23132339
"routing_config",
23142340
[
2315-
{
2316-
"num_experts": 16,
2317-
"top_k": 2,
2318-
"padding": 8,
2319-
"n_groups": None,
2320-
"top_k_groups": None,
2321-
"routed_scaling": None,
2322-
"has_routing_bias": False,
2323-
"routing_method_type": RoutingMethodType.TopK,
2324-
"compatible_moe_impls": [FP4Moe],
2325-
"compatible_intermediate_size": [384, 512, 768, 1024],
2326-
},
2341+
pytest.param(
2342+
{
2343+
"num_experts": 16,
2344+
"top_k": 2,
2345+
"padding": 8,
2346+
"n_groups": None,
2347+
"top_k_groups": None,
2348+
"routed_scaling": None,
2349+
"has_routing_bias": False,
2350+
"routing_method_type": RoutingMethodType.TopK,
2351+
"compatible_moe_impls": [FP4Moe],
2352+
"compatible_intermediate_size": [384, 512, 768, 1024],
2353+
},
2354+
id="TopK",
2355+
),
23272356
],
23282357
)
23292358
@pytest.mark.parametrize(
@@ -2382,18 +2411,21 @@ def test_topk_routing(
23822411
@pytest.mark.parametrize(
23832412
"routing_config",
23842413
[
2385-
{
2386-
"num_experts": 128,
2387-
"top_k": 1,
2388-
"padding": 8,
2389-
"n_groups": 0,
2390-
"top_k_groups": 0,
2391-
"routed_scaling": 2.5,
2392-
"has_routing_bias": True,
2393-
"routing_method_type": RoutingMethodType.Llama4,
2394-
"compatible_moe_impls": [FP8PerTensorMoe],
2395-
"compatible_intermediate_size": [1024, 2048],
2396-
},
2414+
pytest.param(
2415+
{
2416+
"num_experts": 128,
2417+
"top_k": 1,
2418+
"padding": 8,
2419+
"n_groups": 0,
2420+
"top_k_groups": 0,
2421+
"routed_scaling": 2.5,
2422+
"has_routing_bias": True,
2423+
"routing_method_type": RoutingMethodType.Llama4,
2424+
"compatible_moe_impls": [FP8PerTensorMoe],
2425+
"compatible_intermediate_size": [1024, 2048],
2426+
},
2427+
id="Llama4",
2428+
),
23972429
],
23982430
)
23992431
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)