diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index b20a222fc9..9aad1beecc 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -158,14 +158,6 @@ class ModelArguments: ) }, ) - flash_attention_fp8: bool = field( - default=False, - metadata={ - "help": ( - "Whether to enable flash attention in FP8." - ) - }, - ) use_fused_rope: bool = field( default=True, metadata={ @@ -595,7 +587,6 @@ def main(): model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask - model.generation_config.flash_attention_fp8 = model_args.flash_attention_fp8 if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 55fa570ac6..823da61d5c 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -16,7 +16,6 @@ import functools import torch -from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA has_transformer_engine = False @@ -76,10 +75,6 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): new_module.bias.copy_(module.bias) setattr(model, name, new_module) - elif isinstance(module, ModuleFusedSDPA) and to_transformer_engine: - from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention as te_FusedAttention - module._hpu_kernel_fsdpa = te_FusedAttention(scale=module.scale, attention_dropout=module.attention_dropout, enable_recompute=False) - setattr(model, name, module) else: _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear) diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 7c33c84bef..ce38a07ed9 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -37,8 +37,6 @@ class GaudiGenerationConfig(GenerationConfig): Whether to enable causal_mask if use Habana flash attention. flash_attention_fast_softmax_mode (`bool`, *optional*): Whether to use fast softmax with reduced precision if use Habana flash attention. - flash_attention_fp8 (`bool`, *optional*): - Whether to use flash attention in FP8. """ def __init__(self, **kwargs): @@ -56,5 +54,4 @@ def __init__(self, **kwargs): self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) - self.flash_attention_fp8 = kwargs.get("flash_attention_fp8", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index eaa39c5984..0c51d6ce83 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -50,9 +50,6 @@ import habana_frameworks.torch.core as htcore -flash_attention_in_fp8 = False - - def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -228,19 +225,12 @@ def gaudi_llama_repeat_kv( # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): - def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute=False): + def __init__(self, fusedSDPA): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA - self.scale = scale - self.attention_dropout = attention_dropout - self.enable_recompute = enable_recompute - - def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, fast_softmax_mode): - from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention - if isinstance(self._hpu_kernel_fsdpa, FusedAttention): - return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, fast_softmax_mode) - else: - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, fast_softmax_mode) + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) class Matmul(torch.nn.Module): @@ -301,6 +291,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None if config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -316,12 +307,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.fused_scaled_dot_product_attention = ModuleFusedSDPA( - FusedSDPA, - scale=self.norm_factor, - attention_dropout=self.attention_dropout, - enable_recompute=False, - ) if FusedSDPA else None def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" @@ -499,54 +484,27 @@ def pre_attn_forward( softmax_mode = "fast" if flash_attention_fast_softmax else "None" - global flash_attention_in_fp8 - if flash_attention_in_fp8 is True: - query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( - query_states, - key_states, - value_states, - attention_mask, - self.num_key_value_groups, - ) - if q_len == 1: # next token use_recompute = True if os.getenv("QUANT_CONFIG", "") else False with ht.sdp_kernel(enable_recompute=use_recompute): - if flash_attention_in_fp8 is True: - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) - else: - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode + ) else: # first token if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if flash_attention_in_fp8 is True: - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) - else: - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None, softmax_mode + ) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if flash_attention_in_fp8 is True: - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) - else: - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) - - if flash_attention_in_fp8 is True: - attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode + ) + else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -1056,10 +1014,6 @@ def forward( global has_fused_rope has_fused_rope = False - if self.generation_config.flash_attention_fp8 is True: - global flash_attention_in_fp8 - flash_attention_in_fp8 = True - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids,