Skip to content

Commit

Permalink
Added support for Per Tensor Scaling for Float8 Dynamic Autoquant (#1175
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jainapurva authored Oct 28, 2024
1 parent 4ff8784 commit 79ea660
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
19 changes: 16 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -770,11 +771,23 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Fails for {dtype}")
with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)
else:
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
40 changes: 39 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import torch
import torchao
from torchao.quantization.quant_primitives import (
Expand Down Expand Up @@ -500,7 +501,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActiv
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling
"""
activation_granularity: str = PerRow()
activation_granularity = PerRow()
@classmethod
def from_float(cls, weight):

Expand Down Expand Up @@ -537,6 +538,42 @@ def get_per_token_block_size(x):
weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per tensor scaling
"""
activation_granularity = PerTensor()
@classmethod
def from_float(cls, weight):

# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
# weight settings
def get_weight_block_size(x):
assert x.ndim == 2, "Only works for 2D tensors"
return x.shape
target_dtype = torch.float8_e4m3fn

input_target_dtype = torch.float8_e4m3fn
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
from torchao.float8.inference import _is_rowwise_scaled
weight = super(AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -557,6 +594,7 @@ def get_per_token_block_size(x):
OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
]


Expand Down

0 comments on commit 79ea660

Please sign in to comment.