Skip to content

Commit

Permalink
[reland] Add support for Int8DynActInt4WeightQuantizer (pytorch#66) (p…
Browse files Browse the repository at this point in the history
…ytorch#74)

Summary:
att

Test Plan: python test/quantization/test_quant_api.py -k test_8da4w_quantizer

Reviewed By: cpuhrsch

Differential Revision: D55101038

Pulled By: jerryzh168

[ghstack-poisoned]
  • Loading branch information
jerryzh168 authored Mar 21, 2024
1 parent 530f71b commit e980f49
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 21 deletions.
27 changes: 20 additions & 7 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
Quantizer,
TwoStepQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
Int8DynActInt4WeightLinear,
)
from pathlib import Path
from sentencepiece import SentencePieceProcessor
Expand Down Expand Up @@ -85,8 +87,11 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -97,8 +102,7 @@ class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = M().eval()
m = _apply_dynamic_quant(m)
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
quantized = m(*example_inputs)
quantized = m(*m.example_inputs())
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
Expand All @@ -110,9 +114,9 @@ def test_dynamic_quant_gpu_singleline(self):
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
quantizer = XNNPackDynamicQuantizer()
m = M().eval()
example_inputs = m.example_inputs()
m = quantizer.prepare(m)
m = quantizer.convert(m)
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand All @@ -125,15 +129,24 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self):
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
quantizer = TorchCompileDynamicQuantizer()
m = M().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
quantized = m(*example_inputs)
m = torch.compile(m, mode="max-autotune")
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

def test_8da4w_quantizer(self):
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
m = M().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq(self):
def test_gptq_quantizer(self):
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# from model import Transformer # pyre-ignore[21]
from torch.utils._pytree import tree_flatten, tree_unflatten
import logging

# pyre-fixme[5]: Global expression must be annotated.
aten = torch.ops.aten
Expand Down Expand Up @@ -63,7 +64,7 @@ def model_forward(model, x, input_pos):
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate
else:
print("lm_eval is not installed, GPTQ may not be usable")
logging.info("lm_eval is not installed, GPTQ may not be usable")

# pyre-fixme[3]: Return type must be annotated.
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
Expand Down
106 changes: 93 additions & 13 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
from .quant_primitives import (
get_group_qparams_symmetric,
per_token_dynamic_quant,
group_quantize_tensor_symmetric,
)
from typing import Dict, Tuple
from typing import Dict, Tuple, Any
import logging

__all__ = [
"apply_weight_only_int8_quant",
Expand All @@ -54,21 +56,18 @@
############################# Unified Quantization APIs ##############################
# API 1, single quantize call to create a quantized model with quantized state_dict
class Quantizer:
# pyre-fixme[2]: Parameter must be annotated.
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
pass


# API 2, flow that needs calibration or training
class TwoStepQuantizer:
# pyre-fixme[2]: Parameter must be annotated.
def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
pass

# pyre-fixme[2]: Parameter must be annotated.
def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
pass

Expand Down Expand Up @@ -260,7 +259,7 @@ def replace_conv2d_1x1(conv):
MultiInput,
)
else:
print("lm_eval not available, skip defining GPTQQuantizer")
logging.info("lm_eval not available, skip defining GPTQQuantizer")


class GPTQQuantizer(Quantizer):
Expand Down Expand Up @@ -442,11 +441,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":

@torch.no_grad()
# pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
def quantize(
self,
# pyre-fixme[2]: Parameter must be annotated.
model,
) -> torch.nn.Module:
def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(
model,
# pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
Expand Down Expand Up @@ -686,6 +681,91 @@ def replace_linear_8da4w(
)


class Int8DynActInt4WeightQuantizer(Quantizer):
def __init__(
self,
group_size: int = 256,
padding_allowed: bool = False,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
self.group_size: int = group_size
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
self.scales_precision: torch.dtype = scales_precision
# assert group_size in [32, 64, 128, 256]

@torch.no_grad()
def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.group_size == 0
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"

weight = mod.weight.data
"""
if not _check_linear_int4_k(
in_features, self.group_size
):
if self.padding_allowed:
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = _calc_padded_size_linear_int4(
in_features, self.group_size
)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
raise RuntimeError(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that group_size"
)
"""
(
weight_int8,
scales,
zeros,
) = group_quantize_tensor_symmetric(
weight.to(self.precision),
4, # n_bit
self.group_size,
self.scales_precision,
)
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
# TODO: support bias?

return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
replace_linear_8da4w(
model,
self.group_size,
self.padding_allowed,
self.precision,
self.scales_precision,
)
return model

def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(model)
model = self._convert_for_runtime(model)
# TODO: make it strict
model.load_state_dict(state_dict, strict=False)
return model


class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
# pyre-fixme[3]: Return type must be annotated.
def __init__(
Expand Down

0 comments on commit e980f49

Please sign in to comment.