diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 9f8569c2865..bb67057b4ba 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--quantization_mode", type=str, default=None, - choices=["int8", "8da4w", "8da4w-gptq"], + choices=["int8", "8da4w", "8da4w-gptq", "16a4w-hqq"], help="type of quantization", ) diff --git a/examples/models/llama2/source_transformation/hqq_16a4w.py b/examples/models/llama2/source_transformation/hqq_16a4w.py new file mode 100644 index 00000000000..89aaafaeaf8 --- /dev/null +++ b/examples/models/llama2/source_transformation/hqq_16a4w.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model +from hqq.core.quantize import BaseQuantizeConfig, HQQLinear + +########################## Run HQQ ############################### + + +def _replace_linear_4w_hqq( + module: torch.nn.Module, + quant_config, + compute_dtype, + del_orig=False, +): + """ + Recursively replacing all Linear layers with HQQLinear with the 4bit quantized weights + """ + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + new_linear = HQQLinear( + child, + quant_config, + compute_dtype=compute_dtype, + del_orig=True, + device="cpu", + ) + setattr(module, name, new_linear) + else: + _replace_linear_4w_hqq( + child, + quant_config, + compute_dtype, + del_orig=False, + ) + + +def replace_linear_4w_hqq( + module: torch.nn.Module, + quant_config: BaseQuantizeConfig, + compute_dtype, + del_orig=False, +): + """ + Replace all Linear layers with HQQLinear with the 4bit quantized weights + """ + _replace_linear_4w_hqq( + module, + quant_config, + compute_dtype, + del_orig=False, + ) + + +def run_hqq_quantize(model: torch.nn.Module) -> None: + """ + Inplace update the model with the hqq quantized weights + """ + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + + replace_linear_4w_hqq(model, quant_config=quant_config, compute_dtype=torch.float32) + + +########################## Use static quantization with HQQ Linear ############################### + + +def calibrate( + model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length +): + print("run calibration...") + eval_wrapper = EagerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + use_kv_cache=False, + ) + eval_results = evaluate_model( + eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) + for task, res in eval_results["results"].items(): + print(f"Reference result with hqq model: {task}: {res}") + + +class LinearActivationFakeQuant(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + self.input_activation_fake_quant = torch.quantization.FakeQuantize( + observer=torch.quantization.MovingAverageMinMaxObserver, + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + ) + self.output_activation_fake_quant = torch.quantization.FakeQuantize( + observer=torch.quantization.MovingAverageMinMaxObserver, + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + ) + + def forward(self, x): + x = self.input_activation_fake_quant(x) + return self.output_activation_fake_quant(self.linear(x)) + + +def get_quant_params(activation_fake_quant): + quant_min = activation_fake_quant.quant_min + quant_max = activation_fake_quant.quant_max + qparams = activation_fake_quant.calculate_qparams() + scale = qparams[0] + zero_point = qparams[1] + return (quant_min, quant_max, scale, zero_point) + + +class LinearActivationQuant(torch.nn.Module): + + def __init__(self, linear_fake_quant): + super().__init__() + self.linear_fake_quant = linear_fake_quant + ( + self.input_quant_min, + self.input_quant_max, + self.input_scale, + self.input_zero_point, + ) = get_quant_params(linear_fake_quant.input_activation_fake_quant) + + ( + self.output_quant_min, + self.output_quant_max, + self.output_scale, + self.output_zero_point, + ) = get_quant_params(linear_fake_quant.output_activation_fake_quant) + + def forward(self, x): + # Manually quantize the input tensor using observed min and max values + q_tensor = torch.round(x / self.input_scale + self.input_zero_point) + # Clip to ensure within the range [quant min and quant max] + q_tensor = torch.clamp(q_tensor, self.input_quant_min, self.input_quant_max) + # Dequantize to the original scale + dequantized_tensor = (q_tensor - self.input_zero_point) * self.input_scale + + linear_output = self.linear_fake_quant.linear(dequantized_tensor) + + # # Quantize the linear output tensor + q_linear_output = torch.round( + linear_output / self.output_scale + self.output_zero_point + ) + q_linear_output = torch.clamp( + q_linear_output, self.output_quant_min, self.output_quant_max + ) + # Dequantize the linear output tensor + dq_linear_output = ( + q_linear_output - self.output_zero_point + ) * self.output_scale + + return dq_linear_output + + +def _replace_linear_quant_activation(module: torch.nn.Module, stage: str): + for name, child in module.named_children(): + if stage == "convert": + if isinstance(child, LinearActivationFakeQuant): + new_linear = LinearActivationQuant(child) + setattr(module, name, new_linear) + else: + _replace_linear_quant_activation(child, stage) + elif stage == "prepare": + if isinstance(child, HQQLinear): + new_linear = LinearActivationFakeQuant(child) + setattr(module, name, new_linear) + else: + _replace_linear_quant_activation(child, stage) + else: + raise ValueError(f"Unsupported stage {stage}") + + +def replace_linear_quant_activation(module: torch.nn.Module, stage: str): + _replace_linear_quant_activation( + module, + stage, + ) + + +def prepare(model): + """ + Prepare the model for quantization by manually inserting the observors + """ + replace_linear_quant_activation(model, "prepare") + + +def convert(model): + """ + Convert the observors the actual quant/dequant nodes, in this implementation, we manually + calling add, mul, clamp for quick prototyping + """ + replace_linear_quant_activation(model, "convert") diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index c5472668ca0..7ab24ec2e53 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from executorch.examples.models.llama2.tokenizer.tokenizer import Tokenizer from sentencepiece import SentencePieceProcessor @@ -127,6 +128,44 @@ def quantize( group_size, ) model = gptq_quantizer.quantize(model, inputs) + return model + elif qmode == "16a4w-hqq": + try: + from executorch.examples.models.llama2.source_transformation import ( + hqq_16a4w, + ) + except ImportError: + print( + "Please follow instruction in https://github.com/mobiusml/hqq to install the latest version." + ) + if calibration_tasks is None: + calibration_tasks = ["wikitext"] + if calibration_limit is None: + calibration_limit = 5 + if calibration_seq_length is None: + calibration_seq_length = 128 + if tokenizer_path is None: + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = Tokenizer(model_path=str(tokenizer_path)) # pyre-ignore[28] + + # Step 1: Run hqq quantization, the linear inside the model will be replaced with HQQ linear + hqq_16a4w.run_hqq_quantize(model) + + # Run hqq quantization first + # Insert observer + hqq_16a4w.prepare(model) + # Calibration + hqq_16a4w.calibrate( + model=model, + tokenizer=tokenizer, + calibration_tasks=calibration_tasks, + calibration_limit=calibration_limit, + calibration_seq_length=calibration_seq_length, + ) + # Convert observer to the fake quantized model + hqq_16a4w.convert(model) + return model else: raise Exception(f"Unrecognized quantize mode: {qmode}")