Skip to content

Commit

Permalink
Add/bleu tokenizer (#9)
Browse files Browse the repository at this point in the history
* add: bleu with tokenizer
* add: hf evaluation test
  • Loading branch information
shenxiangzhuang authored Apr 22, 2024
1 parent 87ab215 commit beccf8e
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 41 deletions.
48 changes: 48 additions & 0 deletions benchmark/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import List
import time

import evaluate
import bleuscore


def hf_bleu(references: List[List[str]], predictions: List[str]):
# Compute BLEU score
bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=references)
print(results)
return results


def rust_bleu(references: List[List[str]], predictions: List[str]):
rust_result = bleuscore.compute_bleu(reference_corpus=references,
translation_corpus=predictions,
max_order=4,
smooth=False)
print(rust_result)
return rust_result


def compare():
references = [
["Oh, hello bleu score from rust!",
"Oh, hello bleu score from python!"],
]
predictions = ["hello bleu score from"]
t0 = time.time()
n_times = 10
for _ in range(n_times):
hf_bleu(references, predictions)
t1 = time.time()
for _ in range(n_times):
rust_bleu(references, predictions)
t2 = time.time()

hf_py_bleu_spend_seconds = t1 - t0
rust_bleu_spend_seconds = t2 - t1
print(f"hf py: {hf_py_bleu_spend_seconds}s\n"
f"rust: {rust_bleu_spend_seconds}s\n"
f"fast: {hf_py_bleu_spend_seconds / rust_bleu_spend_seconds:.2f}")


if __name__ == "__main__":
compare()
12 changes: 11 additions & 1 deletion benchmark/py_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import collections
import math

from py_token import Tokenizer13a



def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Expand Down Expand Up @@ -65,6 +68,13 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0

tokenizer = Tokenizer13a()
reference_corpus = [[tokenizer(r) for r in ref] for ref in reference_corpus]
translation_corpus = [tokenizer(p) for p in translation_corpus]
# print(f"translation_corpus: {translation_corpus}\n"
# f"reference_corpus: {reference_corpus}")

for (references, translation) in zip(reference_corpus,
translation_corpus):
reference_length += min(len(r) for r in references)
Expand Down Expand Up @@ -113,5 +123,5 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,


if __name__ == "__main__":
res = compute_bleu(reference_corpus=[["Hello"]], translation_corpus=["Yellow"], max_order=4, smooth=True)
res = compute_bleu(reference_corpus=[["Hello, World!"]], translation_corpus=["Yellow, World!"], max_order=4, smooth=True)
print(res)
2 changes: 1 addition & 1 deletion benchmark/py_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __call__(self, line):
:return: the tokenized line
"""
for (_re, repl) in self._re:
print(line)
# print(line)
line = _re.sub(repl, line)

# no leading or trailing spaces, single space within words
Expand Down
4 changes: 2 additions & 2 deletions benchmark/test_benchmark_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def test_bleu(input_text):
translation_corpus=predictions,
max_order=max_order,
smooth=smooth)
print(rust_result)
# print(rust_result)
rust_result = rust_result.get("bleu")
t2 = time.time()
# print(t1 - t0, t2 - t1, (t1 - t0) > (t2 - t1))
print(t1 - t0, t2 - t1, (t1 - t0) > (t2 - t1))
print(py_result, rust_result, abs(py_result - rust_result))
assert abs(py_result - rust_result) < 1e-10
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Homepage = 'https://github.com/shenxiangzhuang/bleuscore'
Source = 'https://github.com/shenxiangzhuang/bleuscore'

[project.optional-dependencies]
test = ["pytest", "hypothesis"]
test = ["pytest", "hypothesis", "evaluate"]
lint = ["black", "ruff~=0.3.7"]
#docs = []
#dev = []
Expand Down
77 changes: 42 additions & 35 deletions src/bleu.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use counter::Counter;
use crate::ngram::get_ngram_counter;
use std::cmp::min;
use std::collections::HashMap;
use crate::ngram::{get_token_ngram_counter};
use crate::tokenizer::{Tokenizer, Tokenizer13a};


#[derive(Debug, Default)]
pub struct BleuScore {
Expand All @@ -23,22 +26,47 @@ pub fn compute_bleu(
let mut references_length: usize = 0;
let mut translation_length: usize = 0;

let tokenizer = Tokenizer13a::new();

for (references, translation) in
reference_corpus.iter().zip(translation_corpus.iter()) {
references_length += references.iter().map(|x| x.len()).min().unwrap();
translation_length += translation.len();
let translation_ngram_counts = get_ngram_counter(translation, max_order);
let mut merged_ref_ngram_counts = Counter::new();
for reference in references {
merged_ref_ngram_counts |= get_ngram_counter(&reference, max_order);

// tokenize
let translation_tokens = tokenizer.tokenize(translation);
let references_tokens: Vec<Vec<String>> = references.iter().map(|x| tokenizer.tokenize(x)).collect();
// println!("translation_tokens: {:?}\nreferences_tokens: {:?}", translation_tokens, references_tokens);

references_length += references_tokens.iter().map(|x| x.len()).min().unwrap();
translation_length += translation_tokens.len();
let translation_ngram_counts = get_token_ngram_counter(&translation_tokens, max_order);
let mut merged_ref_ngram_counts = HashMap::new();
for reference_tokens in references_tokens {
let reference_ngram_counts = get_token_ngram_counter(&reference_tokens, max_order);
for (key, value) in reference_ngram_counts {
merged_ref_ngram_counts.entry(key).and_modify(|v| *v += value).or_insert(value);
}
}

// let overlap: Vec<String> = merged_ref_ngram_counts.keys().filter(|&key| translation_ngram_counts.contains_key(key)).cloned().collect();
let mut overlap_counts = HashMap::new();
for (k, v) in translation_ngram_counts {
let key = k.clone();
if merged_ref_ngram_counts.contains_key(&key) {
overlap_counts.insert(k, min(merged_ref_ngram_counts[&key], v));
// println!("({}, {}): trans: {}; ref: {}", key.0, key.1, v, merged_ref_ngram_counts[&key]);
}
else {
continue
}
}
let overlap = translation_ngram_counts & merged_ref_ngram_counts;

for ngram in overlap.keys() {
matches_by_order[ngram.len() - 1] += overlap[ngram]
for key in overlap_counts.keys() {
let (_, order) = key;
matches_by_order[order - 1] += overlap_counts[&key];
// println!("order: {order}, match: {}", matches_by_order[order - 1]);
}
for order in 1..=max_order {
let possible_matches = translation.len().saturating_sub(order - 1);
let possible_matches = translation_tokens.len().saturating_sub(order - 1);
if possible_matches > 0 {
// println!("Order: {order}");
possible_matches_by_order[order - 1] += possible_matches
Expand Down Expand Up @@ -91,29 +119,8 @@ mod test {
let max_order: usize = 4;
let smooth: bool = true;
let res = compute_bleu(reference_corpus, translation_corpus, max_order, smooth);
// (0.7241577342575828, [0.8666666666666667, 0.7857142857142857, 0.6923076923076923, 0.5833333333333334], 1.0, 1.0769230769230769, 14, 13)
println!("BLEU: {:?}", res);
assert_eq!((res.bleu - 0.7241577342575828).abs() < 1e-10, true);
}

#[test]
fn test_bleu_error() {
let reference_corpus: Vec<Vec<String>> = vec![
vec!["0000000000".to_string()],
vec!["0000000000".to_string()],
vec!["0000000000".to_string()],
vec!["0000000000".to_string()],
];
let translation_corpus: Vec<String> = vec!["000000".to_string(),
"00000".to_string(),
"0000000000".to_string(),
"00".to_string()
];
let max_order: usize = 4;
let smooth: bool = true;
let res = compute_bleu(reference_corpus, translation_corpus, max_order, smooth);
// (0.7241577342575828, [0.8666666666666667, 0.7857142857142857, 0.6923076923076923, 0.5833333333333334], 1.0, 1.0769230769230769, 14, 13)
// (0.668740304976422, [0.8, 0.75, 0.6666666666666666, 0.5], 1.0, 1.0, 4, 4)
println!("BLEU: {:?}", res);
assert_eq!((res.bleu - 0.47752897762233404).abs() < 1e-10, true);
assert_eq!((res.bleu - 0.668740304976422).abs() < 1e-10, true);
}
}
25 changes: 24 additions & 1 deletion src/ngram.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use counter::Counter;


Expand All @@ -15,9 +16,22 @@ pub fn get_ngram_counter(line: &str, max_order: usize) -> Counter<&str> {
}


pub fn get_token_ngram_counter(tokens: &Vec<String>, max_order: usize) -> HashMap<(String, usize), usize> {
let mut count_map: HashMap<(String, usize), usize> = HashMap::new();
for order in 1..=max_order {
for start_index in 0..(tokens.len().saturating_sub(order - 1)) {
// println!("line: {}, start_index: {}, order: {}", line, start_index, order);
let ngram = tokens[start_index..(start_index + order)].join("");
// println!("ngram: {}", ngram);
count_map.entry((ngram, order)).and_modify(|counter| *counter += 1).or_insert(1);
}
}
count_map
}

#[cfg(test)]
mod test {
use crate::ngram::{get_ngram_counter};
use crate::ngram::{get_ngram_counter, get_token_ngram_counter};

#[test]
fn test_get_ngram() {
Expand All @@ -43,4 +57,13 @@ mod test {
assert_eq!(counter[&"b"], 1);
assert_eq!(counter[&"ab"], 1);
}

#[test]
fn test_get_token_ngram_short() {
let tokens = vec!["a".to_string(), "b".to_string()];
let counter = get_token_ngram_counter(&tokens,4);
assert_eq!(counter[&("a".to_string(), 1)], 1);
assert_eq!(counter[&("b".to_string(), 1)], 1);
assert_eq!(counter[&("ab".to_string(), 2)], 1);
}
}

0 comments on commit beccf8e

Please sign in to comment.