From f2926462a0e663cbd2e9b22c42d8bea2cdc92f2a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 27 Nov 2024 12:01:35 -0800 Subject: [PATCH] RUpdate is_sm_ -> is_sm_at_least ruff --- test/dtypes/test_affine_quantized.py | 8 +++-- test/dtypes/test_affine_quantized_float.py | 34 +++++++++++++------ test/float8/test_base.py | 28 ++++++++++----- test/float8/test_compile.py | 28 +++++++++------ test/float8/test_fsdp2/test_fsdp2.py | 4 +-- .../test_fsdp2/test_fsdp2_fp8_comm_only.py | 4 +-- test/float8/test_numerics_integration.py | 14 ++++++-- test/integration/test_integration.py | 8 ++--- test/kernel/test_autotuner.py | 4 +-- test/prototype/mx_formats/test_mx_linear.py | 6 ++-- test/prototype/mx_formats/test_mx_tensor.py | 4 +-- torchao/quantization/quant_api.py | 12 +++---- torchao/utils.py | 8 ++--- 13 files changed, 102 insertions(+), 60 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 52ee29c43..43d57b7d1 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -17,7 +17,11 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, +) def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): @@ -40,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) - if is_sm_89(): + if is_sm_at_least_89(): base_functions.append(float8_weight_only()) return base_functions diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 7ed39d534..4d8312b42 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -38,8 +38,8 @@ choose_qparams_affine, ) from torchao.utils import ( - is_sm_89, - is_sm_90, + is_sm_at_least_89, + is_sm_at_least_90, ) random.seed(0) @@ -60,12 +60,14 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "granularity", [PerTensor(), PerRow()] if is_sm_90() else [PerTensor()] + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] ) # Inputs are (M,..), K, N @common_utils.parametrize( @@ -135,12 +137,16 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -148,7 +154,9 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_unsupported_granularity(self): class UnsupportedGranularity: pass @@ -159,7 +167,9 @@ class UnsupportedGranularity: ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_per_row_with_float32(self): with pytest.raises( AssertionError, @@ -171,7 +181,9 @@ def test_per_row_with_float32(self): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model @@ -241,7 +253,9 @@ def test_serialization(self, mode: str): ), f"Scales do not match for {layer_name}" @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights diff --git a/test/float8/test_base.py b/test/float8/test_base.py index bec515b67..86ff16427 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,11 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -215,7 +219,7 @@ def test_axiswise_reshape(self): ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_sm_90(), "Requires CUDA capability >= 9.0") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -329,7 +333,9 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", @@ -411,7 +417,9 @@ def test_linear_from_recipe( config, ) - @pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -458,7 +466,9 @@ def test_autocast_outputs( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) @@ -519,7 +529,7 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s - @unittest.skipIf(not is_sm_89(), "CUDA 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -530,7 +540,7 @@ def test_inference_mode(self): class TestScaledMM: @unittest.skipIf( - not is_sm_89(), + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -575,7 +585,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 2e-3, 2e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_sm_89(), "CUDA not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") @@ -611,7 +621,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_sm_89(), + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 0df5cd633..6d21686e3 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -95,7 +99,7 @@ def _test_compile_base( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -122,7 +126,7 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @@ -173,7 +177,7 @@ def test_aot_eager( [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_89(), + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -211,7 +215,9 @@ def test_inductor_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) -@unittest.skipIf(not is_sm_90(), "CUDA with capability 9.0 or greater not available") +@unittest.skipIf( + not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" +) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() config = recipe_name_to_linear_config(recipe_name) @@ -249,7 +255,7 @@ def forward(self, x): # TODO(future): figure out why the test below fails on CUDA capability 8.9 @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_90(), + not torch.cuda.is_available() or not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): @@ -265,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self): torch.testing.assert_close(y_eager, y_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_89(), + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -289,7 +295,7 @@ def to_float(x): torch.testing.assert_close(y2_eager, y2_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_89(), + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): @@ -319,7 +325,7 @@ def test_float8_graph_output(self): @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_89(), + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func(): @@ -360,7 +366,7 @@ def __exit__(self, *args): @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_89(), + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func_cuda_graph_success(): @@ -392,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success(): @unittest.skipIf( - not is_sm_89(), + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 70ac82fbf..fbe5c9b50 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -6,7 +6,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -40,7 +40,7 @@ from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -if not is_sm_89(): +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py index faa8866fb..d2e9a51c7 100644 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py @@ -3,7 +3,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -30,7 +30,7 @@ from torchao.float8.float8_tensor import GemmInputRole from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only -if not is_sm_89(): +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 37b7817e1..311964d83 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -173,7 +177,9 @@ def _test_impl(self, config: Float8LinearConfig) -> None: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) - @pytest.mark.skipif(not is_sm_89(), reason="requires SM89 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_89(), reason="requires SM89 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( self, @@ -196,7 +202,9 @@ def test_encoder_fw_bw_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) - @pytest.mark.skipif(not is_sm_90(), reason="requires SM90 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_90(), reason="requires SM90 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_recipe( self, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 54b93f729..10f2d157f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -92,7 +92,7 @@ unwrap_tensor_subclass, is_fbcode, benchmark_model, - is_sm_90, + is_sm_at_least_90, ) from torchao.dtypes.utils import is_device @@ -779,7 +779,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_sm_90(), "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype @@ -799,7 +799,7 @@ def test_autoquantizable_flatten_unflatten(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_sm_90(), "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): @@ -813,7 +813,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_sm_90(), "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index f64940b07..3e8c9b0a0 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -13,7 +13,7 @@ import pytest import torch from parameterized import parameterized -from torchao.utils import is_sm_90 +from torchao.utils import is_sm_at_least_90 logging.basicConfig(level=logging.INFO) @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_sm_90(), "Needs H100") + @unittest.skipIf(not is_sm_at_least_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 5cf78cce2..4cac94031 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -20,7 +20,7 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_89 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 torch.manual_seed(2) @@ -99,7 +99,7 @@ def test_linear_compile(elem_dtype, bias): Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_89(): + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") input_shape = (2, 4) grad_shape = (2, 6) @@ -170,7 +170,7 @@ def test_inference_compile_simple(elem_dtype): Smoke test for inference compile """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_89(): + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index d7b60f384..522785ae6 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -24,7 +24,7 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_89 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 torch.manual_seed(2) @@ -222,7 +222,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_89(): + if not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ddeb4ef2f..a213ffd90 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -52,8 +52,8 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_MI300, - is_sm_89, - is_sm_90, + is_sm_at_least_89, + is_sm_at_least_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -857,11 +857,11 @@ def _normalize_granularity( for _granularity in processed_granularity: if isinstance(_granularity, PerTensor): assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" elif isinstance(_granularity, PerRow): assert ( - is_sm_90() or is_MI300() + is_sm_at_least_90() or is_MI300() ), "PerRow quantization only works for CUDA>=9.0 and MI300+" else: raise ValueError(f"Invalid granularity type: {_granularity}") @@ -959,7 +959,7 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -1016,7 +1016,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index ba91fb3fe..d56191ed6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -33,8 +33,8 @@ "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", "is_MI300", - "is_sm_89", - "is_sm_90", + "is_sm_at_least_89", + "is_sm_at_least_90", ] @@ -612,7 +612,7 @@ def is_MI300(): return False -def is_sm_89(): +def is_sm_at_least_89(): return ( torch.cuda.is_available() and torch.version.cuda @@ -620,7 +620,7 @@ def is_sm_89(): ) -def is_sm_90(): +def is_sm_at_least_90(): return ( torch.cuda.is_available() and torch.version.cuda