Skip to content

Commit

Permalink
update (#8359)
Browse files Browse the repository at this point in the history
* change llama/modeling.py to opt npu performence

* update

* update

* Update modeling.py

* add judge

* update

* update

---------

Co-authored-by: Wang Huan <[email protected]>
  • Loading branch information
Galaxy1458 and wanghuancoder authored May 7, 2024
1 parent 09a0ce7 commit ae0bea9
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def swiglu(x, y=None):
"LlamaForCausalLM",
"LlamaPretrainingCriterion",
]

global npu_is_casual
npu_is_casual = False

def _get_interleave(n):
Expand Down Expand Up @@ -213,7 +213,7 @@ def scaled_dot_product_attention(
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape

global npu_is_casual
if config.use_flash_attention and flash_attention:
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
Expand Down Expand Up @@ -1613,6 +1613,7 @@ def forward(
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
global npu_is_casual
if self.config.use_flash_attention:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
Expand Down

0 comments on commit ae0bea9

Please sign in to comment.