|
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 | import torch |
| 10 | +from _torch_test_utils import fp8_compatible, trtllm_ops_available # noqa: F401 |
10 | 11 | from torch.nn import functional as F |
11 | 12 |
|
| 13 | +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 |
12 | 14 | from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType |
13 | 15 |
|
14 | 16 | FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max |
15 | 17 | FP8_DTYPE = torch.float8_e4m3fn |
16 | 18 |
|
17 | 19 |
|
| 20 | +def _is_hopper_or_later(): |
| 21 | + return torch.cuda.get_device_capability(0) >= (8, 9) |
| 22 | + |
| 23 | + |
18 | 24 | def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: |
19 | 25 | fp8_traits_max = FLOAT8_E4M3_MAX |
20 | 26 | fp8_traits_min = -FLOAT8_E4M3_MAX |
@@ -179,6 +185,10 @@ def _print_diff_if( |
179 | 185 | @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) |
180 | 186 | @pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES) |
181 | 187 | @pytest.mark.parametrize("activation_func", ["silu", "relu2"]) |
| 188 | +@pytest.mark.skipif( |
| 189 | + not _is_hopper_or_later() or not trtllm_ops_available(), |
| 190 | + reason="Requires Hopper or later and trtllm support", |
| 191 | +) |
182 | 192 | def test_trtllm_fused_moe( |
183 | 193 | batch_size, |
184 | 194 | hidden_size, |
@@ -286,6 +296,10 @@ def get_fc1_expert_weights( |
286 | 296 | @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) |
287 | 297 | @pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES) |
288 | 298 | @pytest.mark.parametrize("activation_func", ["silu", "relu2"]) |
| 299 | +@pytest.mark.skipif( |
| 300 | + not fp8_compatible() or not trtllm_ops_available(), |
| 301 | + reason="Requires fp8 and trtllm support", |
| 302 | +) |
289 | 303 | def test_trtllm_fused_fp8moe( |
290 | 304 | batch_size, |
291 | 305 | hidden_size, |
|
0 commit comments