From b705633cce42a327f71f00b547bad8799c9ed5be Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Thu, 21 Mar 2024 14:05:10 +0200 Subject: [PATCH 1/6] added text-generation quantization_config example file with a name that matches its scale method (#92) --- .../act_maxabs_pow2_weights_pcs_opt_pow2_quant.json | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json diff --git a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json new file mode 100644 index 0000000000..602a147baa --- /dev/null +++ b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} From eb8b435222b293e93ff4a1703f4bbe8ba80b1b42 Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Sun, 24 Mar 2024 18:25:06 +0200 Subject: [PATCH 2/6] Encapsulate FSDPA in GaudiLlamaAttention (#129) * Done to allow quantization using HQT * Added use_flash_attention and flash_attention_recompute to run_lm_eval --- examples/text-generation/run_lm_eval.py | 2 ++ .../models/llama/modeling_llama.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 43a1c46ef5..7d9db3985d 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -85,6 +85,8 @@ def __init__(self, tokenizer, model, args, options): self.model_inputs.update( { "attn_softmax_bf16": self.options.attn_softmax_bf16, + "use_flash_attention": self.options.use_flash_attention, + "flash_attention_recompute": self.options.flash_attention_recompute, } ) if args.warmup: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7546e74d3c..4febf5c6b9 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -140,6 +140,16 @@ def gaudi_llama_repeat_kv( return query_states, key_states, value_states, attention_mask +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -267,6 +277,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -424,7 +435,7 @@ def pre_attn_forward( if q_len == 1: # next token with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) else: @@ -432,10 +443,12 @@ def pre_attn_forward( if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None + ) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) From a963932d1735dd2292a421fd3c16a372792c75ea Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:43:04 +0200 Subject: [PATCH 3/6] enforce recompute flag on fsdpa quantization (#133) --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4febf5c6b9..ea60585405 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import os import math import warnings from typing import List, Optional, Tuple, Union @@ -434,7 +435,8 @@ def pre_attn_forward( if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) From e6bae7fb6e285c849247756cf1a693a26e9c7e41 Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Wed, 3 Apr 2024 18:46:40 +0300 Subject: [PATCH 4/6] add flash_attention_causal_mask to run_lm_eval.py (#142) --- examples/text-generation/run_lm_eval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 7d9db3985d..b39fcf8a71 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -87,6 +87,7 @@ def __init__(self, tokenizer, model, args, options): "attn_softmax_bf16": self.options.attn_softmax_bf16, "use_flash_attention": self.options.use_flash_attention, "flash_attention_recompute": self.options.flash_attention_recompute, + "flash_attention_causal_mask": self.options.flash_attention_causal_mask, } ) if args.warmup: From d638f6a54598a41042bfd120bb33c0a805ee7034 Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Sun, 12 May 2024 15:13:30 +0300 Subject: [PATCH 5/6] Document FusedScaledDotProductAttention quantization --- examples/text-generation/README.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 734ac3d2e3..31d1dc797b 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -266,7 +266,10 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python ../gaudi_spawn.py --use_hpu_graphs \ --trim_logits \ --use_kv_cache \ ---reuse_cache \ +--bucket_size=128 \ +--bucket_internal \ +--use_flash_attention \ +--flash_attention_recompute \ --bf16 \ --batch_size 1 ``` @@ -281,7 +284,10 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ --use_hpu_graphs \ --trim_logits \ --use_kv_cache \ ---reuse_cache \ +--bucket_size=128 \ +--bucket_internal \ +--use_flash_attention \ +--flash_attention_recompute \ --bf16 \ --batch_size 1 \ --fp8 @@ -297,8 +303,10 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ --trim_logits \ --use_kv_cache \ --reuse_cache \ +--use_flash_attention \ +--flash_attention_recompute \ --bf16 \ ---batch_size 277 \ +--batch_size 350 \ --max_new_tokens 2048 \ --max_input_tokens 2048 \ --limit_hpu_graphs \ From 4b9addd153ed552166021e06e05462e00587618d Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Tue, 4 Jun 2024 18:05:59 +0000 Subject: [PATCH 6/6] Fix code style --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ea60585405..1c0a16c181 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,5 +1,5 @@ -import os import math +import os import warnings from typing import List, Optional, Tuple, Union @@ -148,7 +148,7 @@ def __init__(self, fusedSDPA): self._hpu_kernel_fsdpa = fusedSDPA def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) class Matmul(torch.nn.Module):