Skip to content

Commit

Permalink
mixed-precision quantization milestone1: naive_intNwo + eval/benchmar…
Browse files Browse the repository at this point in the history
…k framework (#531)

* milestone1: naive_intNwo + eval/benchmark

* remove experiment scripts

* remove exp files

* use default ZeroPointDomain.INT for int2/3/5/6

* renamed test_naive_intNwo.py to test_mixed_precision.py

* updated intNwo with _get_linear_subclass_inserter

* adjust sqnr threshold according to bit width

* fixed test for int4wo and add __init__.py

* skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly

* edit the sqnr threshold

* add unittest

* correct import path
  • Loading branch information
Hanxian97 authored Aug 1, 2024
1 parent 013cce3 commit c023f71
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test/quantization/test_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

import torch
import torch.nn as nn
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
from torchao.quantization.utils import compute_error
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only

_CUDA_IS_AVAILABLE = torch.cuda.is_available()

class TestWeightOnlyQuantNaive(unittest.TestCase):

def test_quantization_intNwo(self):
#skip test int4wo for now since it is under development in torchao
for quantization_bit in [2, 3, 5, 6, 8]:
for symmetric in [False, True]:
with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric):
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]:
x = torch.randn(*x_shape, dtype=torch.bfloat16)
m = nn.Sequential(nn.Linear(32, 80)).bfloat16()
y_ref = m(x)
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric))
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
# SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
# e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills
expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02
self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low")


if __name__ == '__main__':
unittest.main()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .naive_intNwo import intN_weight_only
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn

from naive_intNwo import intN_weight_only
from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_eval.models.huggingface import HFLM
from lm_eval.evaluator import evaluate
from lm_eval.tasks import get_task_dict

from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight
from torchao._models._eval import TransformerEvalWrapper

from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization.quant_api import autoquant


torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True


def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)

if quantization == "autoquant":
model = autoquant(model.to(device=device))

# naive implementation of uniform precision quantization all layers
elif quantization in ["2","3","4","5","6","8"]:
quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym))

# mix precision quantization for Llama3
elif quantization == "MP_llama3":

# filter for sensitive layers (the first 3 and last 2 layers for Llama3)
def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])

# filter for non-sensitive layers (other 27 layers for Llama3)
def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']))

# quantize the sensitive layers
if sensi_bit != 16:
quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen)

# quantize the less-sensitive layers
if sensi_bit == 4:
quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)
else:
quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

with torch.no_grad():

result = evaluate(
HFLM(
pretrained=model,
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
get_task_dict(tasks),
limit = limit,
)

for task, res in result["results"].items():
print(f"{task}: {res}")


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Run HF Model Evaluation')
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.')
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers')
parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers')
parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default')
parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on')
args = parser.parse_args()
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch

from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization import int8_weight_only, int4_weight_only
from torchao.quantization.quant_api import _get_linear_subclass_inserter

def intN_weight_only(group_size=32, n=8, symmetric=False):
'''
Apply int N-bit weight only quantization to a linear layer.
Args:
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
Usage:
from torchao.quantization import quantize_
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
'''
# for asymmetric quantization
def apply_intN_weight_only_quant_asym(weight):
# avoid circular dependency
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.uint8
quant_min = 0
quant_max = 2**n-1
eps = 1e-6
preserve_zero = True
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)

# for symmetric quantization
def apply_intN_weight_only_quant_sym(weight):
# avoid circular dependency
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = 1e-6
zero_point_dtype = torch.int64
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

try:
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
if n == 8:
return int8_weight_only()
elif n == 4:
return int4_weight_only(group_size=group_size)
else:
if symmetric:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
else:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
except Exception as e:
raise

0 comments on commit c023f71

Please sign in to comment.