|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import logging |
| 16 | + |
| 17 | +import torch |
| 18 | +import torchao |
| 19 | +from packaging.version import parse |
15 | 20 | from transformers import AutoModelForCausalLM, GenerationConfig |
16 | 21 |
|
17 | 22 | from ..integrations import CausalLMExportableModule |
@@ -71,4 +76,48 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl |
71 | 76 | }, |
72 | 77 | ), |
73 | 78 | ) |
| 79 | + |
| 80 | + # TODO: Move quantization recipe out for better composability. |
| 81 | + # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. |
| 82 | + qlinear_config = kwargs.get("qlinear", None) |
| 83 | + qembedding_config = kwargs.get("qembedding", None) |
| 84 | + if qlinear_config or qembedding_config: |
| 85 | + # TODO: Update torchao to use 0.11.0 once released |
| 86 | + if parse(torchao.__version__) < parse("0.11.0.dev0"): |
| 87 | + raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.") |
| 88 | + |
| 89 | + from torchao.quantization.granularity import PerAxis, PerGroup |
| 90 | + from torchao.quantization.quant_api import ( |
| 91 | + Int8DynamicActivationIntxWeightConfig, |
| 92 | + IntxWeightOnlyConfig, |
| 93 | + quantize_, |
| 94 | + ) |
| 95 | + from torchao.utils import unwrap_tensor_subclass |
| 96 | + |
| 97 | + if qembedding_config: |
| 98 | + logging.info("Quantizing embedding layers.") |
| 99 | + # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. |
| 100 | + embedding_config = IntxWeightOnlyConfig( |
| 101 | + weight_dtype=torch.int8, |
| 102 | + granularity=PerAxis(0), |
| 103 | + ) |
| 104 | + quantize_( |
| 105 | + eager_model, |
| 106 | + embedding_config, |
| 107 | + lambda m, fqn: isinstance(m, torch.nn.Embedding), |
| 108 | + ) |
| 109 | + |
| 110 | + if qlinear_config: |
| 111 | + logging.info("Quantizing linear layers.") |
| 112 | + linear_config = Int8DynamicActivationIntxWeightConfig( |
| 113 | + weight_dtype=torch.int4, |
| 114 | + weight_granularity=PerGroup(32), |
| 115 | + ) |
| 116 | + quantize_( |
| 117 | + eager_model, |
| 118 | + linear_config, |
| 119 | + ) |
| 120 | + |
| 121 | + unwrap_tensor_subclass(eager_model) |
| 122 | + |
74 | 123 | return CausalLMExportableModule(eager_model) |
0 commit comments