Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions op_tests/triton_tests/attention/test_fav3_sage.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def input_helper(
k_shape = (BATCH, N_CTX_K, HK, D_HEAD)
v_shape = (BATCH, N_CTX_K, HK, D_HEAD_V)

torch.manual_seed(20)
q = torch.randn(q_shape, device="cuda", dtype=dtype)
k = torch.randn(k_shape, device="cuda", dtype=dtype)
v = torch.randn(v_shape, device="cuda", dtype=dtype)
Expand Down Expand Up @@ -214,6 +215,8 @@ def test_sage(
dtype=torch.bfloat16,
):
HEAD_SZ = 128

torch.manual_seed(20)
torch.cuda.empty_cache()

softmax_scale = 1.0 / math.sqrt(HEAD_SZ)
Expand Down
2 changes: 1 addition & 1 deletion op_tests/triton_tests/attention/test_fp8_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def test_fp8_mqa_logits(
head_dim: int,
disable_cp: bool,
) -> None:
torch.manual_seed(0)
if s_q > s_k:
pytest.skip()
torch.manual_seed(0)
q = torch.randn(s_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
kv = torch.randn(s_k, head_dim, device="cuda", dtype=torch.bfloat16)
kv_fp8, scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
Expand Down
1 change: 1 addition & 0 deletions op_tests/triton_tests/attention/test_la_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_persistent_lean_attention(
torch.cuda.empty_cache() # Helps avoid hangs in large tests

torch.manual_seed(20)
random.seed(20)
# Long seqlen (>512K) can hit memory access fault. Suspect compiler issue
# WA with shorter d and longer BLOCK_N
if any(item > 524288 for item in n_ctx):
Expand Down
10 changes: 6 additions & 4 deletions op_tests/triton_tests/attention/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,6 @@ def test_mha_backward(
dtype=torch.float16,
):
HAS_DROPOUT = DROPOUT > 0.0
torch.cuda.empty_cache()
torch.manual_seed(20)

if FP8 and not _supports_fp8:
pytest.skip(f"FP8 not supported on {arch}")
Expand All @@ -629,7 +627,10 @@ def test_mha_backward(
if FP8 and CAUSAL:
pytest.skip("FP8+CAUSAL results in random precision errors")

torch.cuda.empty_cache()
torch.manual_seed(20)
mha_set_use_fused_bwd_kernel(FUSED)

q = torch.randn(BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
k = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
v = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
Expand Down Expand Up @@ -778,8 +779,6 @@ def test_mha_backward_varlen(
HEAD_SZ = 128
NUM_K_HEADS = 8
HAS_DROPOUT = DROPOUT > 0.0
torch.cuda.empty_cache()
torch.manual_seed(20)

if FP8 and not _supports_fp8:
pytest.skip(f"FP8 not supported on {arch}")
Expand All @@ -790,7 +789,10 @@ def test_mha_backward_varlen(
if CAUSAL and HAS_DROPOUT:
pytest.skip("CAUSAL+DROPOUT backward results in NaNs")

torch.cuda.empty_cache()
torch.manual_seed(20)
mha_set_use_fused_bwd_kernel(FUSED)

q = torch.randn(BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
k = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
v = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
Expand Down
11 changes: 8 additions & 3 deletions op_tests/triton_tests/attention/test_pa_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def input_helper(
random_seed: int = 0,
):
"""Helper function to generate input tensors for paged attention testing."""
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
random.seed(random_seed)

Expand Down Expand Up @@ -188,10 +189,13 @@ def test_paged_attn(
):

head_size = 128
torch.cuda.empty_cache() # Helps avoid hangs in large tests

if SEQ_LEN >= 8192 and B >= 16:
pytest.skip("B>={4} and SEQ_LEN>={8192} tests are too slow")

torch.cuda.empty_cache() # Helps avoid hangs in large tests
torch.set_printoptions(threshold=100000)

num_blocks = NUM_BLK

(
Expand Down Expand Up @@ -277,14 +281,15 @@ def test_paged_attn_per_token_quant(
compute_type,
output_type,
):
torch.cuda.empty_cache() # Helps avoid hangs in large tests
torch.set_printoptions(precision=5, threshold=10000)
if D == 128 and KV_BLK_SZ == 512: # Causes Shared Memory out of resources on Mi300
pytest.skip("D={128} and KV_BLK_SZ={512} causes shared memory out of resources")

if SEQ_LEN >= 8192 and B >= 16:
pytest.skip("B>={4} and SEQ_LEN>={8192} tests are too slow")

torch.cuda.empty_cache() # Helps avoid hangs in large tests
torch.set_printoptions(precision=5, threshold=10000)

num_blocks = NUM_BLK

(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_fused_fp4_bmm_rope_cat_and_cache_mla(
if not arch_info.is_fp4_avail():
pytest.skip("MXFP4 is not available on this device")

torch.manual_seed(0)
_, w_k, _, w_k_scale, _ = generate_batched_gemm_a16wfp4_inputs(
QH_per_KH * KH, T, D_lora, D_q_nope, dtype, layout="TN", output=False
)
Expand Down Expand Up @@ -220,7 +221,7 @@ def test_fused_fp8_bmm_rope_cat_and_cache_mla(
pytest.skip("MXFP8 is not available on this device")

QH = QH_per_KH * KH

torch.manual_seed(0)
q_nope, w_k, w_k_scale, _, _ = generate_batched_gemm_a16w8_inputs(
QH,
T,
Expand Down
7 changes: 7 additions & 0 deletions op_tests/triton_tests/fusions/test_fused_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_fused_qk_rope_cat_and_cache_mla(
cache_dtype: bool,
dtype: torch.dtype,
):
torch.manual_seed(0)
pos = True
_, _, _, _, freqs, positions, offsets, cos, sin = generate_rope_inputs(
1,
Expand Down Expand Up @@ -207,6 +208,7 @@ def test_fused_qk_rope_reshape_and_cache(
offs: bool,
dtype: torch.dtype,
):
torch.manual_seed(0)
pos = True
q, k, _, _, freqs, positions, offsets, cos, sin = generate_rope_inputs(
1,
Expand All @@ -231,6 +233,7 @@ def test_fused_qk_rope_reshape_and_cache(
else:
cache_dtype_actual = torch.float8_e4m3fnuz
pytest.skip("Skipping FP8 dtype cases non-gfx950")
torch.manual_seed(0)

if cache_flash:
key_cache = torch.zeros(
Expand Down Expand Up @@ -441,6 +444,7 @@ def test_fused_qk_rope_reshape_and_cache_value_shuffle_layout(
"""Test fused_qk_rope_reshape_and_cache with value_cache in shuffle layout
[num_blocks, num_kv_heads, block_size // x, head_size, x].
"""
torch.manual_seed(0)
assert D % x_size == 0
pos = True
offs = False
Expand Down Expand Up @@ -584,6 +588,7 @@ def test_fused_qk_rope_reshape_and_cache_gpt_oss_120b_config_value_shuffle_preci
pos = True
offs = False

torch.manual_seed(0)
q, k, _, _, freqs, positions, offsets, cos, sin = generate_rope_inputs(
1,
T,
Expand Down Expand Up @@ -731,6 +736,7 @@ def test_fused_qk_rope_cosine_cache_llama(
offs: bool,
dtype: torch.dtype,
):
torch.manual_seed(0)
pos = True
q, k, _, _, freqs, positions, offsets, cos, sin = generate_rope_inputs(
1,
Expand Down Expand Up @@ -764,6 +770,7 @@ def test_fused_qk_rope_cosine_cache_llama(
)
else:
pytest.skip()
torch.manual_seed(0)

if cache_dtype == torch.uint8:
k_scale = torch.randn(
Expand Down
1 change: 1 addition & 0 deletions op_tests/triton_tests/fusions/test_fused_qk_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


def generate_qk_inputs(B: int, QH_PER_KH: int, KH: int, D_nope: int, D_pe: int, dtype):
torch.manual_seed(0)
q_nope = torch.randn((B, QH_PER_KH * KH, D_nope), dtype=dtype, device="cuda")
q_pe = torch.randn((B, QH_PER_KH * KH, D_pe), dtype=dtype, device="cuda")
k_nope = torch.randn((B, KH, D_nope), dtype=dtype, device="cuda")
Expand Down
1 change: 1 addition & 0 deletions op_tests/triton_tests/gemm/basic/test_gemm_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False):
torch.manual_seed(0)
if isinstance(dtype, str):
dtype = str_to_torch_dtype[dtype]

Expand Down
1 change: 1 addition & 0 deletions op_tests/triton_tests/gemm/basic/test_gemm_a16w16_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


def generate_gemm_a16w16_gated_inputs(M, N, K, dtype, layout="TN", output=True):
torch.manual_seed(0)
if isinstance(dtype, str):
dtype = str_to_torch_dtype[dtype]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def generate_gemm_a16w8_blockscale_inputs(
- x: (M, K) -> row-major format
- w: (N, K) -> column-major format
"""
torch.manual_seed(0)
scale_n = (N + block_shape_n - 1) // block_shape_n
scale_k = (K + block_shape_k - 1) // block_shape_k

Expand Down
1 change: 1 addition & 0 deletions op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def generate_gemm_a8w8_inputs(
- x: (M, K) -> row-major format
- w: (N, K) -> column-major format
"""
torch.manual_seed(0)
if layout[0] == "T":
# T (transposed) in Fortran notation equals row-major
x = torch.randn((M, K), dtype=torch.float32, device="cuda")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def generate_gemm_a8w8_blockscale_inputs(
- x: (M, K) -> row-major format
- w: (N, K) -> column-major format
"""
torch.manual_seed(0)
scale_n = (N + block_shape_n - 1) // block_shape_n
scale_k = (K + block_shape_k - 1) // block_shape_k

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def generate_gemm_a8w8_per_token_scale_inputs(
layout: str = "TN",
output=False,
):
torch.manual_seed(0)

if layout[0] == "T":
x = (torch.rand((M, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type)
Expand Down
5 changes: 3 additions & 2 deletions op_tests/triton_tests/gemm/basic/test_gemm_a8wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,13 @@ def test_gemm_a8wfp4(M: int, N: int, K: int, CLEAR_GPUS=True):
a_dtype = e4m3_type
layout = "TN" # Kernel will occasionally crash for layouts other than TN.
out_dtype = torch.bfloat16
torch.cuda.empty_cache() # Helps avoid hangs in large tests

torch.manual_seed(42) # for reproducibility
if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")

torch.cuda.empty_cache() # Helps avoid hangs in large tests
torch.manual_seed(42) # for reproducibility

# clean up to avoid hangs in large tests
if CLEAR_GPUS:
torch.cuda.empty_cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def generate_batched_gemm_a8w8_inputs(
- x_scale: shape (B, M, 1)
- w_scale: shape (B, 1, N)
"""
torch.manual_seed(0)
if isinstance(dtype, str):
dtype = str_to_torch_dtype[dtype]
if layout[0] == "T":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def generate_batched_gemm_a16w8_inputs(
- x_scale: shape (B, M, 1)
- w_scale: shape (B, 1, N)
"""
torch.manual_seed(0)
if isinstance(dtype, str):
dtype = str_to_torch_dtype[dtype]
if layout[0] == "T":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def generate_batched_gemm_a16w16_inputs(
output: bool,
layout: str = "TN",
):
torch.manual_seed(0)
if isinstance(dtype, str):
dtype = str_to_torch_dtype[dtype]
if layout[0] == "T":
Expand Down
2 changes: 2 additions & 0 deletions op_tests/triton_tests/gemm/feed_forward/test_ff_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def test_ff_a16w16_ungated(
batch: int, hidden_dim: int, intermediate_dim: int, dtype, output, activation
):
torch.manual_seed(0)
ff_ungated_test(
ff_a16w16_nogate,
batch=batch,
Expand All @@ -37,6 +38,7 @@ def test_ff_a16w16_ungated(
def test_ff_a16w16_gated(
batch: int, hidden_dim: int, intermediate_dim: int, dtype, output, activation
):
torch.manual_seed(0)
ff_gated_test(
ff_a16w16_gated,
batch=batch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_ff_a16w16_fused_ungated(
pytest.skip(
"Small differences in implementation between Triton & Torch activations accumulate to beyond test bounds w/large matrices."
)
torch.manual_seed(0)
ff_ungated_test(
ff_a16w16_fused_ungated,
batch=batch,
Expand All @@ -47,6 +48,7 @@ def test_ff_a16w16_fused_gated(
pytest.skip(
"Small differences in implementation between Triton & Torch activations accumulate to beyond test bounds w/large matrices."
)
torch.manual_seed(0)

ff_gated_test(
ff_a16w16_fused_gated,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_x_vals():
@pytest.mark.parametrize("output", [True, False])
@pytest.mark.parametrize("skip_reduce", [True, False])
def test_gemm(dtype, M, N1, N2, K, output, skip_reduce):
torch.manual_seed(0)
block_shape_n, block_shape_k = block_shape

x_fp8, w_fp8, _, x_fp8_scale, _, w_fp8_scale, y_fp8 = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_fused_gemm_a8w8_blockscale_mul_add(
b_type_is_scalar,
fuse_type,
):
torch.manual_seed(0)

(
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def generate_fused_gemm_a8w8_blockscale_split_cat_inputs(
- w: (N, K) -> column-major format
- y: (M, D, S3)
"""
torch.manual_seed(0)
scale_n = (N + block_shape_n - 1) // block_shape_n
scale_k = (K + block_shape_k - 1) // block_shape_k

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_gemm(dtype, M, N1, N2, K, output, skip_reduce, fp4_shuffle):

if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")
torch.manual_seed(0)

(
x_fp4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_fused_gemm_afp4wfp4_mul_add(
pytest.skip(
f"K = {K} is not divisible by 256, skip this test for preshuffled weight/scales tests"
)
torch.manual_seed(0)

torch.cuda.empty_cache() # Helps avoid hangs in large tests

Expand Down
1 change: 1 addition & 0 deletions op_tests/triton_tests/moe/test_moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def torch_moe_align_block_size(


def input_helper(M: int, E: int, top_k: int):
torch.manual_seed(0)
values = torch.randn(M, E, dtype=torch.float16, device="cuda")

softmax_vals = torch.softmax(values, dim=1)
Expand Down
6 changes: 3 additions & 3 deletions op_tests/triton_tests/moe/test_moe_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ def test_fused_moe(
routed_weight: bool,
swizzle_mx_scale: bool,
):
torch.cuda.empty_cache() # Helps avoid hangs in large tests
torch.manual_seed(20)
if not (arch_info.is_fp4_avail()):
pytest.skip("MXFP4 not supported on this architecture")
pytest.skip("MXFP4 not supported on this architecture")

torch.cuda.empty_cache() # Helps avoid hangs in large tests
torch.manual_seed(20)

(
a_tri,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


def generate_inputs(M, N, has_res, dtype):
torch.manual_seed(0)
x = torch.randn((M, N), dtype=dtype, device="cuda")
weight = torch.randn((N,), dtype=dtype, device="cuda")
res = torch.randn((M, N), dtype=dtype, device="cuda") if has_res else None
Expand Down
Loading
Loading