From 9d7fbf60f275aca8408a1be038fd9fba66a0d328 Mon Sep 17 00:00:00 2001 From: Alex Calderwood Date: Tue, 13 Aug 2024 22:56:37 -0700 Subject: [PATCH 1/2] Fix beam_constraints.Constraint.advance() docstring --- src/transformers/generation/beam_constraints.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/beam_constraints.py b/src/transformers/generation/beam_constraints.py index e6462f322c49..1b6e3304db21 100644 --- a/src/transformers/generation/beam_constraints.py +++ b/src/transformers/generation/beam_constraints.py @@ -48,10 +48,13 @@ def test(self): @abstractmethod def advance(self): """ - When called, returns the token that would take this constraint one step closer to being fulfilled. + When called, returns the token(s) that would take this constraint one step closer to being fulfilled. Return: - token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer. + token_ids (Union[int, List[int], None]): + - A single token ID (int) that advances the constraint, or + - A list of token IDs that could advance the constraint + - None if the constraint is completed or cannot be advanced """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." From 8176c7d9d32979c597bfaa31f2511887ddcb67d8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 16 Aug 2024 19:07:06 +0100 Subject: [PATCH 2/2] Update src/transformers/generation/beam_constraints.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/transformers/generation/beam_constraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/beam_constraints.py b/src/transformers/generation/beam_constraints.py index 1b6e3304db21..daf64209b796 100644 --- a/src/transformers/generation/beam_constraints.py +++ b/src/transformers/generation/beam_constraints.py @@ -51,7 +51,7 @@ def advance(self): When called, returns the token(s) that would take this constraint one step closer to being fulfilled. Return: - token_ids (Union[int, List[int], None]): + token_ids (Union[int, List[int], None]): - A single token ID (int) that advances the constraint, or - A list of token IDs that could advance the constraint - None if the constraint is completed or cannot be advanced