Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/generation/beam_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ def test(self):
@abstractmethod
def advance(self):
"""
When called, returns the token that would take this constraint one step closer to being fulfilled.
When called, returns the token(s) that would take this constraint one step closer to being fulfilled.

Return:
token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
token_ids (Union[int, List[int], None]):
- A single token ID (int) that advances the constraint, or
- A list of token IDs that could advance the constraint
- None if the constraint is completed or cannot be advanced
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
Expand Down