diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 22652182f9..3b32c83e12 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -262,11 +262,10 @@ def update(self, prev, cur, dim, idx, inp_seq_len): if prev.shape == cur.shape: prev.copy_(cur) return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: # Initialize prev[:, :, :inp_seq_len, :].copy_(cur) return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" if idx is not None: prev.index_copy_(dim, idx - 1, cur) return prev diff --git a/tests/baselines/llama_7b.json b/tests/baselines/llama_7b.json index 8c6ac6b882..f1a556a35b 100644 --- a/tests/baselines/llama_7b.json +++ b/tests/baselines/llama_7b.json @@ -230,16 +230,16 @@ "multi_card": { "learning_rate": 5e-4, "train_batch_size": 1, - "train_runtime": 16.5, - "train_samples_per_second": 63.161, - "perplexity": 1.224, + "train_runtime": 16.1, + "train_samples_per_second": 63.249, + "perplexity": 1.172, "extra_arguments": [ "--num_virtual_tokens 8", "--max_seq_length 64", "--logging_steps 1", "--report_to none", "--max_steps 100", - "--peft_type prompt_tuning", + "--peft_type prefix_tuning", "--max_seq_length 64", "--lr_scheduler_type cosine", "--warmup_steps 0", @@ -256,16 +256,16 @@ "multi_card": { "learning_rate": 5e-4, "train_batch_size": 1, - "train_runtime": 16.5, + "train_runtime": 18.7, "train_samples_per_second": 63.161, - "perplexity": 1.224, + "perplexity": 1.047, "extra_arguments": [ "--num_virtual_tokens 8", "--max_seq_length 64", "--logging_steps 1", "--report_to none", "--max_steps 100", - "--peft_type prompt_tuning", + "--peft_type p_tuning", "--max_seq_length 64", "--lr_scheduler_type cosine", "--warmup_steps 0", @@ -276,4 +276,4 @@ } } } -} \ No newline at end of file +}