From b1462757a7bc65e48208a207db68a3e22b4d48ab Mon Sep 17 00:00:00 2001 From: Piotr Bielak Date: Tue, 1 Oct 2024 14:56:45 +0300 Subject: [PATCH] Enable FusedSDPA fp8 in Llama FT - Add --flash_attention_fp8 flag - Add assert for Gaudi 3 - Update ModuleFusedSDPA - While converting model replace whole ModuleFusedSDPA object with class based on Transformer Engine attention (FP8) - Fix error: `LlamaConfig object has no attribute flash_attention_fp8` Co-authored-by: Yaser Afshar Co-authored-by: Harish Subramony <81822986+hsubramony@users.noreply.github.com> --- examples/language-modeling/run_lora_clm.py | 10 ++++++++++ .../accelerate/utils/transformer_engine.py | 20 +++++++++++++++++++ .../models/llama/configuration_llama.py | 2 ++ .../models/llama/modeling_llama.py | 18 +++++++++++++++-- 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index ebbcc2e4d0..0b16be0725 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -172,6 +172,10 @@ class ModelArguments: ) }, ) + flash_attention_fp8: bool = field( + default=False, + metadata={"help": ("Whether to enable flash attention in FP8.")}, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -509,6 +513,7 @@ def main(): "trust_remote_code": True if model_args.trust_remote_code else None, "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, "token": model_args.token, + "flash_attention_fp8": model_args.flash_attention_fp8, } if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) @@ -705,6 +710,11 @@ def main(): model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + + if model_args.flash_attention_fp8: + import habana_frameworks.torch.hpu as hthpu + + assert hthpu.get_device_name() == "GAUDI3", "Flash attention in FP8 is supported only on Gaudi3" if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 823da61d5c..b40b3b2110 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -42,6 +42,8 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): """ Recursively converts the linear layer of a model to their `transformers_engine` counterpart. """ + from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA + if not is_fp8_available(): raise ImportError("Using `convert_model` requires transformer_engine to be installed.") for name, module in model.named_children(): @@ -75,6 +77,24 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): new_module.bias.copy_(module.bias) setattr(model, name, new_module) + elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine: + from habana_frameworks.torch.hpex.experimental.transformer_engine import ( + FusedAttention as TE_FusedAttention, + ) + + class TE_ModuleFusedSDPA(torch.nn.Module): + def __init__(self): + super().__init__() + self._hpu_kernel_fsdpa = TE_FusedAttention( + scale=module.scale, + attention_dropout=module.attention_dropout, + enable_recompute=module.enable_recompute, + ) + + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): + return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode) + + setattr(model, name, TE_ModuleFusedSDPA()) else: _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear) diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index ce754dadb5..8e7c767a0f 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -27,6 +27,7 @@ def __init__( mlp_bias=False, fused_qkv=False, parallel_strategy=None, + flash_attention_fp8=False, **kwargs, ): super().__init__( @@ -56,3 +57,4 @@ def __init__( self.fused_qkv = fused_qkv self.parallel_strategy = parallel_strategy + self.flash_attention_fp8 = flash_attention_fp8 diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c4a8dc6c85..b810669e7f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -347,9 +347,13 @@ def gaudi_llama_repeat_kv( # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): - def __init__(self, fusedSDPA): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode) @@ -412,7 +416,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() - self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -428,6 +431,17 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.fused_scaled_dot_product_attention = ( + ModuleFusedSDPA( + FusedSDPA, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight."""