Skip to content

Commit a78e884

Browse files
ganteArthurZucker
authored andcommitted
[generate] beam search -- fix output cropping (#37080)
* handle jagged beams * better comment * bart -- beam search tests print special tokens * more bart test updates * more tests! * better comment
1 parent e9a5e32 commit a78e884

File tree

5 files changed

+74
-45
lines changed

5 files changed

+74
-45
lines changed

src/transformers/generation/utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3887,9 +3887,14 @@ def _beam_search(
38873887
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
38883888
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
38893889

3890-
# Crop the static-shaped tensors to the actual size
3891-
sequences = sequences[:, :cur_len]
3892-
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
3890+
# Crop the static-shaped tensors to the actual size.
3891+
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
3892+
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
3893+
# previous decoding iteration)
3894+
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
3895+
output_length = decoder_prompt_len + max_generated_length
3896+
sequences = sequences[:, :output_length]
3897+
beam_indices = beam_indices[:, :max_generated_length]
38933898

38943899
if return_dict_in_generate:
38953900
if not output_scores:

tests/models/bart/test_modeling_bart.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -599,13 +599,15 @@ def test_xsum_1_1_generation(self):
599599
" 2002 to prosecute genocide, crimes against humanity and war crimes."
600600
)
601601
EXPECTED = (
602+
"</s>"
602603
" The International Criminal Court (ICC) has announced that it has been announced by the International"
603604
" Criminal court."
605+
"</s>"
604606
)
605607

606608
dct = tok(ARTICLE, return_tensors="pt")
607609
generated_ids = hf.generate(**dct, num_beams=4)
608-
result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
610+
result = tok.batch_decode(generated_ids)[0]
609611
assert EXPECTED == result
610612

611613
def test_xsum_1_1_batch_generation(self):
@@ -729,16 +731,18 @@ def test_xsum_1_1_batch_generation(self):
729731
truncation=True,
730732
)
731733
generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
732-
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
733-
assert (
734-
result[0]
735-
== " The International Criminal Court (ICC) has announced that it has been announced by the International"
734+
result = self.tok.batch_decode(generated_ids)
735+
assert result[0] == (
736+
"</s>"
737+
" The International Criminal Court (ICC) has announced that it has been announced by the International"
736738
" Criminal court."
739+
"</s><pad><pad><pad><pad><pad>"
737740
)
738-
assert (
739-
result[1]
740-
== " An investigation into the crash that killed at least 10 people in the French capital has been"
741+
assert result[1] == (
742+
"</s>"
743+
" An investigation into the crash that killed at least 10 people in the French capital has been"
741744
" released by the French police investigating the crash."
745+
"</s>"
742746
)
743747

744748
def test_encoder_equiv(self):
@@ -939,8 +943,10 @@ def test_xsum_summarization_same_as_fairseq(self):
939943
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
940944

941945
EXPECTED_SUMMARY = (
946+
"</s>"
942947
"California's largest power company has begun shutting off electricity to thousands of customers in the"
943948
" state."
949+
"</s>"
944950
)
945951
dct = tok.batch_encode_plus(
946952
[PGE_ARTICLE],
@@ -962,10 +968,7 @@ def test_xsum_summarization_same_as_fairseq(self):
962968
decoder_start_token_id=model.config.eos_token_id,
963969
)
964970

965-
decoded = tok.batch_decode(
966-
hypotheses_batch,
967-
skip_special_tokens=True,
968-
)
971+
decoded = tok.batch_decode(hypotheses_batch)
969972
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
970973

971974
def test_xsum_config_generation_params(self):
@@ -1189,26 +1192,32 @@ def test_cnn_summarization_same_as_fairseq(self):
11891192
assert hypotheses_batch[:, 1].eq(0).all().item()
11901193

11911194
EXPECTED = [
1195+
"</s><s>"
11921196
"A French prosecutor says he is not aware of any video footage from on board the plane. Two German "
11931197
"magazines claim to have found a cell phone video showing the crash. The publications say they watched "
11941198
"the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight "
1195-
"9525 were killed.",
1199+
"9525 were killed."
1200+
"</s>",
1201+
"</s><s>"
11961202
"Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court "
11971203
"jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the "
11981204
"Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a "
1199-
"move toward greater justice.",
1205+
"move toward greater justice."
1206+
"</s><pad><pad><pad><pad>",
1207+
"</s><s>"
12001208
"U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The "
12011209
"debate that has already begun will likely result in more heat than light. He says critics have made "
12021210
"dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a "
1203-
"nuclear weapon.",
1211+
"nuclear weapon."
1212+
"</s><pad><pad><pad>",
1213+
"</s><s>"
12041214
"Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors "
12051215
"say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the "
1206-
"Bronx on Friday. If convicted, she faces up to four years in prison.",
1216+
"Bronx on Friday. If convicted, she faces up to four years in prison."
1217+
"</s><pad><pad><pad><pad><pad>",
12071218
]
12081219

1209-
generated_summaries = tok.batch_decode(
1210-
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
1211-
)
1220+
generated_summaries = tok.batch_decode(hypotheses_batch.tolist())
12121221
assert generated_summaries == EXPECTED
12131222

12141223
@slow

tests/models/biogpt/test_modeling_biogpt.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def test_inference_lm_head_model(self):
434434
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
435435

436436
@slow
437-
def test_biogpt_generation(self):
437+
def test_biogpt_generation_beam_search(self):
438438
tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
439439
model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
440440
model.to(torch_device)
@@ -448,13 +448,15 @@ def test_biogpt_generation(self):
448448
num_beams=5,
449449
early_stopping=True,
450450
)
451-
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
451+
output_str = tokenizer.decode(output_ids[0])
452452

453453
EXPECTED_OUTPUT_STR = (
454+
"</s>"
454455
"COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the"
455456
" causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and"
456457
" territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK),"
457458
" and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and"
458-
" more than 800,000 deaths."
459+
" more than 800,000 deaths. "
460+
"</s>"
459461
)
460462
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)

tests/models/m2m_100/test_modeling_m2m_100.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,20 @@ def test_seq_to_seq_generation(self):
415415
)
416416

417417
expected_en = [
418-
"The NSA case highlights the total absence of intelligence debate",
419-
"I think there are two levels of response from the French government.",
418+
"</s> __en__ "
419+
"The NSA case highlights the total absence of intelligence debate"
420+
"</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
421+
"</s> __en__ "
422+
"I think there are two levels of response from the French government."
423+
"</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
424+
"</s> __en__ "
420425
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
421426
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
422-
" communications in France.",
427+
" communications in France."
428+
"</s>",
423429
]
424430

425-
generated = tokenizer.batch_decode(
426-
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
427-
)
431+
generated = tokenizer.batch_decode(hypotheses_batch)
428432
assert generated == expected_en
429433

430434
@require_flash_attn

tests/models/t5/test_modeling_t5.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -1475,19 +1475,27 @@ def test_summarization(self):
14751475
)
14761476

14771477
expected_summaries = [
1478+
"<pad> "
14781479
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
14791480
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
1480-
" magazine says .",
1481+
" magazine says ."
1482+
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
1483+
"<pad> "
14811484
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
14821485
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
1483-
" court, Palestinians may be subject to counter-charges as well .",
1486+
" court, Palestinians may be subject to counter-charges as well ."
1487+
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
1488+
"<pad> "
14841489
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
14851490
" the debate that has already begun since the announcement of the new framework will likely result in more"
14861491
" heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
1487-
" implement a rigorous inspection regime .",
1492+
" implement a rigorous inspection regime ."
1493+
"</s>",
1494+
"<pad> "
14881495
"prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
14891496
' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
1490-
" times, with nine of her marriages occurring between 1999 and 2002 .",
1497+
" times, with nine of her marriages occurring between 1999 and 2002 ."
1498+
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
14911499
]
14921500

14931501
use_task_specific_params(model, "summarization")
@@ -1512,11 +1520,8 @@ def test_summarization(self):
15121520
early_stopping=True,
15131521
)
15141522

1515-
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
1516-
self.assertListEqual(
1517-
expected_summaries,
1518-
decoded,
1519-
)
1523+
decoded = tok.batch_decode(hypotheses_batch)
1524+
self.assertListEqual(expected_summaries, decoded)
15201525

15211526
@slow
15221527
def test_translation_en_to_de(self):
@@ -1526,13 +1531,13 @@ def test_translation_en_to_de(self):
15261531

15271532
en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
15281533
expected_translation = (
1529-
'"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
1534+
'<pad> "Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.</s>'
15301535
)
15311536

15321537
input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
15331538
input_ids = input_ids.to(torch_device)
15341539
output = model.generate(input_ids)
1535-
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
1540+
translation = tok.decode(output[0])
15361541
self.assertEqual(translation, expected_translation)
15371542

15381543
@slow
@@ -1558,13 +1563,15 @@ def test_translation_en_to_fr(self):
15581563
do_sample=False,
15591564
early_stopping=True,
15601565
)
1561-
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
1566+
translation = tok.decode(output[0])
15621567
new_truncated_translation = (
1568+
"<pad> "
15631569
"Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
15641570
"un "
15651571
"« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées "
15661572
"sous forme "
15671573
"de points bleus."
1574+
"</s>"
15681575
)
15691576

15701577
self.assertEqual(translation, new_truncated_translation)
@@ -1575,11 +1582,13 @@ def test_translation_en_to_ro(self):
15751582
tok = self.tokenizer
15761583
use_task_specific_params(model, "translation_en_to_ro")
15771584
en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
1578-
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
1585+
expected_translation = (
1586+
"<pad> Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.</s>"
1587+
)
15791588

15801589
inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device)
15811590
output = model.generate(**inputs)
1582-
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
1591+
translation = tok.decode(output[0])
15831592
self.assertEqual(translation, expected_translation)
15841593

15851594
@slow

0 commit comments

Comments
 (0)