From e2dfad841b2d505478cff3923ce46ffc703674c9 Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Sun, 5 Mar 2023 17:05:35 +0000 Subject: [PATCH 1/4] Fix MinNewTokensLengthLogitsProcessor when used with a list of eos tokens --- src/transformers/generation/logits_process.py | 11 +++++-- tests/generation/test_logits_process.py | 33 +++++++++++++++---- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index ba777f1e8e86..ddac29dee39b 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -137,15 +137,19 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): The id of the *end-of-sequence* token. """ - def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int): + def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]): for arg_name, arg_value in [ ("prompt_length_to_skip", prompt_length_to_skip), ("min_new_tokens", min_new_tokens), - ("eos_token_id", eos_token_id), ]: if not isinstance(arg_value, int) or arg_value < 0: raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + self.prompt_length_to_skip = prompt_length_to_skip self.min_new_tokens = min_new_tokens self.eos_token_id = eos_token_id @@ -153,7 +157,8 @@ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip if new_tokens_length < self.min_new_tokens: - scores[:, self.eos_token_id] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = -float("inf") return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index c377a23e7a70..bb4419afbea5 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import unittest +from typing import List, Union + +from parameterized import parameterized from transformers import is_torch_available from transformers.testing_utils import require_torch, torch_device @@ -76,10 +78,15 @@ def test_min_length_dist_processor(self): scores_before_min_length = min_dist_processor(input_ids, scores) self.assertFalse(torch.isinf(scores_before_min_length).any()) - def test_new_min_length_dist_processor(self): + @parameterized.expand( + [ + (0,), + ([0, 18],), + ] + ) + def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]]): vocab_size = 20 batch_size = 4 - eos_token_id = 0 # check that first input is skipped (min new length applying) input_ids = ids_tensor((batch_size, 5), vocab_size=20) @@ -87,9 +94,15 @@ def test_new_min_length_dist_processor(self): prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id ) + expected_eos_scores_before_min_length = batch_size * [-float("inf")] + if isinstance(eos_token_id, list): + expected_eos_scores_before_min_length *= len(eos_token_id) + scores = self._get_uniform_logits(batch_size, vocab_size) scores_before_min_length = new_min_dist_processor(input_ids, scores) - self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + self.assertListEqual( + scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length + ) # check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5) @@ -98,19 +111,25 @@ def test_new_min_length_dist_processor(self): input_ids = ids_tensor((batch_size, 2), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) scores_before_min_length = new_min_dist_processor(input_ids, scores) - self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + self.assertListEqual( + scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length + ) # check that min new length is applied at length 6 (because it has only 1 new token) input_ids = ids_tensor((batch_size, 6), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) scores_before_min_length = new_min_dist_processor(input_ids, scores) - self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + self.assertListEqual( + scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length + ) # check that min new length is applied at length 7 (because it has only 2 new tokens) input_ids = ids_tensor((batch_size, 7), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) scores_before_min_length = new_min_dist_processor(input_ids, scores) - self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + self.assertListEqual( + scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length + ) # check that min new length is not applied anymore at length 8 input_ids = ids_tensor((batch_size, 8), vocab_size=20) From c7872956d246a1814ce1e2783923de3b7d2e46c5 Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Sun, 5 Mar 2023 17:14:14 +0000 Subject: [PATCH 2/4] fix docs --- src/transformers/generation/logits_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index ddac29dee39b..73e1bdb214e6 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -133,8 +133,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): The input tokens length. min_new_tokens (`int`): The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]): From 2f3c9eaaaeceaa46df68aadb5d2a25cdabc87dd0 Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Sun, 5 Mar 2023 17:42:45 +0000 Subject: [PATCH 3/4] Empty commit From ec9f8151ad567752897658c695cdb92b89774090 Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Tue, 7 Mar 2023 11:34:02 +0000 Subject: [PATCH 4/4] formatting --- tests/generation/test_logits_process.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index bb4419afbea5..85ebf780f7dc 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -78,12 +78,7 @@ def test_min_length_dist_processor(self): scores_before_min_length = min_dist_processor(input_ids, scores) self.assertFalse(torch.isinf(scores_before_min_length).any()) - @parameterized.expand( - [ - (0,), - ([0, 18],), - ] - ) + @parameterized.expand([(0,), ([0, 18],)]) def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]]): vocab_size = 20 batch_size = 4