diff --git a/src/transformers/generation/beam_constraints.py b/src/transformers/generation/beam_constraints.py index e6462f322c49..daf64209b796 100644 --- a/src/transformers/generation/beam_constraints.py +++ b/src/transformers/generation/beam_constraints.py @@ -48,10 +48,13 @@ def test(self): @abstractmethod def advance(self): """ - When called, returns the token that would take this constraint one step closer to being fulfilled. + When called, returns the token(s) that would take this constraint one step closer to being fulfilled. Return: - token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer. + token_ids (Union[int, List[int], None]): + - A single token ID (int) that advances the constraint, or + - A list of token IDs that could advance the constraint + - None if the constraint is completed or cannot be advanced """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."