@@ -1834,6 +1834,32 @@ def cache_permute_indices():
18341834 return _cache_permute_indices
18351835
18361836
1837+ @pytest .fixture (autouse = True )
1838+ def clear_cache_between_test_functions (cache_permute_indices , request ):
1839+ """Automatically clear cache when switching between different test functions.
1840+
1841+ This keeps the cache within the same test function (across all parametrized runs)
1842+ but clears it when moving to a different test function.
1843+ """
1844+ # Get the base test function name (without parameters)
1845+ test_function_name = request .node .originalname or request .node .name .split ("[" )[0 ]
1846+
1847+ # Store the current test function name in the module
1848+ if not hasattr (request .module , "_current_test_function" ):
1849+ request .module ._current_test_function = test_function_name
1850+ elif request .module ._current_test_function != test_function_name :
1851+ # We've switched to a different test function, clear the cache and GPU state
1852+ cache_permute_indices .clear ()
1853+ request .module ._current_test_function = test_function_name
1854+
1855+ # Synchronize and clear GPU memory/cache
1856+ torch .cuda .synchronize ()
1857+ torch .cuda .empty_cache ()
1858+
1859+ yield # Run the test
1860+ # No cleanup needed here - we clear at the start of the next different function
1861+
1862+
18371863def skip_checks (
18381864 moe_impl ,
18391865 routing_config ,
@@ -2312,18 +2338,21 @@ def test_renormalize_routing(
23122338@pytest .mark .parametrize (
23132339 "routing_config" ,
23142340 [
2315- {
2316- "num_experts" : 16 ,
2317- "top_k" : 2 ,
2318- "padding" : 8 ,
2319- "n_groups" : None ,
2320- "top_k_groups" : None ,
2321- "routed_scaling" : None ,
2322- "has_routing_bias" : False ,
2323- "routing_method_type" : RoutingMethodType .TopK ,
2324- "compatible_moe_impls" : [FP4Moe ],
2325- "compatible_intermediate_size" : [384 , 512 , 768 , 1024 ],
2326- },
2341+ pytest .param (
2342+ {
2343+ "num_experts" : 16 ,
2344+ "top_k" : 2 ,
2345+ "padding" : 8 ,
2346+ "n_groups" : None ,
2347+ "top_k_groups" : None ,
2348+ "routed_scaling" : None ,
2349+ "has_routing_bias" : False ,
2350+ "routing_method_type" : RoutingMethodType .TopK ,
2351+ "compatible_moe_impls" : [FP4Moe ],
2352+ "compatible_intermediate_size" : [384 , 512 , 768 , 1024 ],
2353+ },
2354+ id = "TopK" ,
2355+ ),
23272356 ],
23282357)
23292358@pytest .mark .parametrize (
@@ -2382,18 +2411,21 @@ def test_topk_routing(
23822411@pytest .mark .parametrize (
23832412 "routing_config" ,
23842413 [
2385- {
2386- "num_experts" : 128 ,
2387- "top_k" : 1 ,
2388- "padding" : 8 ,
2389- "n_groups" : 0 ,
2390- "top_k_groups" : 0 ,
2391- "routed_scaling" : 2.5 ,
2392- "has_routing_bias" : True ,
2393- "routing_method_type" : RoutingMethodType .Llama4 ,
2394- "compatible_moe_impls" : [FP8PerTensorMoe ],
2395- "compatible_intermediate_size" : [1024 , 2048 ],
2396- },
2414+ pytest .param (
2415+ {
2416+ "num_experts" : 128 ,
2417+ "top_k" : 1 ,
2418+ "padding" : 8 ,
2419+ "n_groups" : 0 ,
2420+ "top_k_groups" : 0 ,
2421+ "routed_scaling" : 2.5 ,
2422+ "has_routing_bias" : True ,
2423+ "routing_method_type" : RoutingMethodType .Llama4 ,
2424+ "compatible_moe_impls" : [FP8PerTensorMoe ],
2425+ "compatible_intermediate_size" : [1024 , 2048 ],
2426+ },
2427+ id = "Llama4" ,
2428+ ),
23972429 ],
23982430)
23992431@pytest .mark .parametrize (
0 commit comments