Skip to content

Commit

Permalink
Slight BERTScore gallery modifications for accuracy and styling (#2642)
Browse files Browse the repository at this point in the history
Slight modifications for accuracy and styling
  • Loading branch information
baskrahmer authored Jul 25, 2024
1 parent d8235b2 commit eaa85e0
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions examples/text/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,38 @@
# %%
# Define the prompt and target texts

prompt = "Economic recovery is underway with a 3.5% GDP growth and a decrease in unemployment. Experts forecast continued improvement with boosts from consumer spending and government projects."
target_text = "The economy is recovering, with GDP growth at 3.5% and unemployment at a two-year low. Experts expect this trend to continue due to higher consumer spending and government infrastructure investments."
prompt = "Economic recovery is underway with a 3.5% GDP growth and a decrease in unemployment. Experts forecast continued improvement with boosts from consumer spending and government projects. In summary: "
target_summary = "the recession is ending."

# %%
# Generate a sample text using the GPT-2 model

generated_text = pipe(prompt, max_new_tokens=20, do_sample=False, temperature=0, pad_token_id=tokenizer.eos_token_id)[
0
]["generated_text"][len(prompt) :]
generated_summary = pipe(prompt, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.eos_token_id)[0][
"generated_text"
][len(prompt) :].strip()

# %%
# Calculate the BERTScore of the generated text

bertscore = BERTScore(model_name_or_path="roberta-base")
score = bertscore(preds=[generated_text], target=[target_text])
score = bertscore(preds=[generated_summary], target=[target_summary])

print(f"Prompt: {prompt}")
print(f"Target Text: {target_text}")
print(f"Generated Text: {generated_text}")
print(f"BERTScore: {score['f1']}")
print(f"Target summary: {target_summary}")
print(f"Generated summary: {generated_summary}")
print(f"BERTScore: {score['f1']:.4f}")

# %%
# In addition, to illustrate BERTScore's robustness to paraphrasing, let's consider two candidate texts that are variations of the reference text.
reference = "the weather is freezing"
candidate_good = "it is cold today"
candidate_bad = "it is warm outside"

# %%
# Here we see that using the BERTScore we are able to differentiate between the candidate texts based on their similarity to the reference text, whereas the ROUGE scores for the same text pairs are identical.
rouge = ROUGEScore()
bertscore = BERTScore(model_name_or_path="roberta-base")

print("ROUGE for candidate_good:", rouge(preds=[candidate_good], target=[reference])["rouge1_fmeasure"])
print("ROUGE for candidate_bad:", rouge(preds=[candidate_bad], target=[reference])["rouge1_fmeasure"])
print("BERTScore for candidate_good:", bertscore(preds=[candidate_good], target=[reference])["f1"])
print("BERTScore for candidate_bad:", bertscore(preds=[candidate_bad], target=[reference])["f1"])
print(f"ROUGE for candidate_good: {rouge(preds=[candidate_good], target=[reference])['rouge1_fmeasure'].item()}")
print(f"ROUGE for candidate_bad: {rouge(preds=[candidate_bad], target=[reference])['rouge1_fmeasure'].item()}")
print(f"BERTScore for candidate_good: {bertscore(preds=[candidate_good], target=[reference])['f1'].item():.4f}")
print(f"BERTScore for candidate_bad: {bertscore(preds=[candidate_bad], target=[reference])['f1'].item():.4f}")

0 comments on commit eaa85e0

Please sign in to comment.