Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -785,15 +785,15 @@ def test_bert2bert_summarization(self):
EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""

input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])

# Test with the TF checkpoint
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")

output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
Expand Down Expand Up @@ -887,7 +887,7 @@ def test_bert2gpt2_summarization(self):
EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption."""

input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
Expand Down