Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static quantization as an example for calibration flow #487

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
aqt = apply_int4_weight_only_quant(t)
ql = apply_int4_weight_only_quant(l)
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized,
to_affine_quantized_static,
LayoutType,
PlainLayoutType,
TensorCoreTiledLayoutType,
Expand All @@ -15,6 +16,7 @@
"UInt4Tensor"
"AffineQuantizedTensor",
"to_affine_quantized",
"to_affine_quantized_static",
"LayoutType",
"PlainLayoutType",
"TensorCoreTiledLayoutType",
Expand Down
36 changes: 35 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def from_float(

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
Expand All @@ -246,8 +247,40 @@ def from_float(
dtype=input_float.dtype
)

@classmethod
def from_float_static(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
):
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
return cls(
layout_tensor,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)

@property
def layout_type(self) -> str:
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type

@classmethod
Expand Down Expand Up @@ -809,3 +842,4 @@ def t(func, *args, **kwargs):
return return_and_correct_aliasing(func, args, kwargs, new)

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
5 changes: 3 additions & 2 deletions torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones
from torchao.ops import quant_llm_linear
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
from torchao.quantization.quant_api import _get_linear_subclass_inserter


_ONES_TABLE = [_n_ones(i) for i in range(8)]
Expand Down Expand Up @@ -456,8 +457,8 @@ def apply_quant_llm(weight: Tensor) -> Tensor:
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
return weight
return QuantLlmLinearWeight.from_float(weight, ebits, mbits)
return apply_quant_llm
return _get_linear_subclass_inserter(apply_quant_llm)


def fp6_llm_weight_only():
return quant_llm_fpx_weight_only(3, 2)
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
21 changes: 13 additions & 8 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,12 @@ def insert_subclass(lin):

return insert_subclass

def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim I made some changes to apply_tensor_subclass to accommodate static quant use cases, please take a look again

"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace

Args:
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance)
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
Expand Down Expand Up @@ -300,19 +300,24 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Ten
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")

def apply_weight_quant_to_linear(linear):
linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False)
return linear

# apply to modules under block0 submodule
def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, apply_weight_quant, filter_fn)
quantize_(m, apply_weight_quant_to_linear, filter_fn)

"""
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(apply_tensor_subclass),
apply_tensor_subclass,
_is_linear if filter_fn is None else filter_fn,
)

Expand Down Expand Up @@ -356,7 +361,7 @@ def get_per_token_block_size(x):
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return apply_int8_dynamic_activation_int4_weight_quant
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant)


def int4_weight_only(group_size=128, inner_k_tiles=8):
Expand Down Expand Up @@ -394,7 +399,7 @@ def apply_int4_weight_only_quant(weight):
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)

return apply_int4_weight_only_quant
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)


def int8_weight_only():
Expand All @@ -412,7 +417,7 @@ def apply_int8wo_quant(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

return apply_int8wo_quant
return _get_linear_subclass_inserter(apply_int8wo_quant)

def int8_dynamic_activation_int8_weight():
"""
Expand Down Expand Up @@ -454,4 +459,4 @@ def get_per_token_block_size(x):
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return apply_int8_dynamic_activation_int8_weight_quant
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)
145 changes: 145 additions & 0 deletions tutorials/calibration_flow/static_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Demo for static quantization flow
"""
import torch
import copy

# TODO: use the generalized observer for affine qunatization in the future
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
import torch.nn.functional as F
from torch import Tensor
from torchao.dtypes import to_affine_quantized_static
from torchao.quantization.utils import compute_error
from torchao.quantization import quantize_
from torchao.quantization.subclass import to_linear_act_quantized
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter



class ObservedLinear(torch.nn.Linear):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.act_obs = act_obs
self.weight_obs = weight_obs

def forward(self, input: Tensor):
observed_input = self.act_obs(input)
observed_weight = self.weight_obs(self.weight)
return F.linear(observed_input, observed_weight, self.bias)

@classmethod
def from_float(cls, float_linear, act_obs, weight_obs):
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear

def insert_observers_(model, act_obs, weight_obs):
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs)
act_obs = copy.deepcopy(act_obs)
weight_obs = copy.deepcopy(weight_obs)
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_static_quant(observed_linear):
target_dtype = torch.uint8

# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
def weight_quant_func(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype)
linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False)

return linear


# alternative for converting observed linear module to quantized linear module
class QuantizedLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
super().__init__()
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
assert weight.dim() == 2
block_size = (1, weight.shape[1])
target_dtype = torch.uint8
self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
self.bias = bias

def forward(self, input: Tensor):
block_size = input.shape
target_dtype = torch.uint8
qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
return F.linear(qinput, self.qweight, self.bias)

@classmethod
def from_observed(cls, observed_linear):
quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias)
return quantized_linear

def apply_static_quant2(observed_linear):
return QuantizedLinear.from_observed(observed_linear)

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

dtype = torch.bfloat16
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")

m_bf16 = torch.compile(m_bf16, mode='max-autotune')

# TODO: use the generalized observer for affine qunatization in the future
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")

before_quant = m(*example_inputs)

insert_observers_(m, act_obs, weight_obs)
# calibrating / training
for _ in range(10):
m(*example_inputs)

after_obs = m(*example_inputs)

m2 = copy.deepcopy(m)

is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_static_quant, is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 30
print("test passed")

# quantized linear as a standalone module
quantize_(m2, apply_static_quant2, is_observed_linear)
print("quantized model (quantized module):", m2)
after_quant = m2(*example_inputs)
assert compute_error(before_quant, after_quant) > 30
print("test passed")
Loading