Skip to content

Commit

Permalink
fix memory being held by autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 29, 2024
1 parent e559f2a commit ffeeb9a
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 108 deletions.
70 changes: 45 additions & 25 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,32 @@
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

import torch
import unittest
import tempfile

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))
if do_sparse:
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())

if is_cuda_8_9 and do_float8: # You need to define this function
base_functions.append(float8_weight_only())

return base_functions


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand All @@ -38,36 +58,36 @@ def test_tensor_core_layout_transpose(self):
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
def test_weights_only(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_device(self):
from torchao.quantization import quantize_
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()

common_utils.instantiate_parametrized_tests(TestAffineQuantized)

if __name__ == "__main__":
run_tests()
50 changes: 6 additions & 44 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.quantization import (
Expand Down Expand Up @@ -54,46 +50,9 @@ def forward(self, x):
return x


class TestAffineQuantizedFloat8Basic(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
shape = t.shape
apply_float8_weight_only_quant = float8_weight_only()
ql = apply_float8_weight_only_quant(l)
aqt = ql.weight
aqt_shape = aqt.shape
assert aqt_shape == shape

# transpose shape test
for _ in range(10):
t = t.t()
aqt = aqt.t()
shape = t.shape
aqt_shape = aqt.shape
assert aqt_shape == shape

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only_save_load(self):
with torch.no_grad():
for apply_quant in [float8_weight_only()]:
# TODO Fails when l requires grad
l = torch.nn.Linear(128, 256).eval().to(torch.bfloat16).to("cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)


class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Need H100")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
Expand All @@ -108,7 +67,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
((64, 256), 512, 128),
],
)
def test_dynamic_fp8_linear(
def test_fp8_linear_variants(
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
):
M, N, K = sizes
Expand All @@ -132,7 +91,10 @@ def test_dynamic_fp8_linear(
output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

assert compute_error(output_original, output_quantized) > 20, "Error is too low"
error = compute_error(output_original, output_quantized)
assert (
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
Expand Down
44 changes: 29 additions & 15 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,12 @@ def from_hp_to_intx(
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
use_hqq: bool = False,
):
Expand All @@ -237,6 +237,8 @@ def from_hp_to_intx(
data = data.to(target_dtype)
else:
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
if zero_point_domain is None:
zero_point = None
data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
# Note: output will be uint8 tensor for sub byte tensors for now

Expand All @@ -262,7 +264,7 @@ def from_hp_to_intx_static(
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
):
Expand Down Expand Up @@ -291,8 +293,8 @@ def from_hp_to_floatx(
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
scale_dtype: Optional[torch.dtype] = None,
layout_type: LayoutType = PlainLayoutType(),
scale_dtype: Optional[torch.dtype],
layout_type: LayoutType,
):

if target_dtype in FP8_TYPES:
Expand Down Expand Up @@ -400,10 +402,8 @@ def extra_repr(self):

@dataclass(frozen=True)
class Float8LayoutType(LayoutType):
mm_config: ScaledMMConfig
mm_config: Optional[ScaledMMConfig]

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
return input

@register_layout_cls(PlainLayoutType)
class PlainAQTLayout(AQTLayout):
Expand Down Expand Up @@ -602,9 +602,18 @@ def _apply_fn_to_data(self, fn):
fn(self.scale)
return self

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.float8_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.transposed,
self.layout_type,
)

def __tensor_flatten__(self):
return ["float8_data", "scale"], [self.transposed, self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
Expand All @@ -621,6 +630,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
Expand Down Expand Up @@ -650,6 +663,7 @@ def from_plain(
):
""" Main entrypoint for constructing Float8Layout Tensor"""
assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}"
assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}"
return cls(data, scale, False, layout_type)

def __repr__(self):
Expand Down Expand Up @@ -1027,14 +1041,14 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):


def _linear_fp_act_fp8_tensor_wise_weight_check(
input_tensor: torch.Tensor,
weight_tensor: AffineQuantizedTensor,
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool:
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
return (
isinstance(aqt, AffineQuantizedTensor) and
isinstance(aqt.layout_tensor, Float8AQTLayout)
isinstance(aqt.layout_type, Float8LayoutType)
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and aqt.shape == aqt.block_size
)
Expand All @@ -1047,7 +1061,7 @@ def _linear_fp_act_fp8_weight_impl(
bias: Optional[torch.Tensor],
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data
from torchao.float8.inference import preprocess_data
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_python_api import addmm_float8_unwrapped

Expand All @@ -1066,7 +1080,7 @@ def _linear_fp_act_fp8_weight_impl(
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
input_scale = input_tensor.layout_tensor.scale
if input_scale.dim() >= 2:
if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])

inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
Expand Down
8 changes: 6 additions & 2 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,14 @@ def quantize_to_float8(
module_filter_fn=module_filter_fn,
)


from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul

def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]:
""" Preprocess the inner fp8 data tensors for admmm

def preprocess_data(
a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Preprocess the inner fp8 data tensors for admmm
Args:
a_data: Input tensor A.
b_data: Input tensor B.
Expand Down
15 changes: 9 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,17 @@ def apply_float8wo_quant(weight):
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
scale_dtype=None,
layout_type=Float8LayoutType(mm_config=None),
)

return _get_linear_subclass_inserter(apply_float8wo_quant)


def float8_dynamic_activation_float8_weight(
target_dtype: torch.dtype = torch.float8_e4m3fn,
activation_dtype: torch.dtype = torch.float8_e4m3fn,
mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True)
weight_dtype: torch.dtype = torch.float8_e4m3fn,
mm_config: Optional[ScaledMMConfig] = None
):
"""
Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers.
Expand All @@ -532,17 +533,19 @@ def float8_dynamic_activation_float8_weight(
mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
"""

from torchao.dtypes import to_affine_quantized_floatx

if mm_config is None:
mm_config = ScaledMMConfig(use_fast_accum=True)

#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=weight.shape,
target_dtype=target_dtype,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=None),
layout_type=Float8LayoutType(mm_config=mm_config),
)

def input_quant_func(x: torch.Tensor):
Expand All @@ -551,7 +554,7 @@ def input_quant_func(x: torch.Tensor):
block_size=x.shape,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=None),
layout_type=Float8LayoutType(mm_config=mm_config),
)
return activation

Expand Down
Loading

0 comments on commit ffeeb9a

Please sign in to comment.