Skip to content

Commit

Permalink
RUpdate is_sm_ -> is_sm_at_least
Browse files Browse the repository at this point in the history
ruff
  • Loading branch information
jainapurva committed Nov 27, 2024
1 parent 73957a5 commit f292646
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 60 deletions.
8 changes: 6 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
)


def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
Expand All @@ -40,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
)

if is_sm_89():
if is_sm_at_least_89():
base_functions.append(float8_weight_only())

return base_functions
Expand Down
34 changes: 24 additions & 10 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
choose_qparams_affine,
)
from torchao.utils import (
is_sm_89,
is_sm_90,
is_sm_at_least_89,
is_sm_at_least_90,
)

random.seed(0)
Expand All @@ -60,12 +60,14 @@ def forward(self, x):

class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_sm_90() else [PerTensor()]
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
)
# Inputs are (M,..), K, N
@common_utils.parametrize(
Expand Down Expand Up @@ -135,20 +137,26 @@ def test_fp8_linear_variants(
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"

@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")

@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))

@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass
Expand All @@ -159,7 +167,9 @@ class UnsupportedGranularity:
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_per_row_with_float32(self):
with pytest.raises(
AssertionError,
Expand All @@ -171,7 +181,9 @@ def test_per_row_with_float32(self):
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
def test_serialization(self, mode: str):
# Create and quantize the model
Expand Down Expand Up @@ -241,7 +253,9 @@ def test_serialization(self, mode: str):
), f"Scales do not match for {layer_name}"

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_fp8_weight_dimension_warning(self):
# Create model with incompatible dimensions (not multiples of 16)
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights
Expand Down
28 changes: 19 additions & 9 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
import torch
import torch.nn as nn

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -215,7 +219,7 @@ def test_axiswise_reshape(self):
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_90(), "Requires CUDA capability >= 9.0")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
Expand Down Expand Up @@ -329,7 +333,9 @@ def _test_linear_impl(
# verify initialization flags got updated
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"

@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
"scaling_type_input",
Expand Down Expand Up @@ -411,7 +417,9 @@ def test_linear_from_recipe(
config,
)

@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
Expand Down Expand Up @@ -458,7 +466,9 @@ def test_autocast_outputs(
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
Expand Down Expand Up @@ -519,7 +529,7 @@ def test_repr(self):
s = m.__repr__()
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s

@unittest.skipIf(not is_sm_89(), "CUDA 8.9 not available")
@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
def test_inference_mode(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
Expand All @@ -530,7 +540,7 @@ def test_inference_mode(self):

class TestScaledMM:
@unittest.skipIf(
not is_sm_89(),
not is_sm_at_least_89(),
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -575,7 +585,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_sm_89(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
Expand Down Expand Up @@ -611,7 +621,7 @@ def test_different_configs_error(self):
a @ b

@unittest.skipIf(
not is_sm_89(),
not is_sm_at_least_89(),
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down
28 changes: 17 additions & 11 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -95,7 +99,7 @@ def _test_compile_base(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_eager_only(
Expand All @@ -122,7 +126,7 @@ def test_eager_only(


@pytest.mark.parametrize("fullgraph", [True])
@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
Expand Down Expand Up @@ -173,7 +177,7 @@ def test_aot_eager(
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_89(),
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
Expand Down Expand Up @@ -211,7 +215,9 @@ def test_inductor_from_config_params(
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
],
)
@unittest.skipIf(not is_sm_90(), "CUDA with capability 9.0 or greater not available")
@unittest.skipIf(
not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available"
)
def test_inductor_from_recipe(recipe_name):
torch._dynamo.reset()
config = recipe_name_to_linear_config(recipe_name)
Expand Down Expand Up @@ -249,7 +255,7 @@ def forward(self, x):

# TODO(future): figure out why the test below fails on CUDA capability 8.9
@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_90(),
not torch.cuda.is_available() or not is_sm_at_least_90(),
"CUDA with capability 9.0 or greater not available",
)
def test_float8_with_graph_break_in_the_middle(self):
Expand All @@ -265,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self):
torch.testing.assert_close(y_eager, y_compiled)

@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_89(),
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_float8_graph_input(self):
Expand All @@ -289,7 +295,7 @@ def to_float(x):
torch.testing.assert_close(y2_eager, y2_compiled)

@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_89(),
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_float8_graph_output(self):
Expand Down Expand Up @@ -319,7 +325,7 @@ def test_float8_graph_output(self):


@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_89(),
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_sync_amax_func():
Expand Down Expand Up @@ -360,7 +366,7 @@ def __exit__(self, *args):


@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_89(),
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_sync_amax_func_cuda_graph_success():
Expand Down Expand Up @@ -392,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success():


@unittest.skipIf(
not is_sm_89(),
not is_sm_at_least_89(),
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -40,7 +40,7 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp

if not is_sm_89():
if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)


Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand All @@ -30,7 +30,7 @@
from torchao.float8.float8_tensor import GemmInputRole
from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only

if not is_sm_89():
if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)


Expand Down
14 changes: 11 additions & 3 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -173,7 +177,9 @@ def _test_impl(self, config: Float8LinearConfig) -> None:
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.skipif(not is_sm_89(), reason="requires SM89 compatible machine")
@pytest.mark.skipif(
not is_sm_at_least_89(), reason="requires SM89 compatible machine"
)
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw_from_config_params(
self,
Expand All @@ -196,7 +202,9 @@ def test_encoder_fw_bw_from_config_params(
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
],
)
@pytest.mark.skipif(not is_sm_90(), reason="requires SM90 compatible machine")
@pytest.mark.skipif(
not is_sm_at_least_90(), reason="requires SM90 compatible machine"
)
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw_from_recipe(
self,
Expand Down
Loading

0 comments on commit f292646

Please sign in to comment.