-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mixed-precision quantization milestone1: naive_intNwo + eval/benchmar…
…k framework (#531) * milestone1: naive_intNwo + eval/benchmark * remove experiment scripts * remove exp files * use default ZeroPointDomain.INT for int2/3/5/6 * renamed test_naive_intNwo.py to test_mixed_precision.py * updated intNwo with _get_linear_subclass_inserter * adjust sqnr threshold according to bit width * fixed test for int4wo and add __init__.py * skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly * edit the sqnr threshold * add unittest * correct import path
- Loading branch information
Showing
5 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only | ||
from torchao.quantization.utils import compute_error | ||
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only | ||
|
||
_CUDA_IS_AVAILABLE = torch.cuda.is_available() | ||
|
||
class TestWeightOnlyQuantNaive(unittest.TestCase): | ||
|
||
def test_quantization_intNwo(self): | ||
#skip test int4wo for now since it is under development in torchao | ||
for quantization_bit in [2, 3, 5, 6, 8]: | ||
for symmetric in [False, True]: | ||
with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric): | ||
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]: | ||
x = torch.randn(*x_shape, dtype=torch.bfloat16) | ||
m = nn.Sequential(nn.Linear(32, 80)).bfloat16() | ||
y_ref = m(x) | ||
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric)) | ||
y_wo = m(x) | ||
sqnr = compute_error(y_ref, y_wo) | ||
# SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization | ||
# e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills | ||
expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02 | ||
self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Empty file.
1 change: 1 addition & 0 deletions
1
torchao/quantization/prototype/mixed_precision/scripts/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .naive_intNwo import intN_weight_only |
95 changes: 95 additions & 0 deletions
95
torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from naive_intNwo import intN_weight_only | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from lm_eval.models.huggingface import HFLM | ||
from lm_eval.evaluator import evaluate | ||
from lm_eval.tasks import get_task_dict | ||
|
||
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight | ||
from torchao._models._eval import TransformerEvalWrapper | ||
|
||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
|
||
from torchao.quantization.quant_api import autoquant | ||
|
||
|
||
torch._inductor.config.force_fuse_int_mm_with_mul = True | ||
torch._inductor.config.fx_graph_cache = True | ||
|
||
|
||
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size): | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(repo_id) | ||
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) | ||
|
||
if quantization == "autoquant": | ||
model = autoquant(model.to(device=device)) | ||
|
||
# naive implementation of uniform precision quantization all layers | ||
elif quantization in ["2","3","4","5","6","8"]: | ||
quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym)) | ||
|
||
# mix precision quantization for Llama3 | ||
elif quantization == "MP_llama3": | ||
|
||
# filter for sensitive layers (the first 3 and last 2 layers for Llama3) | ||
def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: | ||
return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) | ||
|
||
# filter for non-sensitive layers (other 27 layers for Llama3) | ||
def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: | ||
return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) | ||
|
||
# quantize the sensitive layers | ||
if sensi_bit != 16: | ||
quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen) | ||
|
||
# quantize the less-sensitive layers | ||
if sensi_bit == 4: | ||
quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) | ||
else: | ||
quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) | ||
|
||
if compile: | ||
model = torch.compile(model, mode="max-autotune", fullgraph=True) | ||
|
||
with torch.no_grad(): | ||
|
||
result = evaluate( | ||
HFLM( | ||
pretrained=model, | ||
tokenizer=tokenizer, | ||
batch_size=batch_size, | ||
max_length=max_length), | ||
get_task_dict(tasks), | ||
limit = limit, | ||
) | ||
|
||
for task, res in result["results"].items(): | ||
print(f"{task}: {res}") | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Run HF Model Evaluation') | ||
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') | ||
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') | ||
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') | ||
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') | ||
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') | ||
parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization') | ||
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') | ||
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') | ||
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') | ||
parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers') | ||
parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers') | ||
parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default') | ||
parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on') | ||
args = parser.parse_args() | ||
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size) |
60 changes: 60 additions & 0 deletions
60
torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
|
||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
|
||
from torchao.quantization import int8_weight_only, int4_weight_only | ||
from torchao.quantization.quant_api import _get_linear_subclass_inserter | ||
|
||
def intN_weight_only(group_size=32, n=8, symmetric=False): | ||
''' | ||
Apply int N-bit weight only quantization to a linear layer. | ||
Args: | ||
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] | ||
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] | ||
Usage: | ||
from torchao.quantization import quantize_ | ||
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize) | ||
''' | ||
# for asymmetric quantization | ||
def apply_intN_weight_only_quant_asym(weight): | ||
# avoid circular dependency | ||
from torchao.dtypes import to_affine_quantized | ||
mapping_type = MappingType.ASYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.uint8 | ||
quant_min = 0 | ||
quant_max = 2**n-1 | ||
eps = 1e-6 | ||
preserve_zero = True | ||
zero_point_dtype = torch.int64 | ||
zero_point_domain = ZeroPointDomain.INT | ||
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) | ||
|
||
# for symmetric quantization | ||
def apply_intN_weight_only_quant_sym(weight): | ||
# avoid circular dependency | ||
from torchao.dtypes import to_affine_quantized | ||
mapping_type = MappingType.SYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.int8 | ||
eps = 1e-6 | ||
zero_point_dtype = torch.int64 | ||
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) | ||
|
||
try: | ||
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" | ||
if n == 8: | ||
return int8_weight_only() | ||
elif n == 4: | ||
return int4_weight_only(group_size=group_size) | ||
else: | ||
if symmetric: | ||
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym) | ||
else: | ||
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym) | ||
except Exception as e: | ||
raise | ||
|