diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 7fce5c88bd..4869e67ae6 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -499,6 +499,33 @@ python run_generation.py \ ``` +### Loading 4 Bit Checkpoints from Hugging Face + +You can load pre-quantized 4bit models with the argument `--load_quantized_model`. +Currently, uint4 checkpoints and single device are supported. +More information on enabling 4 bit inference in SynapseAI is available here: +https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_INT4.html. + +Below is an example to load a model with 4bit checkpoints from Hugging Face. +Please note that model name is denoted as ``. +Additionally, the below env vars are used for performance optimizations, and are planned to be removed in future version: +`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1` +```bash +SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \ +python run_lm_eval.py \ +-o acc_load_uint4_model.txt \ +--model_name_or_path \ +--use_hpu_graphs \ +--use_kv_cache \ +--trim_logits \ +--batch_size 1 \ +--bf16 \ +--attn_softmax_bf16 \ +--bucket_size=128 \ +--bucket_internal \ +--load_quantized_model +``` + ### Using Habana Flash Attention Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index b6a4b44a3b..0f63940d6e 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -300,6 +300,11 @@ def setup_parser(parser): help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.", ) parser.add_argument( + "--load_quantized_model", + action="store_true", + help="Whether to load model from hugging face checkpoint.", + ) + "--parallel_strategy", type=str, choices=["tp", "none"], # Add other strategies as needed diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index ee1c624230..5dae91f673 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -246,6 +246,14 @@ def setup_model(args, model_dtype, model_kwargs, logger): torch_dtype=model_dtype, **model_kwargs, ) + elif args.load_quantized_model: + from neural_compressor.torch.quantization import load + model = load( + model_name_or_path=args.model_name_or_path, + format="huggingface", + device="hpu", + **model_kwargs + ) else: if args.assistant_model is not None: assistant_model = AutoModelForCausalLM.from_pretrained( @@ -619,6 +627,9 @@ def initialize_model(args, logger): "trust_remote_code": args.trust_remote_code, } + if args.load_quantized_model: + model_kwargs["torch_dtype"] = torch.bfloat16 + if args.trust_remote_code: logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail") diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a094da0e01..ab1b3f58ee 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -423,9 +423,22 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) + def get_k_proj_weight(self): + """ 4bit quantization in GPTQ replaces the k_proj.weight with qweight. """ + if hasattr(self.k_proj, 'qweight'): + return self.k_proj.qweight + return self.k_proj.weight + + def get_k_proj_weight_dtype(self): + """ 4bit quantization in GPTQ replaces the k_proj.weight with qweight. + Scales tensor gets the weight dtype. """ + if hasattr(self.k_proj, 'qweight'): + return self.k_proj.scales.dtype + return self.k_proj.weight.dtype + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - device = self.k_proj.weight.device + device = self.get_k_proj_weight().device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) @@ -436,7 +449,7 @@ def update_sincos_cache(self, seq_len): # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + _, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -493,7 +506,7 @@ def pre_attn_forward( query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + key_slices = self.get_k_proj_weight().split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] @@ -565,9 +578,9 @@ def pre_attn_forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key = torch.zeros(key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device) past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) # Return list instead of tuple past_key_value = [past_key, past_value]