diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 996cd48db3..f2ba957217 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -212,6 +212,11 @@ def setup_parser(parser): help="Store kv-cache in float8 when kv-cache is used", ) parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") + parser.add_argument( + "--use_flash_attention", + action="store_true", + help="Whether to enable Habana Flash Attention, provided that the model supports it.", + ) args = parser.parse_args() diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 83829f6013..0df9363197 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -29,6 +29,8 @@ class GaudiGenerationConfig(GenerationConfig): Only active if `static_shapes` is used. Can't be used with `reuse_cache`. kv_cache_fp8 (`bool`, *optional*): Store kv-cache in float8 when kv-cache is used + use_flash_attention (`bool`, *optional*): + Whether to use flash attention optimization. """ def __init__(self, **kwargs): @@ -41,3 +43,4 @@ def __init__(self, **kwargs): self.reuse_cache = kwargs.get("reuse_cache", None) self.bucket_size = kwargs.get("bucket_size", -1) self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) + self.use_flash_attention = kwargs.get("use_flash_attention", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ba7ca3c3ed..7c739051bd 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -670,6 +670,9 @@ def generate( # prepare for allocate kv cache model_kwargs["reuse_cache"] = generation_config.reuse_cache + # determine whether flash attention needs to be used + model_kwargs["use_flash_attention"] = generation_config.use_flash_attention + if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] if not generation_config.static_shapes and generation_config.max_new_tokens is not None: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index e8109a9433..6d3dae0a29 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -30,6 +30,12 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -150,7 +156,8 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, - ): + use_flash_attention: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: @@ -158,6 +165,7 @@ def pre_attn_forward( - optimize KV cache - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ bsz, q_len, _ = hidden_states.size() @@ -222,30 +230,45 @@ def pre_attn_forward( key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) - attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # first token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + else: + attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) - attn_output = self.matmul_av(attn_weights, value_states) + attn_output = self.matmul_av(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -330,6 +353,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -337,6 +361,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ residual = hidden_states output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( @@ -349,6 +374,7 @@ def forward( token_idx, attn_softmax_bf16, reuse_cache, + use_flash_attention=use_flash_attention, ) self.self_attn.attention_all_reduce(output_pre_attn) output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) @@ -375,6 +401,7 @@ def pre_attn( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -387,6 +414,7 @@ def pre_attn( token_idx, attn_softmax_bf16, reuse_cache, + use_flash_attention, ) return output_attn, attn_weights, present_key_value @@ -433,6 +461,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -440,6 +469,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flash_attention """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -514,7 +544,13 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16) + return module( + *inputs, + past_key_value, + output_attentions, + attn_softmax_bf16=attn_softmax_bf16, + use_flash_attention=use_flash_attention, + ) return custom_forward @@ -532,6 +568,7 @@ def custom_forward(*inputs): token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = layer_outputs[0] @@ -596,6 +633,7 @@ def forward( trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -616,6 +654,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -699,6 +738,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), } ) return model_inputs diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 305d019f97..9fc078cda9 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -832,10 +832,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - # attn_softmax_bf16 is enabled only for llama + # attn_softmax_bf16 and use_flash_attention is enabled only for llama if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True + if self.model.config.model_type == "llama": + if self.model.generation_config.attn_softmax_bf16: + inputs["attn_softmax_bf16"] = True + if self.model.generation_config.use_flash_attention: + inputs["use_flash_attention"] = True # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1530,10 +1533,13 @@ def evaluation_loop( if batch_size is None: batch_size = observed_batch_size - # attn_softmax_bf16 is enabled only for llama + # attn_softmax_bf16 and use_flash_attention are enabled only for llama if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True + if self.model.config.model_type == "llama": + if self.model.generation_config.attn_softmax_bf16: + inputs["attn_softmax_bf16"] = True + if self.model.generation_config.use_flash_attention: + inputs["use_flash_attention"] = True # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)