From 862fb22f7f53b4d328891c9d62cfe81bb6d89160 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 28 Aug 2024 20:59:44 -0700 Subject: [PATCH] fix memory being held by autograd --- test/dtypes/test_affine_quantized.py | 69 ++++++++++++++-------- test/dtypes/test_affine_quantized_float.py | 50 ++-------------- torchao/dtypes/affine_quantized_tensor.py | 44 +++++++++----- torchao/float8/inference.py | 8 ++- torchao/quantization/quant_api.py | 15 +++-- torchao/quantization/quant_primitives.py | 30 +++++----- 6 files changed, 109 insertions(+), 107 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e019468888..206b53e17a 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -10,12 +10,31 @@ int8_dynamic_activation_int8_semi_sparse_weight, float8_weight_only, ) +from torch.testing._internal import common_utils from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 import torch import unittest import tempfile +is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + + +def get_quantization_functions(do_sparse: bool, do_float8: bool): + base_functions = [ + int4_weight_only(group_size=32), + int8_weight_only(), + int8_dynamic_activation_int4_weight(), + int8_dynamic_activation_int8_weight(), + ] + if do_sparse: + base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight()) + + if is_cuda_8_9 and do_float8: # You need to define this function + base_functions.append(float8_weight_only()) + + return base_functions + class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -38,36 +57,36 @@ def test_tensor_core_layout_transpose(self): self.assertEqual(aqt_shape, shape) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_weights_only(self): - for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), - int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]: - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(l) - with tempfile.NamedTemporaryFile() as f: - torch.save(ql.state_dict(), f) - f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) + @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) + def test_weights_only(self, apply_quant): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ql = apply_quant(l) + with tempfile.NamedTemporaryFile() as f: + torch.save(ql.state_dict(), f) + f.seek(0) + # `weights_only=True` is enabled for torch 2.5+ + if TORCH_VERSION_AT_LEAST_2_5: + _ = torch.load(f, weights_only=True) + else: + _ = torch.load(f, weights_only=False) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_to_device(self): - from torchao.quantization import quantize_ - for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]: - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) - ql.to("cuda") + @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) + def test_to_device(self, apply_quant): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.to("cuda") + + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.to(device="cuda") - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) - ql.to(device="cuda") + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.cuda() - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) - ql.cuda() +common_utils.instantiate_parametrized_tests(TestAffineQuantized) if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 8e700116b0..7e2ce278d5 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -13,10 +13,6 @@ ) from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import ( - TestCase, - run_tests, -) from torch._dynamo.testing import CompileCounterWithBackend from torchao.quantization import ( @@ -54,46 +50,9 @@ def forward(self, x): return x -class TestAffineQuantizedFloat8Basic(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_tensor_core_layout_transpose(self): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - t = l.weight - shape = t.shape - apply_float8_weight_only_quant = float8_weight_only() - ql = apply_float8_weight_only_quant(l) - aqt = ql.weight - aqt_shape = aqt.shape - assert aqt_shape == shape - - # transpose shape test - for _ in range(10): - t = t.t() - aqt = aqt.t() - shape = t.shape - aqt_shape = aqt.shape - assert aqt_shape == shape - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_weights_only_save_load(self): - with torch.no_grad(): - for apply_quant in [float8_weight_only()]: - # TODO Fails when l requires grad - l = torch.nn.Linear(128, 256).eval().to(torch.bfloat16).to("cuda") - ql = apply_quant(l) - with tempfile.NamedTemporaryFile() as f: - torch.save(ql.state_dict(), f) - f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) - - class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Need H100") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @common_utils.parametrize("compile", [True, False]) @@ -108,7 +67,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): ((64, 256), 512, 128), ], ) - def test_dynamic_fp8_linear( + def test_fp8_linear_variants( self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple ): M, N, K = sizes @@ -132,7 +91,10 @@ def test_dynamic_fp8_linear( output_original = model(input_tensor) output_quantized = quantized_model(input_tensor) - assert compute_error(output_original, output_quantized) > 20, "Error is too low" + error = compute_error(output_original, output_quantized) + assert ( + compute_error(output_original, output_quantized) > 20 + ), f"Quantization error is too high got a SQNR of {error}" common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 5395a9102a..350bcbfd26 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -214,12 +214,12 @@ def from_hp_to_intx( block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_max: Optional[int] = None, eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, layout_type: LayoutType = PlainLayoutType(), use_hqq: bool = False, ): @@ -237,6 +237,8 @@ def from_hp_to_intx( data = data.to(target_dtype) else: scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + if zero_point_domain is None: + zero_point = None data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) # Note: output will be uint8 tensor for sub byte tensors for now @@ -262,7 +264,7 @@ def from_hp_to_intx_static( block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, layout_type: LayoutType = PlainLayoutType(), ): @@ -291,8 +293,8 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - scale_dtype: Optional[torch.dtype] = None, - layout_type: LayoutType = PlainLayoutType(), + scale_dtype: Optional[torch.dtype], + layout_type: LayoutType, ): if target_dtype in FP8_TYPES: @@ -400,10 +402,8 @@ def extra_repr(self): @dataclass(frozen=True) class Float8LayoutType(LayoutType): - mm_config: ScaledMMConfig + mm_config: Optional[ScaledMMConfig] - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return input @register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): @@ -602,9 +602,18 @@ def _apply_fn_to_data(self, fn): fn(self.scale) return self + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.float8_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.transposed, + self.layout_type, + ) + def __tensor_flatten__(self): return ["float8_data", "scale"], [self.transposed, self.layout_type] - + @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride @@ -621,6 +630,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose @@ -650,6 +663,7 @@ def from_plain( ): """ Main entrypoint for constructing Float8Layout Tensor""" assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): @@ -1027,14 +1041,14 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): def _linear_fp_act_fp8_tensor_wise_weight_check( - input_tensor: torch.Tensor, - weight_tensor: AffineQuantizedTensor, + input_tensor: Union[torch.Tensor, AffineQuantizedTensor], + weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], ) -> bool: - def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool: + def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt.layout_tensor, Float8AQTLayout) + isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and aqt.shape == aqt.block_size ) @@ -1047,7 +1061,7 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data + from torchao.float8.inference import preprocess_data from torchao.float8.float8_tensor import ScaledMMConfig from torchao.float8.float8_python_api import addmm_float8_unwrapped @@ -1066,7 +1080,7 @@ def _linear_fp_act_fp8_weight_impl( # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) input_scale = input_tensor.layout_tensor.scale - if input_scale.dim() >= 2: + if input_scale.dim() > 2: input_scale = input_scale.reshape(-1, input_scale.shape[-1]) inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 66f83d933a..b3d5de1440 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -243,10 +243,14 @@ def quantize_to_float8( module_filter_fn=module_filter_fn, ) + from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul -def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]: - """ Preprocess the inner fp8 data tensors for admmm + +def preprocess_data( + a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig +) -> Tuple[torch.Tensor, torch.Tensor]: + """Preprocess the inner fp8 data tensors for admmm Args: a_data: Input tensor A. b_data: Input tensor B. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3ebbc83117..204238607a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -512,6 +512,7 @@ def apply_float8wo_quant(weight): input_float=weight, block_size=block_size, target_dtype=target_dtype, + scale_dtype=None, layout_type=Float8LayoutType(mm_config=None), ) @@ -519,9 +520,9 @@ def apply_float8wo_quant(weight): def float8_dynamic_activation_float8_weight( - target_dtype: torch.dtype = torch.float8_e4m3fn, activation_dtype: torch.dtype = torch.float8_e4m3fn, - mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True) + weight_dtype: torch.dtype = torch.float8_e4m3fn, + mm_config: Optional[ScaledMMConfig] = None ): """ Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers. @@ -532,17 +533,19 @@ def float8_dynamic_activation_float8_weight( mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ - from torchao.dtypes import to_affine_quantized_floatx + if mm_config is None: + mm_config = ScaledMMConfig(use_fast_accum=True) + #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling def apply_float8_dynamic_activation_quant(weight: torch.Tensor): quantized_weight = to_affine_quantized_floatx( input_float=weight, block_size=weight.shape, - target_dtype=target_dtype, + target_dtype=weight_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), + layout_type=Float8LayoutType(mm_config=mm_config), ) def input_quant_func(x: torch.Tensor): @@ -551,7 +554,7 @@ def input_quant_func(x: torch.Tensor): block_size=x.shape, target_dtype=activation_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), + layout_type=Float8LayoutType(mm_config=mm_config), ) return activation diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c90fc403ea..fa4460cd0f 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -159,7 +159,7 @@ def _get_reduction_params(block_size, input_size): cur_dim += 1 return shape_for_reduction, reduction_dims - +@torch.no_grad() def quantize_affine( input: torch.Tensor, block_size: Tuple[int, ...], @@ -267,10 +267,11 @@ def _quantize_affine_no_dtype_cast( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain is None: - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) - quant = quant.view(original_shape) - return quant + if zero_point is None: + assert zero_point_domain is None, "zero_point_domain should be None when zero_point is None" + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + quant = quant.view(original_shape) + return quant if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( @@ -385,15 +386,7 @@ def _dequantize_affine_no_dtype_check( shape_after_reduction = shape_for_reduction for i in reduction_dims: shape_after_reduction[i] = 1 - scale = scale.view(shape_after_reduction) - - # This case handles dequantization for float8 - if zero_point_domain is None: - assert zero_point is None, "zero_point should be None when zero_point_domain is None" - assert _is_float8_type(input.dtype), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale - return dequant.view(original_shape).to(output_dtype) + scale = scale.view(shape_after_reduction) if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) @@ -406,6 +399,13 @@ def _dequantize_affine_no_dtype_check( dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) dequant = dequant * scale + elif zero_point_domain is None: + # This case handles dequantization for float8 we expect no zero point and no zero point domain + assert zero_point is None, "zero_point should be None when zero_point_domain is None" + assert _is_float8_type(input.dtype), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + dequant = input.to(output_dtype) + dequant = dequant * scale + return dequant.view(original_shape).to(output_dtype) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}" # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) @@ -544,7 +544,7 @@ def _do_fake_quantize_affine( ) return (q, dq) - +@torch.no_grad() def choose_qparams_affine( input: torch.Tensor, mapping_type: MappingType,