@@ -1834,6 +1834,32 @@ def cache_permute_indices():
18341834 return _cache_permute_indices
18351835
18361836
1837+ @pytest .fixture (autouse = True )
1838+ def clear_cache_between_test_functions (cache_permute_indices , request ):
1839+ """Automatically clear cache when switching between different test functions.
1840+
1841+ This keeps the cache within the same test function (across all parametrized runs)
1842+ but clears it when moving to a different test function.
1843+ """
1844+ # Get the base test function name (without parameters)
1845+ test_function_name = request .node .originalname or request .node .name .split ("[" )[0 ]
1846+
1847+ # Store the current test function name in the module
1848+ if not hasattr (request .module , "_current_test_function" ):
1849+ request .module ._current_test_function = test_function_name
1850+ elif request .module ._current_test_function != test_function_name :
1851+ # We've switched to a different test function, clear the cache and GPU state
1852+ cache_permute_indices .clear ()
1853+ request .module ._current_test_function = test_function_name
1854+
1855+ # Synchronize and clear GPU memory/cache
1856+ torch .cuda .synchronize ()
1857+ torch .cuda .empty_cache ()
1858+
1859+ yield # Run the test
1860+ # No cleanup needed here - we clear at the start of the next different function
1861+
1862+
18371863def skip_checks (
18381864 moe_impl ,
18391865 routing_config ,
@@ -2313,18 +2339,21 @@ def test_renormalize_routing(
23132339@pytest .mark .parametrize (
23142340 "routing_config" ,
23152341 [
2316- {
2317- "num_experts" : 16 ,
2318- "top_k" : 2 ,
2319- "padding" : 8 ,
2320- "n_groups" : None ,
2321- "top_k_groups" : None ,
2322- "routed_scaling" : None ,
2323- "has_routing_bias" : False ,
2324- "routing_method_type" : RoutingMethodType .TopK ,
2325- "compatible_moe_impls" : [FP4Moe ],
2326- "compatible_intermediate_size" : [384 , 512 , 768 , 1024 ],
2327- },
2342+ pytest .param (
2343+ {
2344+ "num_experts" : 16 ,
2345+ "top_k" : 2 ,
2346+ "padding" : 8 ,
2347+ "n_groups" : None ,
2348+ "top_k_groups" : None ,
2349+ "routed_scaling" : None ,
2350+ "has_routing_bias" : False ,
2351+ "routing_method_type" : RoutingMethodType .TopK ,
2352+ "compatible_moe_impls" : [FP4Moe ],
2353+ "compatible_intermediate_size" : [384 , 512 , 768 , 1024 ],
2354+ },
2355+ id = "TopK" ,
2356+ ),
23282357 ],
23292358)
23302359@pytest .mark .parametrize (
@@ -2383,18 +2412,21 @@ def test_topk_routing(
23832412@pytest .mark .parametrize (
23842413 "routing_config" ,
23852414 [
2386- {
2387- "num_experts" : 128 ,
2388- "top_k" : 1 ,
2389- "padding" : 8 ,
2390- "n_groups" : 0 ,
2391- "top_k_groups" : 0 ,
2392- "routed_scaling" : 2.5 ,
2393- "has_routing_bias" : True ,
2394- "routing_method_type" : RoutingMethodType .Llama4 ,
2395- "compatible_moe_impls" : [FP8PerTensorMoe ],
2396- "compatible_intermediate_size" : [1024 , 2048 ],
2397- },
2415+ pytest .param (
2416+ {
2417+ "num_experts" : 128 ,
2418+ "top_k" : 1 ,
2419+ "padding" : 8 ,
2420+ "n_groups" : 0 ,
2421+ "top_k_groups" : 0 ,
2422+ "routed_scaling" : 2.5 ,
2423+ "has_routing_bias" : True ,
2424+ "routing_method_type" : RoutingMethodType .Llama4 ,
2425+ "compatible_moe_impls" : [FP8PerTensorMoe ],
2426+ "compatible_intermediate_size" : [1024 , 2048 ],
2427+ },
2428+ id = "Llama4" ,
2429+ ),
23982430 ],
23992431)
24002432@pytest .mark .parametrize (
0 commit comments