Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,33 @@ python run_generation.py \
```


### Loading 4 Bit Checkpoints from Hugging Face

You can load pre-quantized 4bit models with the argument `--load_quantized_model`.
Currently, uint4 checkpoints and single device are supported.
More information on enabling 4 bit inference in SynapseAI is available here:
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_INT4.html.

Below is an example to load a model with 4bit checkpoints from Hugging Face.
Please note that model name is denoted as `<model_path_in_hugging_face>`.
Additionally, the below env vars are used for performance optimizations, and are planned to be removed in future version:
`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1`
```bash
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \
python run_lm_eval.py \
-o acc_load_uint4_model.txt \
--model_name_or_path <model_path_in_hugging_face> \
--use_hpu_graphs \
--use_kv_cache \
--trim_logits \
--batch_size 1 \
--bf16 \
--attn_softmax_bf16 \
--bucket_size=128 \
--bucket_internal \
--load_quantized_model
```

### Using Habana Flash Attention

Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ def setup_parser(parser):
help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
)
parser.add_argument(
"--load_quantized_model",
action="store_true",
help="Whether to load model from hugging face checkpoint.",
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one line missing here

"--parallel_strategy",
type=str,
choices=["tp", "none"], # Add other strategies as needed
Expand Down
11 changes: 11 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,14 @@ def setup_model(args, model_dtype, model_kwargs, logger):
torch_dtype=model_dtype,
**model_kwargs,
)
elif args.load_quantized_model:
from neural_compressor.torch.quantization import load
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ulivne sounds like neural_compressor is missing from the requirements consider adding it!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ulivne sounds like neural_compressor is missing from the requirements consider adding it!

neural_compressor is installed automatically as part of habana software stack. it replaces habana_quantization_toolkit which was also not part of requirements.

model = load(
model_name_or_path=args.model_name_or_path,
format="huggingface",
device="hpu",
**model_kwargs
)
else:
if args.assistant_model is not None:
assistant_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -619,6 +627,9 @@ def initialize_model(args, logger):
"trust_remote_code": args.trust_remote_code,
}

if args.load_quantized_model:
model_kwargs["torch_dtype"] = torch.bfloat16

if args.trust_remote_code:
logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail")

Expand Down
23 changes: 18 additions & 5 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,22 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

def get_k_proj_weight(self):
""" 4bit quantization in GPTQ replaces the k_proj.weight with qweight. """
if hasattr(self.k_proj, 'qweight'):
return self.k_proj.qweight
return self.k_proj.weight

def get_k_proj_weight_dtype(self):
""" 4bit quantization in GPTQ replaces the k_proj.weight with qweight.
Scales tensor gets the weight dtype. """
if hasattr(self.k_proj, 'qweight'):
return self.k_proj.scales.dtype
return self.k_proj.weight.dtype

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
Expand All @@ -436,7 +449,7 @@ def update_sincos_cache(self, seq_len):
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
_, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len)

def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
Expand Down Expand Up @@ -493,7 +506,7 @@ def pre_attn_forward(
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
key_slices = self.get_k_proj_weight().split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
Expand Down Expand Up @@ -565,9 +578,9 @@ def pre_attn_forward(
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_key = torch.zeros(key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device)
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
# Return list instead of tuple
past_key_value = [past_key, past_value]
Expand Down