From 02f3b45110b4b6a3230779962288f71107279fcb Mon Sep 17 00:00:00 2001 From: mandy-li Date: Tue, 5 Dec 2023 11:45:58 -0800 Subject: [PATCH] Enable habana flash attention for llama2 ft --- examples/language-modeling/run_lora_clm.py | 10 +++ .../models/llama/modeling_llama.py | 61 +++++++++++++------ optimum/habana/transformers/trainer.py | 18 ++++-- 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index f04c3e2224..7924230099 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -142,6 +142,14 @@ class ModelArguments: ) }, ) + use_flash_attention: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only.", + ) + }, + ) load_meta_device: bool = field( default=False, metadata={ @@ -493,6 +501,8 @@ def main(): model.generation_config.eos_token_id = 2 if model_args.attn_softmax_bf16: model.generation_config.attn_softmax_bf16 = True + if model_args.use_flash_attention: + model.generation_config.use_flash_attention = True if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: tokenizer.pad_token_id = model.generation_config.pad_token_id diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index dfbbe29cf1..e1a7f6bb03 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -27,6 +27,11 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -136,6 +141,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -208,30 +214,35 @@ def forward( key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht + with ht.sdp_kernel(enable_recompute = False): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -277,6 +288,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -284,6 +296,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ residual = hidden_states @@ -300,6 +313,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = residual + hidden_states @@ -345,6 +359,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -352,6 +367,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -426,7 +442,8 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16) + return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16, + use_flash_attention=use_flash_attention) return custom_forward @@ -444,6 +461,7 @@ def custom_forward(*inputs): token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = layer_outputs[0] @@ -508,6 +526,7 @@ def forward( trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -528,6 +547,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -611,6 +631,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), } ) return model_inputs diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 305d019f97..27b74e155c 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -832,10 +832,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - # attn_softmax_bf16 is enabled only for llama + # attn_softmax_bf16 and use_flash_attention are enabled only for llama if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True + if self.model.config.model_type == "llama": + if self.model.generation_config.attn_softmax_bf16: + inputs["attn_softmax_bf16"] = True + if self.model.generation_config.use_flash_attention: + inputs["use_flash_attention"] = True # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1530,10 +1533,13 @@ def evaluation_loop( if batch_size is None: batch_size = observed_batch_size - # attn_softmax_bf16 is enabled only for llama + # attn_softmax_bf16 and use_flash_attention are enabled only for llama if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True + if self.model.config.model_type == "llama": + if self.model.generation_config.attn_softmax_bf16: + inputs["attn_softmax_bf16"] = True + if self.model.generation_config.use_flash_attention: + inputs["use_flash_attention"] = True # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)