diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 9aad1beecc..9998e20cb9 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -42,7 +42,7 @@ from transformers.trainer_utils import is_main_process from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments -from optimum.habana.utils import set_seed +from optimum.habana.utils import is_gaudi3, set_seed try: @@ -158,6 +158,14 @@ class ModelArguments: ) }, ) + flash_attention_fp8: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable flash attention in FP8." + ) + }, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -432,6 +440,7 @@ def main(): "trust_remote_code": True if model_args.trust_remote_code else None, "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, "token": model_args.token, + "flash_attention_fp8": model_args.flash_attention_fp8, } if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) @@ -587,6 +596,9 @@ def main(): model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + + if model_args.flash_attention_fp8: + assert is_gaudi3(), "Flash attention in FP8 is supported only on Gaudi3" if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 823da61d5c..c4aaba69a7 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -42,6 +42,8 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): """ Recursively converts the linear layer of a model to their `transformers_engine` counterpart. """ + from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA + if not is_fp8_available(): raise ImportError("Using `convert_model` requires transformer_engine to be installed.") for name, module in model.named_children(): @@ -75,6 +77,14 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): new_module.bias.copy_(module.bias) setattr(model, name, new_module) + elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine: + from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention as te_FusedAttention + module._hpu_kernel_fsdpa = te_FusedAttention( + scale=module.scale, + attention_dropout=module.attention_dropout, + enable_recompute=module.enable_recompute, + ) + setattr(model, name, module) else: _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear) diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index 12ad78e29a..7a92a35c76 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -28,6 +28,7 @@ def __init__( mlp_bias=False, fused_qkv=False, parallel_strategy=None, + flash_attention_fp8=False, **kwargs, ): super().__init__( @@ -57,3 +58,4 @@ def __init__( self.mlp_bias = mlp_bias self.fused_qkv = fused_qkv self.parallel_strategy = parallel_strategy + self.flash_attention_fp8 = flash_attention_fp8 diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 9366a7bee1..ea47de38ed 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -266,12 +266,20 @@ def gaudi_llama_repeat_kv( # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): - def __init__(self, fusedSDPA): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA - - def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left"): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side) + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 + + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left"): + from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention as FusedAttentionTE + if isinstance(self._hpu_kernel_fsdpa, FusedAttentionTE): + return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode) + else: + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side) class Matmul(torch.nn.Module): @@ -332,7 +340,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() - self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None if config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -348,6 +355,13 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA( + FusedSDPA, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=config.flash_attention_fp8, + ) if FusedSDPA else None def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" @@ -522,26 +536,24 @@ def pre_attn_forward( past_key_value = None if use_flash_attention and FusedSDPA: - import habana_frameworks.torch.hpu as ht - - softmax_mode = "fast" if flash_attention_fast_softmax else "None" - if q_len == 1: # next token attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode, False, None, "None" + query_states, key_states, value_states, attention_mask, 0.0, False, None, "None", False, None, "None" ) else: + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + # first token if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, None, 0.0, True, None, softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left" ) else: attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode,flash_attention_recompute, None, "None" + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode, flash_attention_recompute, None, "None" ) - else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -597,7 +609,6 @@ def post_attn_forward(self, attn_output): self.o_proj.post_all_reduce(attn_output) return attn_output - class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): def __init__( self, diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index f5c3345a46..2bac05e0eb 100755 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -400,3 +400,8 @@ def get_device_name(): return "gaudi2" else: raise ValueError(f"Unsupported device: the device type is {device_type}.") + + +def is_gaudi3(): + import habana_frameworks.torch as htorch + return htorch.hpu.get_device_name() == "GAUDI3"