Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Blocking repeated ngrams during beam search #5216

Merged
merged 15 commits into from
Jun 1, 2021

Conversation

danieldeutsch
Copy link
Contributor

Closes #5205.

Changes proposed in this pull request:

  • Adds a new Constraint abstract class for enforcing constraints during beam search
  • Adds a NGramBlockingConstraint class to prevent decoding repeated ngrams.

This is a work in progress and maybe close to finishing, so I wanted to get feedback. @epwalsh, any thoughts on

  1. The Constraint interface? I limited the methods to just those which I needed to implement the ngram blocking
  2. How to more efficiently implement the ngram blocking? I'm not sure there's a way to get around maintaining which ngrams have appeared before in a dictionary.
  3. How to test the end-to-end beam search with the ngram blocking constraint? I've tested the individual methods and some toy examples on my own. To really test it, I would need to come up with a transition matrix which has repeated ngrams by default and then block them. I couldn't think of a simple one which wouldn't require a lot of manual effort to ensure the output is correct.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good so far. I do think having an end-to-end test is important. Is it feasible to at least make a test that assumes ngram_size is 1?

One comment I have about the API for the Constraint class is that having to deal with backpointers in update_state() is not ideal. Can we automatically handle reordering the state list outside of update_state(), either in the BeamSearch class or the Constraint base class?

allennlp/nn/beam_search.py Outdated Show resolved Hide resolved
allennlp/nn/beam_search.py Outdated Show resolved Hide resolved
allennlp/nn/beam_search.py Outdated Show resolved Hide resolved
@danieldeutsch danieldeutsch changed the title [WIP] Blocking repeated ngrams during beam search Blocking repeated ngrams during beam search May 25, 2021
@danieldeutsch
Copy link
Contributor Author

In this version:

  • I incorporated your comments on the code
  • I moved the logic to copy the parent's state into the Constraint class
  • I added end-to-end unit tests for using the n-gram blocking in BeamSearch with a search that I manually traced. See test_take_repeated_ngram_step

I removed "WIP" from the PR title because I think this version is complete to me, pending comments.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @danieldeutsch! Just a few more minor comments.

CHANGELOG.md Outdated
@@ -33,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences.
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`.
- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling.
- Added a `Constrant` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Added a `Constrant` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`.
- Added a `Constraint` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`.

class Constraint(Registrable):
"""
An abstract class that can be used to enforce constraints on the output predictions
by manipulate the class log probabilities during beam search.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
by manipulate the class log probabilities during beam search.
by manipulating the class log probabilities during beam search.


The `apply()` method should manipulate the `class_log_probabilities` in place to enforce the constraint
for this step of beam search. For instance, it may prevent a specific class from being selected by setting
the corresponding log probability to `-inf` (by using `min_value_of_dtype(class_log_probabilities.dtype)`).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
the corresponding log probability to `-inf` (by using `min_value_of_dtype(class_log_probabilities.dtype)`).
the corresponding log probability to a negligible value by using `min_value_of_dtype(class_log_probabilities.dtype)`, which is essentially equivalent to `-inf`.

Comment on lines 987 to 991
for constraint, constraint_state in zip(self.constraints, constraint_states):
# shape: (batch_size, beam_size, num_classes)
reshaped_class_log_probabilities = class_log_probabilities.view(
batch_size, self.beam_size, -1
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to put the reshape outside of the loop?

Suggested change
for constraint, constraint_state in zip(self.constraints, constraint_states):
# shape: (batch_size, beam_size, num_classes)
reshaped_class_log_probabilities = class_log_probabilities.view(
batch_size, self.beam_size, -1
)
# shape: (batch_size, beam_size, num_classes)
reshaped_class_log_probabilities = class_log_probabilities.view(
batch_size, self.beam_size, -1
)
for constraint, constraint_state in zip(self.constraints, constraint_states):

- `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.

The `apply()` method should manipulate the `class_log_probabilities` in place to enforce the constraint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more natural for the apply() method to return new class_log_probabilities. It's a little easier to reason about.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @danieldeutsch!

@epwalsh epwalsh enabled auto-merge (squash) June 1, 2021 17:27
@epwalsh epwalsh merged commit c014232 into allenai:main Jun 1, 2021
@danieldeutsch danieldeutsch deleted the ngram-blocking branch June 1, 2021 19:12
@danieldeutsch
Copy link
Contributor Author

Thanks @epwalsh for finishing it. I forgot about it over the holiday

backpointer = last_backpointer[i, j].item()
batch_state.append(copy.deepcopy(state[i][backpointer]))
new_state.append(batch_state)
return new_state
Copy link
Contributor

@JohnGiorgi JohnGiorgi Jul 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danieldeutsch @epwalsh Sorry, I know this PR is closed but I have a question that's easiest to ask here because of the existing context.

It seems that last_backpointer, which is passed to _copy_state by update_state will always be None (its never provided to update_state in BeamSearch). That would mean that backpointer will always by 0 and then then it will always be state[i][0] being copied, regardless of the actual timestep. Isn't this a problem? The tests don't catch this because they provide backpointer to update manually (see my other comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks like you are right to me. I think the fix should be to pass backpointer here, right?

for i, constraint in enumerate(self.constraints):
constraint_states[i] = constraint.update_state(
constraint_states[i], restricted_predicted_classes
)

],
]
predictions = torch.LongTensor([[5, 6], [0, 3]])
backpointers = torch.LongTensor([[1, 1], [0, 1]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backpointers provided here to update, but not in BeamSearch (see my other comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need to retrace the beam search that I used to write this test, but my guess is that this is not caught because the backpointer is always 0

def test_repeated_ngram_blocking_end_to_end(self):

Abhishek-P pushed a commit to Abhishek-P/allennlp that referenced this pull request Aug 11, 2021
* Implementing blocking repeated ngrams

* Adding comment

* Adding unit tests for the end to end beam search

* Renaming class

* Adding comment about  function

* Simplifying indexing to variable

* Refactoring the state copying into the  class

* Reformatting

* Editing changelog

* fix line too long

* comments

* doc updates

Co-authored-by: Pete <[email protected]>
Co-authored-by: epwalsh <[email protected]>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add features to beam search that are supported in other libraries
3 participants