Skip to content

Commit cfab9b7

Browse files
committed
Add skips if not hopper+
Signed-off-by: Neta Zmora <[email protected]>
1 parent d0b11d8 commit cfab9b7

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@
77

88
import pytest
99
import torch
10+
from _torch_test_utils import fp8_compatible, trtllm_ops_available # noqa: F401
1011
from torch.nn import functional as F
1112

13+
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
1214
from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType
1315

1416
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
1517
FP8_DTYPE = torch.float8_e4m3fn
1618

1719

20+
def _is_hopper_or_later():
21+
return torch.cuda.get_device_capability(0) >= (8, 9)
22+
23+
1824
def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
1925
fp8_traits_max = FLOAT8_E4M3_MAX
2026
fp8_traits_min = -FLOAT8_E4M3_MAX
@@ -179,6 +185,10 @@ def _print_diff_if(
179185
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
180186
@pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES)
181187
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
188+
@pytest.mark.skipif(
189+
not _is_hopper_or_later() or not trtllm_ops_available(),
190+
reason="Requires Hopper or later and trtllm support",
191+
)
182192
def test_trtllm_fused_moe(
183193
batch_size,
184194
hidden_size,
@@ -286,6 +296,10 @@ def get_fc1_expert_weights(
286296
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
287297
@pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES)
288298
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
299+
@pytest.mark.skipif(
300+
not fp8_compatible() or not trtllm_ops_available(),
301+
reason="Requires fp8 and trtllm support",
302+
)
289303
def test_trtllm_fused_fp8moe(
290304
batch_size,
291305
hidden_size,

0 commit comments

Comments
 (0)