From ae0bea9e8580bbedea96b45d4153be7a8b884d74 Mon Sep 17 00:00:00 2001 From: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> Date: Tue, 7 May 2024 16:29:31 +0800 Subject: [PATCH] update (#8359) * change llama/modeling.py to opt npu performence * update * update * Update modeling.py * add judge * update * update --------- Co-authored-by: Wang Huan --- paddlenlp/transformers/llama/modeling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 03dff343dcad..37c573189821 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -96,7 +96,7 @@ def swiglu(x, y=None): "LlamaForCausalLM", "LlamaPretrainingCriterion", ] - +global npu_is_casual npu_is_casual = False def _get_interleave(n): @@ -213,7 +213,7 @@ def scaled_dot_product_attention( ): bsz, q_len, num_heads, head_dim = query_states.shape _, kv_seq_len, _, _ = value_states.shape - + global npu_is_casual if config.use_flash_attention and flash_attention: # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] @@ -1613,6 +1613,7 @@ def forward( attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] + global npu_is_casual if self.config.use_flash_attention: is_casual = is_casual_mask(attention_mask) if get_env_device() != "npu":