From 4d2b52010521f93079fe88ba1eab01916ef8afaa Mon Sep 17 00:00:00 2001 From: nidal567 Date: Tue, 31 Mar 2026 19:05:03 +0000 Subject: [PATCH 1/2] Fix nondeterministic RNG in test_fused_mxfp4_quant Tests in test_fused_mxfp4_quant.py were failing in CI, especially when executed as part of shard 3. The failures were not reproducible when running the test line in isolation. Thanks to Bruno for providing the command line. Root cause: The random seed was previously set to be at the top-level part of the module just after imports via torch.manual_seed(). This caused test behaviour to depend on the global RNG state, which is affected by previously executed tests in the same shard (which makes sense why it worked in isolation, but not in the shard). As a result, the test outcomes were order-dependent and non-deterministic. Fix: - Removed torch.manual_seed() from top-level part of module - Added this deterministic seeding behaviour to the test case that was being impacted by this to ensure order-independent behaviour Validation: - Reproduced failure using CI shard 3 command locally - Verified the failures occuring in op_tests/triton_tests/quant/test_fused_mxfp4_quant.py::test_fused_rms_quant - After fix: - All tests pass in shard 3 with TRITON_HIP_USE_ASYNC_COPY=0 - Test_fused_rms_quant also passes with ASYNC_COPY enabled (in command line run with shard 3 and isolation) - Tests pass consistently in isolation and repeated runs Additional Notes: - Remaining failures with TRITON_HIP_USE_ASYNC_COPY=1 are affected (MoE + GEMM known issues with ASYNC enabled). This is unrelated to the current task and can be addressed separately --- op_tests/triton_tests/quant/test_fused_mxfp4_quant.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py index ea43b7a358..812eb44b1d 100644 --- a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py +++ b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py @@ -22,7 +22,6 @@ from aiter.ops.quant import per_1x32_f4_quant_hip from aiter.utility.fp4_utils import moe_mxfp4_sort, dynamic_mxfp4_quant -torch.manual_seed(0) def rmsnorm(input, weight, eps=1e-6): @@ -169,7 +168,7 @@ def test_fused_rms_quant( shuffle: bool, scale_shuffle_padding: bool, ): - + torch.manual_seed(0) if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") From 46bbfd31882db78fac22652e002286df10e2dbed Mon Sep 17 00:00:00 2001 From: nidal567 Date: Wed, 1 Apr 2026 00:08:01 +0000 Subject: [PATCH 2/2] Set RNG seed before all test cases to make everything deterministic. Moved the seeds after skip condition, and used black to format file --- .../triton_tests/quant/test_fused_mxfp4_quant.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py index 812eb44b1d..7cdd5cd48a 100644 --- a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py +++ b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py @@ -23,7 +23,6 @@ from aiter.utility.fp4_utils import moe_mxfp4_sort, dynamic_mxfp4_quant - def rmsnorm(input, weight, eps=1e-6): row_norm = input * input row_norm = torch.sum(row_norm, dim=-1) @@ -129,6 +128,8 @@ def test_flatten_quant(B: int, M: int, N: int, dtype): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") + torch.manual_seed(0) + torch.cuda.empty_cache() # Helps avoid hangs in large tests x = torch.randn((B, M, N), dtype=dtype, device="cuda").transpose(0, 1) @@ -168,10 +169,11 @@ def test_fused_rms_quant( shuffle: bool, scale_shuffle_padding: bool, ): - torch.manual_seed(0) if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") + torch.manual_seed(0) + torch.cuda.empty_cache() # Helps avoid hangs in large tests x1, x2, rms1_w, rms2_w, resid1 = generate_fused_rms_quant_data( x1_shape=(M, N1), @@ -316,6 +318,8 @@ def test_fused_reduce_act_mul_mxfp4_group_quant( if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") + torch.manual_seed(0) + if shuffle and (N1 * 2) % 512 != 0: pytest.skip() @@ -401,6 +405,8 @@ def test_fuse_reduce_rms_quant( if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") + torch.manual_seed(0) + torch.cuda.empty_cache() # Helps avoid hangs in large tests x1, w1, x2, w2, res1, x3 = generate_fused_reduce_rms_quant_data( M, N1, N2, N3, SPK, dtype @@ -547,6 +553,9 @@ def test_fused_dynamic_mxfp4_quant_moe_sort( ): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") + + torch.manual_seed(0) + q_dtype_a = torch.float4_e2m1fn_x2 num_local_tokens = None num_valid_ids = torch.zeros(2, dtype=torch.int64, device="cuda")