From d584181ce6bbe72ac14d31b23b83a72a3fbb1d7e Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Thu, 21 Mar 2024 11:39:47 +0200 Subject: [PATCH] Encapsulate FSDPA in GaudiLlamaAttention * Done to allow quantization using HQT * Added use_flash_attention and flash_attention_recompute to run_lm_eval --- examples/text-generation/run_lm_eval.py | 2 ++ .../models/llama/modeling_llama.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 8d61118890..cf174141d8 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -85,6 +85,8 @@ def __init__(self, tokenizer, model, args, options): self.model_inputs.update( { "attn_softmax_bf16": self.options.attn_softmax_bf16, + "use_flash_attention": self.options.use_flash_attention, + "flash_attention_recompute": self.options.flash_attention_recompute, } ) if args.warmup: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 79bf8c4b05..1d998decc8 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -107,6 +107,16 @@ def gaudi_llama_repeat_kv( return query_states, key_states, value_states, attention_mask +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -164,6 +174,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -309,7 +320,7 @@ def pre_attn_forward( if q_len == 1: # next token with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) else: @@ -317,10 +328,12 @@ def pre_attn_forward( if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same lenght with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None + ) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None )