Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,30 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \

For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).

### Running with UINT4 using AutoGPTQ


Llama2-7b in UINT4 is enabled using [AutoGPTQ Fork](https://github.com/HabanaAI/AutoGPTQ), which provides quantization capabilities in PyTorch.
Currently, the support is for UINT4 inference of pre-quantized models only.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HolyFalafel please add the AutoGPTQ installation here,

BUILD_CUDA_EXT=0 pip install auto-gptq --no-build-isolation


You can run a *UINT4 quantized* model using AutoGPTQ with the argument `--gptq`.

Here is an example to run a quantized model on Llama2-7b `TheBloke/Llama-2-7b-Chat-GPTQ`:
```bash
python run_generation.py \
--attn_softmax_bf16 \
--model_name_or_path TheBloke/Llama-2-7b-Chat-GPTQ \
--use_hpu_graphs \
--limit_hpu_graphs \
--use_kv_cache \
--bucket_size 128 \
--bucket_internal \
--trim_logits \
--max_new_tokens 128 \
--batch_size 1 \
--bf16 \
--gptq
```

## Language Model Evaluation Harness

Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def setup_parser(parser):
help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)",
)

parser.add_argument("--gptq", action="store_true", help="Enable Quantization to 4 bit with AutoGPTQ")

parser.add_argument(
"--use_flash_attention",
action="store_true",
Expand Down Expand Up @@ -296,6 +298,9 @@ def setup_parser(parser):
args.limit_hpu_graphs = False

args.quant_config = os.getenv("QUANT_CONFIG", "")
if args.quant_config and args.gptq:
raise RuntimeError("Setting both quant_config and gptq is unsupported. ")

if args.quant_config == "" and args.disk_offload:
logger.warning(
"`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag."
Expand Down
4 changes: 4 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def setup_model(args, model_dtype, model_kwargs, logger):
torch_dtype=model_dtype,
**model_kwargs,
)
elif args.gptq:
from transformers import GPTQConfig
quantization_config = GPTQConfig(bits=4, use_exllama=False)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs)
else:
if args.assistant_model is not None:
assistant_model = AutoModelForCausalLM.from_pretrained(
Expand Down
23 changes: 18 additions & 5 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,22 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

def get_k_proj_weight(self):
""" 4bit quantization in GPTQ replaces the k_proj.weight with qweight. """
if hasattr(self.k_proj, 'qweight'):
return self.k_proj.qweight
return self.k_proj.weight

def get_k_proj_weight_dtype(self):
""" 4bit quantization in GPTQ replaces the k_proj.weight with qweight.
Scales tensor gets the weight dtype. """
if hasattr(self.k_proj, 'qweight'):
return self.k_proj.scales.dtype
return self.k_proj.weight.dtype

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
Expand All @@ -306,7 +319,7 @@ def update_sincos_cache(self, seq_len):
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
_, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len)

def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
Expand Down Expand Up @@ -362,7 +375,7 @@ def pre_attn_forward(
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
key_slices = self.get_k_proj_weight().split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
Expand Down Expand Up @@ -414,9 +427,9 @@ def pre_attn_forward(
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_key = torch.zeros(key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device)
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
# Return list instead of tuple
past_key_value = [past_key, past_value]
Expand Down