From 96d271ec05023204e3dc967530d6194092bc2ec2 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 26 Nov 2024 19:33:42 -0800 Subject: [PATCH] Update hardware check conditions --- test/dtypes/test_affine_quantized.py | 6 ++--- test/dtypes/test_affine_quantized_float.py | 23 ++++++++-------- test/float8/test_base.py | 22 +++++++--------- test/float8/test_compile.py | 26 ++++++++----------- test/float8/test_fsdp2/test_fsdp2.py | 5 ++-- .../test_fsdp2/test_fsdp2_fp8_comm_only.py | 5 ++-- test/float8/test_numerics_integration.py | 9 +++---- test/integration/test_integration.py | 10 +++---- test/kernel/test_autotuner.py | 4 +-- test/prototype/mx_formats/test_mx_linear.py | 9 +++---- test/prototype/mx_formats/test_mx_tensor.py | 7 ++--- 11 files changed, 53 insertions(+), 73 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e049500e3..048533fdb 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -17,9 +17,7 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 def get_quantization_functions(do_sparse: bool, do_int4: bool): @@ -37,7 +35,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) - if is_cuda_8_9: + if is_sm_89(): base_functions.append(float8_weight_only()) return base_functions diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 74c130dc5..7ed39d534 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -37,13 +37,14 @@ MappingType, choose_qparams_affine, ) +from torchao.utils import ( + is_sm_89, + is_sm_90, +) random.seed(0) torch.manual_seed(0) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features): @@ -59,12 +60,12 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + "granularity", [PerTensor(), PerRow()] if is_sm_90() else [PerTensor()] ) # Inputs are (M,..), K, N @common_utils.parametrize( @@ -134,12 +135,12 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -147,7 +148,7 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") def test_unsupported_granularity(self): class UnsupportedGranularity: pass @@ -158,7 +159,7 @@ class UnsupportedGranularity: ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") def test_per_row_with_float32(self): with pytest.raises( AssertionError, @@ -170,7 +171,7 @@ def test_per_row_with_float32(self): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model @@ -240,7 +241,7 @@ def test_serialization(self, mode: str): ), f"Scales do not match for {layer_name}" @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d00b96d3b..bec515b67 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -60,10 +60,6 @@ torch.manual_seed(0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -219,7 +215,7 @@ def test_axiswise_reshape(self): ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") + @unittest.skipIf(not is_sm_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -333,7 +329,7 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", @@ -415,7 +411,7 @@ def test_linear_from_recipe( config, ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -462,7 +458,7 @@ def test_autocast_outputs( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) @@ -523,7 +519,7 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s - @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") + @unittest.skipIf(not is_sm_89(), "CUDA 8.9 not available") def test_inference_mode(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -534,7 +530,7 @@ def test_inference_mode(self): class TestScaledMM: @unittest.skipIf( - not is_cuda_8_9, + not is_sm_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -579,7 +575,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 2e-3, 2e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_cuda_8_9, "CUDA not available") + @unittest.skipIf(not is_sm_89(), "CUDA not available") def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") @@ -615,7 +611,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_cuda_8_9, + not is_sm_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ced5db7ff..0df5cd633 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -11,7 +11,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -46,10 +46,6 @@ from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config -# TODO(future PR): standardize IS_H100 with the rest of the codebase -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - def _test_compile_base( backend: str, @@ -99,7 +95,7 @@ def _test_compile_base( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -126,7 +122,7 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @@ -177,7 +173,7 @@ def test_aot_eager( [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -215,7 +211,7 @@ def test_inductor_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) -@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") +@unittest.skipIf(not is_sm_90(), "CUDA with capability 9.0 or greater not available") def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() config = recipe_name_to_linear_config(recipe_name) @@ -253,7 +249,7 @@ def forward(self, x): # TODO(future): figure out why the test below fails on CUDA capability 8.9 @unittest.skipIf( - not torch.cuda.is_available() or not is_H100, + not torch.cuda.is_available() or not is_sm_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): @@ -269,7 +265,7 @@ def test_float8_with_graph_break_in_the_middle(self): torch.testing.assert_close(y_eager, y_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -293,7 +289,7 @@ def to_float(x): torch.testing.assert_close(y2_eager, y2_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): @@ -323,7 +319,7 @@ def test_float8_graph_output(self): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_89(), "CUDA with float8 support not available", ) def test_sync_amax_func(): @@ -364,7 +360,7 @@ def __exit__(self, *args): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_89(), "CUDA with float8 support not available", ) def test_sync_amax_func_cuda_graph_success(): @@ -396,7 +392,7 @@ def test_sync_amax_func_cuda_graph_success(): @unittest.skipIf( - not is_cuda_8_9, + not is_sm_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index c3e31816a..70ac82fbf 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -6,7 +6,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -40,8 +40,7 @@ from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py index d5c0d7b85..faa8866fb 100644 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py @@ -3,7 +3,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -30,8 +30,7 @@ from torchao.float8.float8_tensor import GemmInputRole from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index e9028c871..37b7817e1 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,7 +11,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -34,9 +34,6 @@ from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - torch.manual_seed(0) @@ -176,7 +173,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif(not is_sm_89(), reason="requires SM89 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( self, @@ -199,7 +196,7 @@ def test_encoder_fw_bw_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) - @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") + @pytest.mark.skipif(not is_sm_90(), reason="requires SM90 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_recipe( self, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index ac2403d6d..75581b441 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -91,7 +91,8 @@ TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, is_fbcode, - benchmark_model + benchmark_model, + is_sm_90, ) logger = logging.getLogger("INFO") @@ -104,7 +105,6 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -775,7 +775,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype @@ -795,7 +795,7 @@ def test_autoquantizable_flatten_unflatten(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): @@ -809,7 +809,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 4ed097417..f64940b07 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -13,10 +13,10 @@ import pytest import torch from parameterized import parameterized +from torchao.utils import is_sm_90 logging.basicConfig(level=logging.INFO) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_H100, "Needs H100") + @unittest.skipIf(not is_sm_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bc9b02deb..5cf78cce2 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -20,11 +20,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -102,7 +99,7 @@ def test_linear_compile(elem_dtype, bias): Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") input_shape = (2, 4) grad_shape = (2, 6) @@ -173,7 +170,7 @@ def test_inference_compile_simple(elem_dtype): Smoke test for inference compile """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 964a57541..d7b60f384 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -24,11 +24,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -225,7 +222,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton")