Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[reland] Add support for Int8DynActInt4WeightQuantizer (#66) #74

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
Quantizer,
TwoStepQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
Int8DynActInt4WeightLinear,
)
from pathlib import Path
from sentencepiece import SentencePieceProcessor
Expand Down Expand Up @@ -85,8 +87,11 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -97,8 +102,7 @@ class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = M().eval()
m = _apply_dynamic_quant(m)
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
quantized = m(*example_inputs)
quantized = m(*m.example_inputs())
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
Expand All @@ -110,9 +114,9 @@ def test_dynamic_quant_gpu_singleline(self):
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
quantizer = XNNPackDynamicQuantizer()
m = M().eval()
example_inputs = m.example_inputs()
m = quantizer.prepare(m)
m = quantizer.convert(m)
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand All @@ -125,15 +129,24 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self):
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
quantizer = TorchCompileDynamicQuantizer()
m = M().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
quantized = m(*example_inputs)
m = torch.compile(m, mode="max-autotune")
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

def test_8da4w_quantizer(self):
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
m = M().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq(self):
def test_gptq_quantizer(self):
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# from model import Transformer # pyre-ignore[21]
from torch.utils._pytree import tree_flatten, tree_unflatten
import logging

# pyre-fixme[5]: Global expression must be annotated.
aten = torch.ops.aten
Expand Down Expand Up @@ -63,7 +64,7 @@ def model_forward(model, x, input_pos):
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate
else:
print("lm_eval is not installed, GPTQ may not be usable")
logging.info("lm_eval is not installed, GPTQ may not be usable")

# pyre-fixme[3]: Return type must be annotated.
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
Expand Down
106 changes: 93 additions & 13 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
from .quant_primitives import (
get_group_qparams_symmetric,
per_token_dynamic_quant,
group_quantize_tensor_symmetric,
)
from typing import Dict, Tuple
from typing import Dict, Tuple, Any
import logging

__all__ = [
"apply_weight_only_int8_quant",
Expand All @@ -54,21 +56,18 @@
############################# Unified Quantization APIs ##############################
# API 1, single quantize call to create a quantized model with quantized state_dict
class Quantizer:
# pyre-fixme[2]: Parameter must be annotated.
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
pass


# API 2, flow that needs calibration or training
class TwoStepQuantizer:
# pyre-fixme[2]: Parameter must be annotated.
def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
pass

# pyre-fixme[2]: Parameter must be annotated.
def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
pass

Expand Down Expand Up @@ -260,7 +259,7 @@ def replace_conv2d_1x1(conv):
MultiInput,
)
else:
print("lm_eval not available, skip defining GPTQQuantizer")
logging.info("lm_eval not available, skip defining GPTQQuantizer")


class GPTQQuantizer(Quantizer):
Expand Down Expand Up @@ -442,11 +441,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":

@torch.no_grad()
# pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
def quantize(
self,
# pyre-fixme[2]: Parameter must be annotated.
model,
) -> torch.nn.Module:
def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(
model,
# pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
Expand Down Expand Up @@ -686,6 +681,91 @@ def replace_linear_8da4w(
)


class Int8DynActInt4WeightQuantizer(Quantizer):
def __init__(
self,
group_size: int = 256,
padding_allowed: bool = False,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
self.group_size: int = group_size
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
self.scales_precision: torch.dtype = scales_precision
# assert group_size in [32, 64, 128, 256]

@torch.no_grad()
def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.group_size == 0
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"

weight = mod.weight.data
"""
if not _check_linear_int4_k(
in_features, self.group_size
):
if self.padding_allowed:
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = _calc_padded_size_linear_int4(
in_features, self.group_size
)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
raise RuntimeError(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that group_size"
)
"""
(
weight_int8,
scales,
zeros,
) = group_quantize_tensor_symmetric(
weight.to(self.precision),
4, # n_bit
self.group_size,
self.scales_precision,
)
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
# TODO: support bias?

return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
replace_linear_8da4w(
model,
self.group_size,
self.padding_allowed,
self.precision,
self.scales_precision,
)
return model

def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(model)
model = self._convert_for_runtime(model)
# TODO: make it strict
model.load_state_dict(state_dict, strict=False)
return model


class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
# pyre-fixme[3]: Return type must be annotated.
def __init__(
Expand Down