Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -156,6 +155,7 @@ def __init__(
self,
batch_size: int,
num_beams: int,
max_length: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BeamSearchScorer is a public class (even though, I don't think it's used that much on its own), do you think we could maybe do:

Suggested change
max_length: int,
max_length: Optional[int] = None,

and throw an error if do_early_stopping is set to False that says that one should do_early_stopping=True to not have to pass max_length.

device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
Expand All @@ -167,6 +167,7 @@ def __init__(
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.max_length = max_length
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
Expand All @@ -177,6 +178,7 @@ def __init__(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=self.max_length,
)
for _ in range(batch_size)
]
Expand All @@ -194,13 +196,6 @@ def __init__(
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)

if "max_length" in kwargs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_length is now a mandatory argument, so this warning no longer makes sense. The test that confirms that this warning is thrown was also removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can delete kwargs then also from the __init__

warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)

@property
def is_done(self) -> bool:
return self._done.all()
Expand Down Expand Up @@ -424,6 +419,7 @@ def __init__(
self,
batch_size: int,
num_beams: int,
max_length: int,
constraints: List[Constraint],
device: torch.device,
length_penalty: Optional[float] = 1.0,
Expand All @@ -436,6 +432,7 @@ def __init__(
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.max_length = max_length
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
Expand All @@ -447,6 +444,7 @@ def __init__(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=self.max_length,
)
for _ in range(batch_size)
]
Expand All @@ -464,13 +462,6 @@ def __init__(
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)

if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)

@property
def is_done(self) -> bool:
return self._done.all()
Expand Down Expand Up @@ -851,12 +842,13 @@ def finalize(


class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: int):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.max_length = max_length
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
Expand Down Expand Up @@ -892,6 +884,9 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
if self.length_penalty > 0.0:
cur_score = best_sum_logprobs / self.max_length**self.length_penalty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great find!

else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
Comment on lines +887 to +890
Copy link
Contributor Author

@gante gante Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implements the logic as described in the PR header. max_length is now needed as an input, which implied some changes in the tests (not in terms of results, but rather in terms of class initialization).

ret = self.worst_score >= cur_score
return ret
7 changes: 5 additions & 2 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,14 @@ def beam_search_cond_fn(state):
not_max_length_yet = state.cur_len < max_length

# 2. can the new beams still improve?
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
if length_penalty > 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with XLA? Gives no errors?

best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
else:
best_running_score = state.running_scores[:, -1:] / (state.cur_len**length_penalty)
worst_finished_score = jnp.where(
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
)
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see comment on the TF implementation)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrectly copied haha - good catch!


# 3. is there still a beam that has not finished?
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3066,16 +3066,19 @@ def beam_search_cond_fn(
not_max_length_yet = cur_len < max_length

# 2. can the new beams still improve?
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
if length_penalty > 0.0:
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
else:
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
)
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)
improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before: improvement was possible when ALL finished scores were worse than the best running scores. In other words, if batch member 0 had running candidates that were better than the finished candidates' scores, but batch member 1 did not, this condition would evaluate to False because of batch member 1. This means that we were terminating beam search even though an improvement as still possible for batch member 0.


# 3. is there still a beam that has not finished?
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)

return not_max_length_yet & (still_open_beam | improvement_still_possible)
return not_max_length_yet & still_open_beam & improvement_still_possible
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙈 this is what happens when test fixing is done without thinking deeply on the subject: the previous condition, combined with the previous handling of improvement_still_possible and best_running_score, made all tests pass. But for the wrong reasons -- early_stopping=True was not operating as intended before


def beam_search_body_fn(
cur_len,
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,7 @@ def generate(
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
Expand Down Expand Up @@ -1442,6 +1443,7 @@ def generate(
beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences,
num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
Expand Down Expand Up @@ -1577,6 +1579,7 @@ def typeerror():
constraints=final_constraints,
batch_size=batch_size,
num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
Expand Down Expand Up @@ -2534,6 +2537,7 @@ def beam_search(
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
Expand Down Expand Up @@ -3543,7 +3547,11 @@ def constrained_beam_search(

>>> # instantiate beam scorer
>>> beam_scorer = ConstrainedBeamSearchScorer(
... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
... batch_size=1,
... num_beams=num_beams,
... max_length=model.config.max_length,
... device=model.device,
... constraints=constraints,
... )

>>> # instantiate logits processors
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,7 @@ def extend_enc_output(tensor, num_beams=None):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
max_length=generation_config.max_length,
device=self.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
Expand Down
14 changes: 10 additions & 4 deletions tests/generation/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
max_length=kwargs.get("max_length", self.max_length),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
Expand All @@ -81,7 +82,7 @@ def prepare_inputs(self):

def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True, max_length=input_ids.shape[-1])
beam_hyp = beam_scorer._beam_hyps[0]

self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
Expand All @@ -100,7 +101,7 @@ def check_beam_hypotheses(self, input_ids, *args):
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))

# re-init
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False, max_length=input_ids.shape[-1])
beam_hyp = beam_scorer._beam_hyps[0]

# add `num_beams + 1` beams to change `worst_score`
Expand Down Expand Up @@ -291,6 +292,7 @@ def prepare_constrained_beam_scorer(self, **kwargs):
constraints=kwargs.get("constraints", self.constraints),
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
max_length=kwargs.get("max_length", self.max_length),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
Expand All @@ -309,7 +311,9 @@ def prepare_inputs(self):

def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
do_early_stopping=True, max_length=input_ids.shape[-1]
)
beam_hyp = constrained_beam_scorer._beam_hyps[0]

self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
Expand All @@ -328,7 +332,9 @@ def check_beam_hypotheses(self, input_ids, *args):
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))

# re-init
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
do_early_stopping=False, max_length=input_ids.shape[-1]
)
beam_hyp = constrained_beam_scorer._beam_hyps[0]

# add `num_beams + 1` beams to change `worst_score`
Expand Down
67 changes: 13 additions & 54 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=beam_kwargs["num_beams"],
max_length=max_length,
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
Expand All @@ -194,6 +195,7 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=beam_kwargs["num_beams"],
max_length=max_length,
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
Expand All @@ -214,6 +216,7 @@ def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints,
batch_size=batch_size,
constraints=constraints,
num_beams=beam_kwargs["num_beams"],
max_length=max_length,
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
Expand Down Expand Up @@ -1898,6 +1901,7 @@ def test_max_length_backward_compat_beam_search(self):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
max_length=max_length,
device=torch_device,
)
with self.assertWarns(UserWarning):
Expand Down Expand Up @@ -1930,6 +1934,7 @@ def test_max_length_backward_compat_group_beam_search(self):
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
max_length=max_length,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
Expand Down Expand Up @@ -1991,6 +1996,7 @@ def test_max_length_warning_if_different(self):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
max_length=max_length,
device=torch_device,
)
with self.assertWarns(UserWarning):
Expand All @@ -2008,6 +2014,7 @@ def test_max_length_warning_if_different(self):
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
max_length=max_length,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
Expand All @@ -2022,59 +2029,6 @@ def test_max_length_warning_if_different(self):
**model_kwargs,
)

def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)

batch_size = 1
num_beams = 3

input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})

# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids = input_ids[:, :1]

stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])

with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)

generated_ids = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)

beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)

generated_ids_no_max_len = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)

# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())

def test_custom_stopping_criteria_overload_error(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
Expand Down Expand Up @@ -2744,6 +2698,7 @@ def test_beam_search_example_integration(self):
beam_scorer = BeamSearchScorer(
batch_size=1,
num_beams=num_beams,
max_length=model.config.max_length,
device=model.device,
)

Expand Down Expand Up @@ -2924,7 +2879,11 @@ def test_constrained_beam_search_example_integration(self):

# instantiate beam scorer
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
batch_size=1,
num_beams=num_beams,
max_length=model.config.max_length,
device=model.device,
constraints=constraints,
)

# instantiate logits processors
Expand Down
Loading