Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
14d5b6c
diverse beam search
Nov 17, 2020
de6280f
bug fixes
Nov 18, 2020
3ab6551
bug fixes
Nov 18, 2020
da6654d
bug fix
Nov 18, 2020
9e7d976
separate out diverse_beam_search function
Nov 20, 2020
69a91b4
separate out diverse_beam_search function
Nov 20, 2020
a1b57d2
bug fix
Nov 20, 2020
8ffe8fd
improve code quality
Nov 20, 2020
191a59d
bug fix
Nov 20, 2020
ea945b2
bug fix
Nov 20, 2020
b363a33
Merge branch 'master' into diverse_beam_search
ayushtiku5 Nov 20, 2020
4683c04
separate out diverse beam search scorer
Nov 21, 2020
77a861a
code format
Nov 21, 2020
26eb8a8
code format
Nov 21, 2020
c35492b
code format
Nov 21, 2020
f63017a
code format
Nov 21, 2020
217a1fa
add test
Nov 21, 2020
f08cfe5
code format
Nov 21, 2020
9d70206
documentation changes
Dec 5, 2020
9b61e09
code quality
Dec 5, 2020
b1d2269
add slow integration tests
patrickvonplaten Dec 7, 2020
26f9e0f
more general name
patrickvonplaten Dec 7, 2020
fb433cd
refactor into logits processor
patrickvonplaten Dec 7, 2020
6b61b8e
add test
patrickvonplaten Dec 7, 2020
68d841c
avoid too much copy paste
patrickvonplaten Dec 7, 2020
8751681
refactor
patrickvonplaten Dec 7, 2020
a221ef0
add to docs
patrickvonplaten Dec 7, 2020
7d8d469
fix-copies
patrickvonplaten Dec 7, 2020
3ca6a2a
Merge remote-tracking branch 'main/master' into diverse_beam_search
patrickvonplaten Dec 7, 2020
c99eb5a
bug fix
Dec 7, 2020
da77f29
Revert "bug fix"
Dec 7, 2020
757bd35
improve comment
patrickvonplaten Dec 7, 2020
81b9533
implement sylvains feedback
patrickvonplaten Dec 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/internal/generation_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ generation.
.. autoclass:: transformers.NoBadWordsLogitsProcessor
:members: __call__

.. autoclass:: transformers.PrefixConstrainedLogitsProcessor
:members: __call__

.. autoclass:: transformers.HammingDiversityLogitsProcessor
:members: __call__

BeamSearch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,14 @@
)
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
HammingDiversityLogitsProcessor,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class PretrainedConfig(object):
sentences are finished per batch or not.
- **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by
default in the :obj:`generate` method of the model. 1 means no beam search.
- **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams`
into in order to ensure diversity among different groups of beams that will be used by default in the
:obj:`generate` method of the model. 1 means no group beam search.
- **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group
beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity
penalty. The higher the penalty, the more diverse are the outputs.
- **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
positive.
Expand Down Expand Up @@ -185,6 +191,8 @@ def __init__(self, **kwargs):
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
Expand Down
38 changes: 27 additions & 11 deletions src/transformers/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ class BeamSearchScorer(BeamScorer):
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.

Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation
<https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua>`__

Args:
batch_size (:obj:`int`):
Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel.
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
num_beams (:obj:`int`):
Expand All @@ -141,6 +144,9 @@ class BeamSearchScorer(BeamScorer):
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
:meth:`~transformer.BeamSearchScorer.finalize`.
num_beam_groups (:obj:`int`):
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
"""

def __init__(
Expand All @@ -152,13 +158,16 @@ def __init__(
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups

self._is_init = False
self._beam_hyps = [
Expand All @@ -177,6 +186,12 @@ def __init__(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
)

if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)

@property
def is_done(self) -> bool:
return self._done.all()
Expand All @@ -192,12 +207,12 @@ def process(
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
assert batch_size == (input_ids.shape[0] // self.num_beams)
assert batch_size == (input_ids.shape[0] // self.group_size)

device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand All @@ -218,11 +233,11 @@ def process(
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.num_beams + next_index
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
Expand All @@ -237,12 +252,12 @@ def process(
beam_idx += 1

# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.num_beams:
if beam_idx == self.group_size:
break

if beam_idx < self.num_beams:
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)

# Check if we are done so that we can save a pad step if all(done)
Expand Down Expand Up @@ -274,7 +289,8 @@ def finalize(
if self._done[batch_idx]:
continue

# need to add best num_beams hypotheses to generated hyps
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
Expand Down
76 changes: 74 additions & 2 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import math
from abc import ABC
from typing import Callable, Iterable, List
Expand All @@ -37,6 +38,8 @@
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
kwargs:
Additional logits processor specific kwargs.

Return:
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
Expand Down Expand Up @@ -75,9 +78,16 @@ class LogitsProcessorList(list):
"""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
for processor in self:
scores = processor(input_ids, scores)
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
assert all(
arg in kwargs for arg in list(function_args.keys())[2:]
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores


Expand Down Expand Up @@ -400,3 +410,65 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0

return scores + mask


class HammingDiversityLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only
effective for `group_beam_search`. See `Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models
<https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.

Args:
diversity_penalty (:obj:`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled.
num_beams (:obj:`int`):
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for
more details.
num_beam_groups (:obj:`int`):
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
"""

def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
self._diversity_penalty = diversity_penalty
if not isinstance(num_beams, int) or num_beams < 2:
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
self._num_beams = num_beams
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`")
self._num_sub_beams = num_beams // num_beam_groups

def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
current_tokens: torch.LongTensor,
beam_group_idx: int,
) -> torch.FloatTensor:
# hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step
batch_size = current_tokens.shape[0] // self._num_beams
group_start_idx = beam_group_idx * self._num_sub_beams
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
group_size = group_end_idx - group_start_idx
vocab_size = scores.shape[-1]

if group_start_idx == 0:
return scores

for batch_idx in range(batch_size):
# predicted tokens of last time step of previous groups
previous_group_tokens = current_tokens[
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
]
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency

return scores
Loading