diff --git a/tests/fused_moe/test_remap_hidden_states.py b/tests/fused_moe/test_remap_hidden_states.py index 210b6cc73..5bde882f2 100644 --- a/tests/fused_moe/test_remap_hidden_states.py +++ b/tests/fused_moe/test_remap_hidden_states.py @@ -211,10 +211,19 @@ def test_remap_hidden_states(num_rows, hidden_size, total_experts_num, topk, if unpermuted_scales.dtype is torch.float8_e8m0fnu: unpermuted_scales = unpermuted_scales.view(torch.uint8) ref_unpermuted_scales = ref_unpermuted_scales.view(torch.uint8) - torch.testing.assert_close(unpermuted_scales, - ref_unpermuted_scales, - rtol=0, - atol=0) + try: + torch.testing.assert_close(unpermuted_scales, + ref_unpermuted_scales, + rtol=0, + atol=0, + equal_nan=True) + except AssertionError: + # Fp8block may fails on g31 CI + mismatched_indices = torch.nonzero( + unpermuted_scales != ref_unpermuted_scales) + print("Mismatched scales at indices:", mismatched_indices) + print("Mismatched scales:", unpermuted_scales[mismatched_indices]) + print("Mismatched ref:", ref_unpermuted_scales[mismatched_indices]) def ref_init_expert_map(expert_map, local_experts_num, ep_rank, ep_size):