Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,32 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
The input tokens length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""

def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
("eos_token_id", eos_token_id),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens:
scores[:, self.eos_token_id] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = -float("inf")

return scores

Expand Down
28 changes: 21 additions & 7 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
from typing import List, Union

from parameterized import parameterized

from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
Expand Down Expand Up @@ -76,20 +78,26 @@ def test_min_length_dist_processor(self):
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())

def test_new_min_length_dist_processor(self):
@parameterized.expand([(0,), ([0, 18],)])
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]]):
vocab_size = 20
batch_size = 4
eos_token_id = 0

# check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
)

expected_eos_scores_before_min_length = batch_size * [-float("inf")]
if isinstance(eos_token_id, list):
expected_eos_scores_before_min_length *= len(eos_token_id)

scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)

# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
Expand All @@ -98,19 +106,25 @@ def test_new_min_length_dist_processor(self):
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)

# check that min new length is applied at length 6 (because it has only 1 new token)
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)

# check that min new length is applied at length 7 (because it has only 2 new tokens)
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)

# check that min new length is not applied anymore at length 8
input_ids = ids_tensor((batch_size, 8), vocab_size=20)
Expand Down