Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions op_tests/triton_tests/quant/test_fused_mxfp4_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from aiter.ops.quant import per_1x32_f4_quant_hip
from aiter.utility.fp4_utils import moe_mxfp4_sort, dynamic_mxfp4_quant

torch.manual_seed(0)


def rmsnorm(input, weight, eps=1e-6):
row_norm = input * input
Expand Down Expand Up @@ -130,6 +128,8 @@ def test_flatten_quant(B: int, M: int, N: int, dtype):
if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")

torch.manual_seed(0)

torch.cuda.empty_cache() # Helps avoid hangs in large tests

x = torch.randn((B, M, N), dtype=dtype, device="cuda").transpose(0, 1)
Expand Down Expand Up @@ -169,10 +169,11 @@ def test_fused_rms_quant(
shuffle: bool,
scale_shuffle_padding: bool,
):

if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")

torch.manual_seed(0)

torch.cuda.empty_cache() # Helps avoid hangs in large tests
x1, x2, rms1_w, rms2_w, resid1 = generate_fused_rms_quant_data(
x1_shape=(M, N1),
Expand Down Expand Up @@ -317,6 +318,8 @@ def test_fused_reduce_act_mul_mxfp4_group_quant(
if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")

torch.manual_seed(0)

if shuffle and (N1 * 2) % 512 != 0:
pytest.skip()

Expand Down Expand Up @@ -402,6 +405,8 @@ def test_fuse_reduce_rms_quant(
if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")

torch.manual_seed(0)

torch.cuda.empty_cache() # Helps avoid hangs in large tests
x1, w1, x2, w2, res1, x3 = generate_fused_reduce_rms_quant_data(
M, N1, N2, N3, SPK, dtype
Expand Down Expand Up @@ -548,6 +553,9 @@ def test_fused_dynamic_mxfp4_quant_moe_sort(
):
if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")

torch.manual_seed(0)

q_dtype_a = torch.float4_e2m1fn_x2
num_local_tokens = None
num_valid_ids = torch.zeros(2, dtype=torch.int64, device="cuda")
Expand Down
Loading