From 6c8a1e5242da70278867ffdcc50fb1374115d0a5 Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Wed, 6 Dec 2023 11:36:18 +0200 Subject: [PATCH 1/3] Support for FlashAttention in Llama2 --- examples/text-generation/run_generation.py | 5 ++ .../generation/configuration_utils.py | 3 + .../habana/transformers/generation/utils.py | 3 + .../models/llama/modeling_llama.py | 66 +++++++++++++------ 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 996cd48db3..f6ec837fde 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -212,6 +212,11 @@ def setup_parser(parser): help="Store kv-cache in float8 when kv-cache is used", ) parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") + parser.add_argument( + "--use_flash_attention", + action="store_true", + help="Whether to enable Habana Flash Attention.", + ) args = parser.parse_args() diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 83829f6013..0df9363197 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -29,6 +29,8 @@ class GaudiGenerationConfig(GenerationConfig): Only active if `static_shapes` is used. Can't be used with `reuse_cache`. kv_cache_fp8 (`bool`, *optional*): Store kv-cache in float8 when kv-cache is used + use_flash_attention (`bool`, *optional*): + Whether to use flash attention optimization. """ def __init__(self, **kwargs): @@ -41,3 +43,4 @@ def __init__(self, **kwargs): self.reuse_cache = kwargs.get("reuse_cache", None) self.bucket_size = kwargs.get("bucket_size", -1) self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) + self.use_flash_attention = kwargs.get("use_flash_attention", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ba7ca3c3ed..7c739051bd 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -670,6 +670,9 @@ def generate( # prepare for allocate kv cache model_kwargs["reuse_cache"] = generation_config.reuse_cache + # determine whether flash attention needs to be used + model_kwargs["use_flash_attention"] = generation_config.use_flash_attention + if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] if not generation_config.static_shapes and generation_config.max_new_tokens is not None: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index e8109a9433..ce46dc0c02 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -30,6 +30,11 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -150,7 +155,8 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, - ): + use_flash_attention: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: @@ -222,30 +228,42 @@ def pre_attn_forward( key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) - attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) + else: + # first token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + else: + attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) - attn_output = self.matmul_av(attn_weights, value_states) + attn_output = self.matmul_av(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -330,6 +348,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -337,6 +356,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ residual = hidden_states output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( @@ -349,6 +369,7 @@ def forward( token_idx, attn_softmax_bf16, reuse_cache, + use_flash_attention=use_flash_attention, ) self.self_attn.attention_all_reduce(output_pre_attn) output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) @@ -375,6 +396,7 @@ def pre_attn( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -387,6 +409,7 @@ def pre_attn( token_idx, attn_softmax_bf16, reuse_cache, + use_flash_attention ) return output_attn, attn_weights, present_key_value @@ -433,6 +456,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -440,6 +464,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -532,6 +557,7 @@ def custom_forward(*inputs): token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = layer_outputs[0] @@ -596,6 +622,7 @@ def forward( trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -616,6 +643,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape From 0adb157e0c6f49822b0943bcb2d135743d46504c Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Thu, 7 Dec 2023 18:06:13 +0200 Subject: [PATCH 2/3] Align text for help in argument and pass additional arguments --- examples/text-generation/run_generation.py | 2 +- .../models/llama/modeling_llama.py | 4 +++- optimum/habana/transformers/trainer.py | 18 ++++++++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index f6ec837fde..f2ba957217 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -215,7 +215,7 @@ def setup_parser(parser): parser.add_argument( "--use_flash_attention", action="store_true", - help="Whether to enable Habana Flash Attention.", + help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) args = parser.parse_args() diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ce46dc0c02..b57d5b0697 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -539,7 +539,8 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16) + return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16, + use_flash_attention=use_flash_attention) return custom_forward @@ -727,6 +728,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention") } ) return model_inputs diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 305d019f97..a865862f75 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -832,10 +832,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - # attn_softmax_bf16 is enabled only for llama + # attn_softmax_bf16 and use_flash_attention is enabled only for llama if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True + if self.model.config.model_type == "llama": + if self.model.generation_config.attn_softmax_bf16: + inputs["attn_softmax_bf16"] = True + if self.model.generation_config.use_flash_attention: + inputs["use_flash_attention"] = False # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1530,10 +1533,13 @@ def evaluation_loop( if batch_size is None: batch_size = observed_batch_size - # attn_softmax_bf16 is enabled only for llama + # attn_softmax_bf16 and use_flash_attention are enabled only for llama if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True + if self.model.config.model_type == "llama": + if self.model.generation_config.attn_softmax_bf16: + inputs["attn_softmax_bf16"] = True + if self.model.generation_config.use_flash_attention: + inputs["use_flash_attention"] = False # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) From cc9c8c17b75f1bf2272247b20dab15baad9c4777 Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Mon, 11 Dec 2023 21:50:11 +0200 Subject: [PATCH 3/3] Code formatting; fix trainer parameter --- .../models/llama/modeling_llama.py | 20 ++++++++++++++----- optimum/habana/transformers/trainer.py | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index b57d5b0697..6d3dae0a29 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -36,6 +36,7 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None + def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur cur = cur.to(dtype=prev.dtype) @@ -164,6 +165,7 @@ def pre_attn_forward( - optimize KV cache - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ bsz, q_len, _ = hidden_states.size() @@ -230,10 +232,13 @@ def pre_attn_forward( if use_flash_attention and FusedSDPA: import habana_frameworks.torch.hpu as ht + if q_len == 1: # next token with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: # first token with ht.sdp_kernel(enable_recompute=False): @@ -409,7 +414,7 @@ def pre_attn( token_idx, attn_softmax_bf16, reuse_cache, - use_flash_attention + use_flash_attention, ) return output_attn, attn_weights, present_key_value @@ -539,8 +544,13 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16, - use_flash_attention=use_flash_attention) + return module( + *inputs, + past_key_value, + output_attentions, + attn_softmax_bf16=attn_softmax_bf16, + use_flash_attention=use_flash_attention, + ) return custom_forward @@ -728,7 +738,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "reuse_cache": reuse_cache, - "use_flash_attention": kwargs.get("use_flash_attention") + "use_flash_attention": kwargs.get("use_flash_attention"), } ) return model_inputs diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index a865862f75..9fc078cda9 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -838,7 +838,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if self.model.generation_config.attn_softmax_bf16: inputs["attn_softmax_bf16"] = True if self.model.generation_config.use_flash_attention: - inputs["use_flash_attention"] = False + inputs["use_flash_attention"] = True # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1539,7 +1539,7 @@ def evaluation_loop( if self.model.generation_config.attn_softmax_bf16: inputs["attn_softmax_bf16"] = True if self.model.generation_config.use_flash_attention: - inputs["use_flash_attention"] = False + inputs["use_flash_attention"] = True # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)