Skip to content

Commit

Permalink
Expose hqq through int4_weight_only API
Browse files Browse the repository at this point in the history
Summary:
att, this is a follow up for pytorch#605 to make hqq available in quantize_ API

`quantize_(model, int4_weight_only(group_size, use_hqq=True)`

Test Plan:

python generate.py --compile --quantization int4wo-hqq-64 --precision bfloat16
Average tokens/sec: 195.24
Average Bandwidth: 729.40 GB/s
Peak Memory Usage: 5.09 GB
Model Size: 3.74 GB

python eval.py --compile --quantization int4wo-hqq-64 --precision bfloat16

wikitext: {'word_perplexity,none': 12.823631773497512, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.611400903914048, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6883154699192412, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Sep 6, 2024
1 parent e05635e commit 43ab845
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 62 deletions.
67 changes: 28 additions & 39 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_3,
)
from torchao.quantization import (
uintx_weight_only,
int4_weight_only,
)

cuda_available = torch.cuda.is_available()

#Parameters
device = 'cuda:0'
compute_dtype = torch.bfloat16
group_size = 64
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) #axis=1
preserve_zero = False
Expand All @@ -34,36 +37,24 @@

def _init_data(in_features, out_features, compute_dtype, device, torch_seed):
torch.random.manual_seed(torch_seed)
linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device)
linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device)
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
y_ref = linear_layer(x)
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
return W, x, y_ref

def _eval_hqq(nbits, layout_type):
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)

#Plain layout
target_dtype = torch.uint8
#Tensorcore layout
if isinstance(layout_type, TensorCoreTiledLayoutType):
target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32

q_tensor_hqq = to_affine_quantized_intx(
input_float=W,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=0,
quant_max=2**nbits - 1,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
layout_type=layout_type,
use_hqq=True,
)
def _eval_hqq(dtype):
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)

dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight
else:
q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight

quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
del quant_linear_layer.weight
del quant_linear_layer.weight
quant_linear_layer.weight = q_tensor_hqq
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
Expand All @@ -73,42 +64,40 @@ def _eval_hqq(nbits, layout_type):

class TestHQQBase(unittest.TestCase):
@unittest.skipIf(not cuda_available, "Need CUDA available")
def test_hqq(self, nbits=None, layout_type=None, ref_dequantize_error=None, ref_dot_product_error=None):
if(nbits is None): return
dequantize_error, dot_product_error = _eval_hqq(nbits=nbits, layout_type=layout_type)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+")
def test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None):
if(dtype is None): return
dequantize_error, dot_product_error = _eval_hqq(dtype)
self.assertTrue(dequantize_error < ref_dequantize_error)
self.assertTrue(dot_product_error < ref_dot_product_error)

class TestHQQ8Bit(TestHQQBase):
def test_hqq_plain_8bit(self):
self.test_hqq(nbits=8, layout_type=PlainLayoutType(), ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)
self.test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)

class TestHQQ7Bit(TestHQQBase):
def test_hqq_plain_7bit(self):
self.test_hqq(nbits=7, layout_type=PlainLayoutType(), ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)
self.test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)

class TestHQQ6Bit(TestHQQBase):
def test_hqq_plain_6bit(self):
self.test_hqq(nbits=6, layout_type=PlainLayoutType(), ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)
self.test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)

class TestHQQ5Bit(TestHQQBase):
def test_hqq_plain_5bit(self):
self.test_hqq(nbits=5, layout_type=PlainLayoutType(), ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)
self.test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)

class TestHQQ4bit(TestHQQBase):
def test_hqq_plain_4bit(self):
self.test_hqq(nbits=4, layout_type=PlainLayoutType(), ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)

def test_hqq_tensorcore_4bit(self):
self.test_hqq(nbits=4, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), ref_dequantize_error=0.000487, ref_dot_product_error=0.00147)
self.test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)

class TestHQQ3Bit(TestHQQBase):
def test_hqq_plain_3bit(self):
self.test_hqq(nbits=3, layout_type=PlainLayoutType(), ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)
self.test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)

class TestHQQ2Bit(TestHQQBase):
def test_hqq_plain_2bit(self):
self.test_hqq(nbits=2, layout_type=PlainLayoutType(), ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)
self.test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)

if __name__ == "__main__":
unittest.main()
15 changes: 10 additions & 5 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,19 @@ def run_evaluation(
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
if "uintx" in quantization:
# uintx-nbits-group_size
# uintx-nbits-groupsize
# "uintx-2-64"
if "hqq" in quantization:
use_hqq = True
quantization = quantization[:-4]
else:
use_hqq = False
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
nbits = int(_quant_args[0])
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
group_size = int(_quant_args[1])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
Expand Down Expand Up @@ -135,7 +140,7 @@ def run_evaluation(
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<group_size>")
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq")
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand Down
17 changes: 11 additions & 6 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,20 @@ def main(
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "uintx" in quantization:
# uintx-nbits-group_size
# "uintx-2-64"
# uintx-nbits-groupsize, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-groupsize-hqq
quantization = quantization[:-4]
use_hqq = True
else:
use_hqq = False
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
nbits = int(_quant_args[0])
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
group_size = int(_quant_args[1])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
Expand Down Expand Up @@ -451,7 +456,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<group_size>')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ZeroPointDomain,
MappingType,
int_scaled_matmul,
quantize_affine_hqq,
choose_qparams_and_quantize_affine_hqq,
FP8_TYPES,
choose_qparams_affine_fpx,
quantize_affine_fpx,
Expand Down Expand Up @@ -264,7 +264,7 @@ def from_hp_to_intx(
group_size = max(block_size)
compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype
device = input_float.device
data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
data = data.to(target_dtype)
else:
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
Expand Down
7 changes: 6 additions & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
```python
# for torch 2.4+
from torchao.quantization import quantize_, int4_weight_only
quantize_(model, int4_weight_only())
group_size = 32

# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
# use_hqq flag for `int4_weight_only` quantization
use_hqq = False
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
36 changes: 29 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def input_quant_func(x: torch.Tensor):
return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)


def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by `dtype`
Expand All @@ -606,23 +606,44 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, defaults to 64
`pack_dim`: the dimension we use for packing, defaults to -1
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
"""
def apply_uintx_weight_only_quant(weight):
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS

SUPPORTED_DTYPES = {torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8}
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"

def apply_uintx_weight_only_quant(weight, dtype):
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT

if use_hqq:
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
zero_point_dtype = None
zero_point_domain = ZeroPointDomain.FLOAT
preserve_zero = False
layout_type = PlainLayoutType()
else:
quant_min, quant_max = None, None
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)

return to_affine_quantized_intx(
weight, mapping_type, block_size, dtype,
quant_min=quant_min, quant_max=quant_max,
eps=eps, zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
layout_type=layout_type,
use_hqq=use_hqq,
)

return _get_linear_subclass_inserter(apply_uintx_weight_only_quant)
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)

def fpx_weight_only(ebits: int, mbits: int):
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
Expand Down Expand Up @@ -652,5 +673,6 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
return to_affine_quantized_fpx(weight, layout_type)
return _get_linear_subclass_inserter(apply_quant_llm)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])
4 changes: 2 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"dequantize_affine_fpx",
"fake_quantize_affine",
"fake_quantize_affine_cachemask",
"quantize_affine_hqq",
"choose_qparams_and_quantize_affine_hqq",
]

class MappingType(Enum):
Expand Down Expand Up @@ -840,7 +840,7 @@ def _convert_to_affinequantized_format(W_q: torch.Tensor, scale: torch.Tensor, z
return W_q_ao, scale_ao, zero_ao

# Main hqq quantizer function
def quantize_affine_hqq(
def choose_qparams_and_quantize_affine_hqq(
tensor: torch.Tensor,
nbits: float = 4,
group_size: int = 64,
Expand Down

0 comments on commit 43ab845

Please sign in to comment.