From 1d067aced606b7ad4b30b7fa7fe04bdc49ccdb40 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Fri, 29 Mar 2024 18:12:01 +0000 Subject: [PATCH] Fix the pytest for llama and falcon when token_idx is None. --- .../models/falcon/modeling_falcon.py | 16 ++++++---------- .../transformers/models/llama/modeling_llama.py | 2 ++ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index d86f91bb75..9b9a74c12f 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -334,16 +334,12 @@ def pre_attn_forward( device=self.query_key_value.weight.device, ) layer_past = (past_key, past_value) - key_layer = self.k_cache.update( - layer_past[0], key_layer, -2, token_idx, self.inp_seq_len - ) # k_layer bs*1, q_len, head_dim - value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) - else: - key_layer = self.k_cache.update( - layer_past[0], key_layer, -2, token_idx, self.inp_seq_len - ) # k_layer bs*1, q_len, head_dim - value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) - layer_past = (key_layer.contiguous(), value_layer.contiguous()) + key_layer = self.k_cache.update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + if token_idx is None: + layer_past = (key_layer, value_layer) present = layer_past if cache_idx is not None and query_length == 1: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 15c74dc15b..4d0f3513d7 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -383,6 +383,8 @@ def pre_attn_forward( past_key_value = (past_key, past_value) key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) if cache_idx is not None and q_len == 1: key_states = key_states[:, :, :cache_idx, :]