Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def test_trtllm_batch_prefill(
max_kv_len,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] in [11, 12]:
pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
# Set up test parameters
torch.manual_seed(0)
head_dim = 128
Expand Down Expand Up @@ -918,8 +918,8 @@ def test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] in [11, 12]:
pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
if s_qo > s_kv:
pytest.skip("s_qo > s_kv, skipping test as causal")

Expand Down
4 changes: 2 additions & 2 deletions tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_trtllm_batch_decode_mla(
enable_pdl: bool,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] in [11, 12]:
pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
if dynamic_scale and dtype != torch.float8_e4m3fn:
pytest.skip("Dynamic scale is not supported for non-fp8 dtype")

Expand Down
7 changes: 6 additions & 1 deletion tests/gemm/test_bmm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass", "auto"])
@pytest.mark.parametrize("auto_tuning", [True, False])
def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning):
if get_compute_capability(torch.device("cuda"))[0] == 12 and backend in [
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] == 12 and backend in [
"cutlass",
"auto",
]:
# TODO(yongwwww): enable all test cases for SM120/121 CUTLASS bmm_fp8 backend
pytest.xfail(
"Not all test cases for CUTLASS bmm_fp8 on SM120/121 are passing at this moment"
)
if backend == "cutlass" and compute_capability[0] not in [10, 11, 12]:
pytest.skip(
"bmm_fp8 with cutlass backend is only supported on SM100, SM110, and SM120/121 GPUs."
)
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
pytest.skip("Invalid combination: both input and mat2 are e5m2")
if input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2:
Expand Down
19 changes: 16 additions & 3 deletions tests/gemm/test_groupwise_scaled_gemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def test_fp8_blockscale_gemm(
scale_major_mode,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] not in [10, 11, 12]:
pytest.skip(
"gemm_fp8_nt_blockscaled is only supported on SM100/103, SM110, and SM120/121 GPUs."
)
torch.random.manual_seed(0)
tile_size = 128

Expand Down Expand Up @@ -83,8 +88,8 @@ def test_fp8_groupwise_gemm(
scale_major_mode,
backend,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if backend == "trtllm":
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 10:
pytest.skip(
"gemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend."
Expand All @@ -93,7 +98,10 @@ def test_fp8_groupwise_gemm(
pytest.skip("trtllm only supports MN scale_major_mode")
if k < 256:
pytest.skip("k < 256")

if backend == "cutlass" and compute_capability[0] not in [10, 11, 12]:
pytest.skip(
"gemm_fp8_nt_groupwise with cutlass backend is only supported on SM100/103, SM110, and SM120/121 GPUs."
)
torch.random.manual_seed(0)
tile_size = 128
out_dtype = torch.bfloat16
Expand Down Expand Up @@ -146,12 +154,17 @@ def test_fp8_groupwise_group_gemm(
scale_major_mode,
out_dtype,
):
if group_size > 1 and torch.cuda.get_device_capability()[0] in [
compute_capability = get_compute_capability(torch.device(device="cuda"))
if group_size > 1 and compute_capability[0] in [
12,
]:
pytest.skip(
"group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121"
)
if compute_capability[0] not in [10, 12]:
pytest.skip(
"group_gemm_fp8_nt_groupwise is only supported on SM100/103, and SM120/121 GPUs."
)
torch.random.manual_seed(0)
tile_size = 128

Expand Down
6 changes: 4 additions & 2 deletions tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,10 @@ def test_mxfp8_mxfp4_groupwise_group_gemm(
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
# TODO: We need to add gemm_mxfp4_nt_groupwise support for sm120/121 at some point.
if compute_capability[0] == 12:
pytest.skip("gemm_mxfp4_nt_groupwise is not supported in SM120/SM121.")
if compute_capability[0] not in [10]:
pytest.skip(
"gemm_mxfp4_nt_groupwise is only supported on SM100 and SM103 GPUs."
)
torch.random.manual_seed(0)
tile_size = 32
alignment_n = 8
Expand Down
6 changes: 6 additions & 0 deletions tests/gemm/test_mm_fp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_mm_fp4(
use_nvfp4 = fp4_type == "nvfp4"

compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
if not mm_fp4.is_backend_supported(backend, compute_capability_number):
pytest.skip(
f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}."
)

if backend == "trtllm":
if res_dtype == torch.float16:
pytest.skip("Skipping test for trtllm fp4 with float16")
Expand Down
4 changes: 2 additions & 2 deletions tests/moe/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,8 +2035,8 @@ def test_moe_quantization_classes(
Each quantization class clearly shows which precision is being used.
"""
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] in [11, 12]:
pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
if compute_capability[0] not in [10]:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
# Skip incompatible combinations
if gated_act_type == GatedActType.GeGlu and (
type(moe_impl) is not FP4Moe
Expand Down