From 7edc640ef4f9b57dd2a6659374a6982c55b7003b Mon Sep 17 00:00:00 2001 From: Fei Wang <19998174+FeiWang96@users.noreply.github.com> Date: Mon, 12 Jun 2023 21:06:44 -0700 Subject: [PATCH 1/3] Fix LLaMa beam search when using parallelize same issue as T5 #11717 --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3d31928562a2..d636a32a17fd 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -760,7 +760,7 @@ def prepare_inputs_for_generation( def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past From 5563a48fac61bd3ffdddb63990fdd751f1c08aa2 Mon Sep 17 00:00:00 2001 From: Fei Wang <19998174+FeiWang96@users.noreply.github.com> Date: Tue, 13 Jun 2023 10:26:30 -0700 Subject: [PATCH 2/3] fix code format in modeling_llama.py --- src/transformers/models/llama/modeling_llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d636a32a17fd..205cf2b19b86 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -760,7 +760,12 @@ def prepare_inputs_for_generation( def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) return reordered_past From a44abefcd18082d983ce9fda9b1c0c92a21de709 Mon Sep 17 00:00:00 2001 From: Fei Wang <19998174+FeiWang96@users.noreply.github.com> Date: Wed, 14 Jun 2023 20:49:40 -0700 Subject: [PATCH 3/3] fix format of _reorder_cache in modeling_llama.py --- src/transformers/models/llama/modeling_llama.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 205cf2b19b86..27c62ca49673 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -761,10 +761,7 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ), + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past