From 83b36053e2a4b14a6c7aea2865aec0f1bbccc36e Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Thu, 21 Mar 2024 14:05:10 +0200 Subject: [PATCH 1/4] added text-generation quantization_config example file with a name that matches its scale method (#92) --- .../act_maxabs_pow2_weights_pcs_opt_pow2_quant.json | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json diff --git a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json new file mode 100644 index 0000000000..602a147baa --- /dev/null +++ b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} From e231aa5829843b9742f21b5b5485970d6801fca7 Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Sun, 24 Mar 2024 18:25:06 +0200 Subject: [PATCH 2/4] Encapsulate FSDPA in GaudiLlamaAttention (#129) * Done to allow quantization using HQT * Added use_flash_attention and flash_attention_recompute to run_lm_eval --- examples/text-generation/run_lm_eval.py | 2 ++ .../models/llama/modeling_llama.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 8682a28d35..4b9aef65bf 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -85,6 +85,8 @@ def __init__(self, tokenizer, model, args, options): self.model_inputs.update( { "attn_softmax_bf16": self.options.attn_softmax_bf16, + "use_flash_attention": self.options.use_flash_attention, + "flash_attention_recompute": self.options.flash_attention_recompute, } ) if args.warmup: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7546e74d3c..4febf5c6b9 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -140,6 +140,16 @@ def gaudi_llama_repeat_kv( return query_states, key_states, value_states, attention_mask +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -267,6 +277,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -424,7 +435,7 @@ def pre_attn_forward( if q_len == 1: # next token with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) else: @@ -432,10 +443,12 @@ def pre_attn_forward( if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None + ) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) From 86fa5b66f21f161c21500eb560762da197fb535a Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:43:04 +0200 Subject: [PATCH 3/4] enforce recompute flag on fsdpa quantization (#133) --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4febf5c6b9..ea60585405 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import os import math import warnings from typing import List, Optional, Tuple, Union @@ -434,7 +435,8 @@ def pre_attn_forward( if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) From 659b2d160693ddf45244e31f8e9dae3e647b9885 Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Wed, 3 Apr 2024 18:46:40 +0300 Subject: [PATCH 4/4] add flash_attention_causal_mask to run_lm_eval.py (#142) --- examples/text-generation/run_lm_eval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4b9aef65bf..7c79b2c13a 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -87,6 +87,7 @@ def __init__(self, tokenizer, model, args, options): "attn_softmax_bf16": self.options.attn_softmax_bf16, "use_flash_attention": self.options.use_flash_attention, "flash_attention_recompute": self.options.flash_attention_recompute, + "flash_attention_causal_mask": self.options.flash_attention_causal_mask, } ) if args.warmup: