From cee7dd488c28e018c0dac1c2173c9a8c0928bee4 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 6 Jun 2024 17:56:35 +0000 Subject: [PATCH 1/8] initial commit --- .../models/qwen2/modeling_qwen2.py | 92 +++++++++++++++++-- 1 file changed, 82 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index f192cf4898..e5f895996c 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -20,6 +20,7 @@ import warnings from typing import List, Optional, Tuple, Union +import os import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -82,6 +83,16 @@ def gaudi_qwen2_rmsnorm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + class GaudiQwen2MLP(Qwen2MLP): def pre_mlp_forward(self, x): inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) @@ -176,6 +187,12 @@ def forward(self, cur, dim, idx): class GaudiQwen2Attention(Qwen2Attention): def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) + + self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" + + # In the constructor we do not know which one we will need later in the forward, so creating both + # TODO, Does this affect memory usage? + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) self.matmul_qk = Matmul() self.matmul_av = Matmul() @@ -214,7 +231,7 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) - def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size): + def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size, softmax_mode): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase Causal mask is not supported in this optimization @@ -231,7 +248,7 @@ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mas s, e = i * q_block_size, (i + 1) * q_block_size row_q = query_layer[:, :, s:e, :] row_mask = attention_mask[:, :, s:e, :] - attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None) + attn_output_partial = self.fused_scaled_dot_product_attention(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode) row_o_list.append(attn_output_partial) attn_output = torch.cat(row_o_list, dim=-2) @@ -271,6 +288,7 @@ def pre_attn_forward( warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) + train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None bsz, q_len, _ = hidden_states.size() @@ -324,32 +342,86 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - + + ''' + if use_flash_attention or train_with_flash_attention: + is_causal = self.is_causal and query_length > 1 and flash_attention_causal_mask + if self.is_fp8: + attn_mask = None if is_causal else attention_mask + flash_attention_fast_softmax = True # TODO pass this along + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + enable_recompute = query_length == 1 or flash_attention_recompute + with sdp_kernel(enable_recompute=enable_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attn_mask, 0.0, is_causal, None, softmax_mode + ) + else: + # TODO very similar to the fp8 case above, could be merged. + with sdp_kernel( + enable_recompute=flash_attention_recompute + ) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal and attention_mask is None, + ) + else: + if self.is_fp8: + attn_output = self.unfused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + ''' + + flash_attention_fast_softmax = True # TODO pass this along + softmax_mode = "fast" if flash_attention_fast_softmax else "None" if use_flash_attention and FusedSDPA: import habana_frameworks.torch.hpu as ht if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None + #with ht.sdp_kernel(enable_recompute=False): + # attn_output = FusedSDPA.apply( + # query_states, key_states, value_states, attention_mask, 0.0, False, None + # ) + + with ht.sdp_kernel(enable_recompute=True): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) else: # first token if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention(query_states, key_states, value_states, None, 0.0, True, None, softmax_mode) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): if q_len > 8192: attn_output = self.gaudi_flash_attn_v1( - query_states, key_states, value_states, attention_mask, 0.0, self.block_size + query_states, key_states, value_states, attention_mask, 0.0, self.block_size, softmax_mode ) htcore.mark_step() else: - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) else: From 4ed76f313e5b088757cb9bf861313b3fd0c87f29 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 6 Jun 2024 20:23:43 +0000 Subject: [PATCH 2/8] Similify fusedsdpa section --- .../models/qwen2/modeling_qwen2.py | 72 ++++++++++++------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index e5f895996c..71c8e22c9c 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -189,9 +189,6 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" - - # In the constructor we do not know which one we will need later in the forward, so creating both - # TODO, Does this affect memory usage? self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) self.matmul_qk = Matmul() @@ -231,21 +228,22 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) - def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size, softmax_mode): + def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase Causal mask is not supported in this optimization """ + assert not is_casual q_len = query_layer.size(-2) - q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) - q_padding = q_tiles * q_block_size - q_len + q_tiles = (q_len // self.block_size) if (q_len % self.block_size == 0) else math.ceil(q_len / self.block_size) + q_padding = q_tiles * self.block_size - q_len query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) if attention_mask is not None: attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) row_o_list = [] for i in range(q_tiles): - s, e = i * q_block_size, (i + 1) * q_block_size + s, e = i * self.block_size, (i + 1) * self.block_size row_q = query_layer[:, :, s:e, :] row_mask = attention_mask[:, :, s:e, :] attn_output_partial = self.fused_scaled_dot_product_attention(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode) @@ -389,40 +387,64 @@ def pre_attn_forward( is_causal=self.is_causal and attention_mask is None and query_length > 1, ) ''' - + flash_attention_fast_softmax = True # TODO pass this along softmax_mode = "fast" if flash_attention_fast_softmax else "None" if use_flash_attention and FusedSDPA: import habana_frameworks.torch.hpu as ht + enable_recompute = q_len == 1 or flash_attention_recompute + is_causal = q_len > 1 and flash_attention_causal_mask + + attn_mask = None if flash_attention_causal_mask else attention_mask + args = (query_states, key_states, value_states, attn_mask, 0.0, is_causal, None, softmax_mode) + with ht.sdp_kernel(enable_recompute=enable_recompute): + if flash_attention_causal_mask or (not flash_attention_causal_mask and q_len <= 8192): + attn_output = self.fused_scaled_dot_product_attention(*args) + else: + attn_output = self.gaudi_flash_attn_v1(*args) + htcore.mark_step() + + ''' if q_len == 1: # next token #with ht.sdp_kernel(enable_recompute=False): # attn_output = FusedSDPA.apply( # query_states, key_states, value_states, attention_mask, 0.0, False, None # ) - with ht.sdp_kernel(enable_recompute=True): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) else: - # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention(query_states, key_states, value_states, None, 0.0, True, None, softmax_mode) - else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if q_len > 8192: - attn_output = self.gaudi_flash_attn_v1( - query_states, key_states, value_states, attention_mask, 0.0, self.block_size, softmax_mode - ) - htcore.mark_step() - else: - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) + + attn_mask = None if flash_attention_causal_mask else attention_mask + args = (query_states, key_states, value_states, attn_mask, 0.0, flash_attention_causal_mask, None, softmax_mode) + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + if flash_attention_causal_mask: + attn_output = self.fused_scaled_dot_product_attention(*args) + else: + attn_output = self.gaudi_flash_attn_v1(*args) + htcore.mark_step() + ''' + ''' + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention(query_states, key_states, value_states, attn_mask, 0.0, True, None, softmax_mode) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + if q_len > 8192: + attn_output = self.gaudi_flash_attn_v1( + query_states, key_states, value_states, attn_mask, 0.0, False, None, softmax_mode + ) + htcore.mark_step() + else: + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask, 0.0, False, None, softmax_mode + ) + ''' else: query_states, key_states, value_states, attention_mask = gaudi_qwen2_repeat_kv( From 65c01ca0fbfdb9996e5c28da768bee549ba53e5e Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 6 Jun 2024 22:22:09 +0000 Subject: [PATCH 3/8] kv cache opt --- .../transformers/models/qwen2/modeling_qwen2.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 71c8e22c9c..f929764120 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -326,7 +326,9 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) - past_key_value = (past_key, past_value) + #past_key_value = (past_key, past_value) + # Return list instead of tuple + past_key_value = [past_key, past_value] key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if token_idx is None: @@ -484,6 +486,11 @@ def pre_attn_forward( if not output_attentions: attn_weights = None + + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) return attn_output, attn_weights, past_key_value @@ -915,6 +922,7 @@ def prepare_inputs_for_generation( past_length = 0 reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -946,7 +954,7 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - elif reuse_cache and token_idx is not None: + elif (reuse_cache or bucket_internal) and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] From 99ffca78148d670678279264b5a5a2c290e708b9 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 11 Jun 2024 22:31:41 +0000 Subject: [PATCH 4/8] clean up --- .../models/qwen2/modeling_qwen2.py | 90 +------------------ 1 file changed, 1 insertion(+), 89 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index f929764120..503deeef4b 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -228,6 +228,7 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) + # TODO test with this function as well def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase @@ -342,53 +343,6 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - - ''' - if use_flash_attention or train_with_flash_attention: - is_causal = self.is_causal and query_length > 1 and flash_attention_causal_mask - if self.is_fp8: - attn_mask = None if is_causal else attention_mask - flash_attention_fast_softmax = True # TODO pass this along - softmax_mode = "fast" if flash_attention_fast_softmax else "None" - enable_recompute = query_length == 1 or flash_attention_recompute - with sdp_kernel(enable_recompute=enable_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_layer, key_layer, value_layer, attn_mask, 0.0, is_causal, None, softmax_mode - ) - else: - # TODO very similar to the fp8 case above, could be merged. - with sdp_kernel( - enable_recompute=flash_attention_recompute - ) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal and attention_mask is None, - ) - else: - if self.is_fp8: - attn_output = self.unfused_scaled_dot_product_attention( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False - ) - else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - ''' flash_attention_fast_softmax = True # TODO pass this along softmax_mode = "fast" if flash_attention_fast_softmax else "None" @@ -406,48 +360,6 @@ def pre_attn_forward( else: attn_output = self.gaudi_flash_attn_v1(*args) htcore.mark_step() - - ''' - if q_len == 1: - # next token - #with ht.sdp_kernel(enable_recompute=False): - # attn_output = FusedSDPA.apply( - # query_states, key_states, value_states, attention_mask, 0.0, False, None - # ) - with ht.sdp_kernel(enable_recompute=True): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) - else: - - attn_mask = None if flash_attention_causal_mask else attention_mask - args = (query_states, key_states, value_states, attn_mask, 0.0, flash_attention_causal_mask, None, softmax_mode) - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if flash_attention_causal_mask: - attn_output = self.fused_scaled_dot_product_attention(*args) - else: - attn_output = self.gaudi_flash_attn_v1(*args) - htcore.mark_step() - ''' - ''' - # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention(query_states, key_states, value_states, attn_mask, 0.0, True, None, softmax_mode) - else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if q_len > 8192: - attn_output = self.gaudi_flash_attn_v1( - query_states, key_states, value_states, attn_mask, 0.0, False, None, softmax_mode - ) - htcore.mark_step() - else: - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask, 0.0, False, None, softmax_mode - ) - ''' - else: query_states, key_states, value_states, attention_mask = gaudi_qwen2_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups From eca543fa46f6241254d14beb2598484216ed988e Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 11 Jun 2024 22:54:29 +0000 Subject: [PATCH 5/8] style --- optimum/habana/transformers/models/qwen2/modeling_qwen2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 503deeef4b..a6a6b9d949 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -17,10 +17,10 @@ ############################################################################### import math +import os import warnings from typing import List, Optional, Tuple, Union -import os import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -187,7 +187,7 @@ def forward(self, cur, dim, idx): class GaudiQwen2Attention(Qwen2Attention): def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - + self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) @@ -398,7 +398,7 @@ def pre_attn_forward( if not output_attentions: attn_weights = None - + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: # Return only past key value shapes and not the tensors during decode phase (q len is 1) # to avoid making past key values as persistent output tensors of HPU graphs. From 7ba78d49cc35a0f3e791e6d479d2cfd74bd91375 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 13 Jun 2024 06:24:44 +0000 Subject: [PATCH 6/8] remove unused var and add style --- optimum/habana/transformers/models/qwen2/modeling_qwen2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index a6a6b9d949..c1d7e6badf 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -17,7 +17,6 @@ ############################################################################### import math -import os import warnings from typing import List, Optional, Tuple, Union @@ -188,7 +187,6 @@ class GaudiQwen2Attention(Qwen2Attention): def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) self.matmul_qk = Matmul() From a5b9afc64647a55f5e865205be7f69cd0881fff0 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 13 Jun 2024 07:07:15 +0000 Subject: [PATCH 7/8] remove unused var and add style --- .../transformers/models/qwen2/modeling_qwen2.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index c1d7e6badf..9f4e57a62d 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -227,7 +227,9 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): return (self.k_cache.cache.shape, self.v_cache.cache.shape) # TODO test with this function as well - def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode): + def gaudi_flash_attn_v1( + self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode + ): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase Causal mask is not supported in this optimization @@ -245,7 +247,9 @@ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mas s, e = i * self.block_size, (i + 1) * self.block_size row_q = query_layer[:, :, s:e, :] row_mask = attention_mask[:, :, s:e, :] - attn_output_partial = self.fused_scaled_dot_product_attention(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode) + attn_output_partial = self.fused_scaled_dot_product_attention( + row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode + ) row_o_list.append(attn_output_partial) attn_output = torch.cat(row_o_list, dim=-2) @@ -285,7 +289,6 @@ def pre_attn_forward( warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None bsz, q_len, _ = hidden_states.size() @@ -325,7 +328,7 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) - #past_key_value = (past_key, past_value) + # past_key_value = (past_key, past_value) # Return list instead of tuple past_key_value = [past_key, past_value] key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) From 4281d6bcb0ae806de35721c7e4593c6a29762c23 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 17 Jun 2024 10:01:30 -0700 Subject: [PATCH 8/8] Update modeling_qwen2.py --- optimum/habana/transformers/models/qwen2/modeling_qwen2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 9f4e57a62d..27d2c24716 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -328,7 +328,6 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) - # past_key_value = (past_key, past_value) # Return list instead of tuple past_key_value = [past_key, past_value] key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)