Skip to content

Commit

Permalink
fix idcs count
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jun 4, 2024
1 parent 2b77d40 commit e6a31e7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
40 changes: 30 additions & 10 deletions wtpsplit/evaluation/intrinsic_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@
)
from wtpsplit.utils import Constants


def split_language_data(eval_data):
new_eval_data = {}

for lang_code, lang_data in eval_data.items():
if '-' in lang_code:
lang1, lang2 = lang_code.split('-')
if "-" in lang_code:
lang1, lang2 = lang_code.split("-")
new_lang1 = f"{lang_code}_{lang1.upper()}"
new_lang2 = f"{lang_code}_{lang2.upper()}"

# Adding the same content for both new language keys
new_eval_data[new_lang1] = lang_data
new_eval_data[new_lang2] = lang_data
else:
new_eval_data[lang_code] = lang_data

return new_eval_data


Expand Down Expand Up @@ -67,12 +68,25 @@ class Args:
# if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
# print("SKIP: ", lang, dataset_name)
# continue
# if "ted2020-corrupted-asr" not in dataset_name:
# continue
if not dataset["data"]:
continue
results[lang][dataset_name] = {}
indices[lang][dataset_name] = {}
if "asr" in dataset_name and not any(
x in dataset_name for x in ["lyrics", "short", "code", "ted2020", "legal"]
):
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
continue
if "social-media" in dataset_name:
continue
if "nllb" in dataset_name:
continue

if "-" in lang:
# code-switched data: eval 2x
# code-switched data: eval 2x
lang_code = lang.split("_")[1].lower()
else:
lang_code = lang
Expand All @@ -92,13 +106,18 @@ class Args:
exclude_every_k = args.exclude_every_k
try:
if isinstance(dataset["data"][0], list):
all_sentences = [[preprocess_sentence(s) for s in doc] for doc in dataset["data"]]
# all_sentences = [[preprocess_sentence(s) for s in doc] for doc in dataset["data"]]
all_sentences = dataset["data"]
metrics = []
for i, sentences in enumerate(all_sentences):
text = Constants.SEPARATORS[lang_code].join(sentences)
doc_metrics = {}
doc_metrics = evaluate_sentences(
lang_code, sentences, f(lang_code, text), return_indices=True, exclude_every_k=exclude_every_k
lang_code,
sentences,
f(lang_code, text),
return_indices=True,
exclude_every_k=exclude_every_k,
)
f1 = doc_metrics[0]
doc_metrics = doc_metrics[1]
Expand Down Expand Up @@ -133,7 +152,8 @@ class Args:
results[lang][dataset_name][name] = avg_results
indices[lang][dataset_name][name] = concat_indices
else:
sentences = [preprocess_sentence(s) for s in dataset["data"]]
# sentences = [preprocess_sentence(s) for s in dataset["data"]]
sentences = dataset["data"]
text = Constants.SEPARATORS[lang_code].join(sentences)

metrics = evaluate_sentences(
Expand All @@ -148,7 +168,7 @@ class Args:
metrics["f1"] = f1
print(f1)
indices[lang][dataset_name][name]["true_indices"] = [metrics.pop("true_indices")]
indices[lang][dataset_name][name]["predicted_indices"] =[ metrics.pop("predicted_indices")]
indices[lang][dataset_name][name]["predicted_indices"] = [metrics.pop("predicted_indices")]
indices[lang][dataset_name][name]["length"] = [metrics.pop("length")]
results[lang][dataset_name][name] = metrics
except LanguageError as e:
Expand Down
42 changes: 30 additions & 12 deletions wtpsplit/evaluation/intrinsic_baselines_multilingual.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@
)
from wtpsplit.utils import Constants


def split_language_data(eval_data):
new_eval_data = {}

for lang_code, lang_data in eval_data.items():
if '-' in lang_code:
lang1, lang2 = lang_code.split('-')
if "-" in lang_code:
lang1, lang2 = lang_code.split("-")
new_lang1 = f"{lang_code}_{lang1.upper()}"
new_lang2 = f"{lang_code}_{lang2.upper()}"

# Adding the same content for both new language keys
new_eval_data[new_lang1] = lang_data
new_eval_data[new_lang2] = lang_data
else:
new_eval_data[lang_code] = lang_data

return new_eval_data


Expand Down Expand Up @@ -71,15 +72,26 @@ class Args:
continue
results[lang][dataset_name] = {}
indices[lang][dataset_name] = {}
if "asr" in dataset_name and not any(
x in dataset_name for x in ["lyrics", "short", "code", "ted2020", "legal"]
):
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
continue
if "social-media" in dataset_name:
continue
if "nllb" in dataset_name:
continue

if "-" in lang:
# code-switched data: eval 2x
# code-switched data: eval 2x
lang_code = lang.split("_")[1].lower()
else:
lang_code = lang

for f, name in [
(spacy_dp_sentencize, "spacy_dp"),
(spacy_sent_sentencize, "spacy_sent"),
# (spacy_sent_sentencize, "spacy_sent"),
]:
print(f"Running {name} on {dataset_name} in {lang_code}...")
indices[lang][dataset_name][name] = {}
Expand All @@ -89,13 +101,18 @@ class Args:
exclude_every_k = args.exclude_every_k
try:
if isinstance(dataset["data"][0], list):
all_sentences = [[preprocess_sentence(s) for s in doc] for doc in dataset["data"]]
# all_sentences = [[preprocess_sentence(s) for s in doc] for doc in dataset["data"]]
all_sentences = dataset["data"]
metrics = []
for i, sentences in enumerate(all_sentences):
text = Constants.SEPARATORS[lang_code].join(sentences)
doc_metrics = {}
doc_metrics = evaluate_sentences(
lang_code, sentences, f("xx", text), return_indices=True, exclude_every_k=exclude_every_k
lang_code,
sentences,
f("xx", text),
return_indices=True,
exclude_every_k=exclude_every_k,
)
f1 = doc_metrics[0]
doc_metrics = doc_metrics[1]
Expand Down Expand Up @@ -130,7 +147,8 @@ class Args:
results[lang][dataset_name][name] = avg_results
indices[lang][dataset_name][name] = concat_indices
else:
sentences = [preprocess_sentence(s) for s in dataset["data"]]
# sentences = [preprocess_sentence(s) for s in dataset["data"]]
sentences = dataset["data"]
text = Constants.SEPARATORS[lang_code].join(sentences)

metrics = evaluate_sentences(
Expand All @@ -145,7 +163,7 @@ class Args:
metrics["f1"] = f1
print(f1)
indices[lang][dataset_name][name]["true_indices"] = [metrics.pop("true_indices")]
indices[lang][dataset_name][name]["predicted_indices"] =[ metrics.pop("predicted_indices")]
indices[lang][dataset_name][name]["predicted_indices"] = [metrics.pop("predicted_indices")]
indices[lang][dataset_name][name]["length"] = [metrics.pop("length")]
results[lang][dataset_name][name] = metrics
except LanguageError as l:
Expand All @@ -154,4 +172,4 @@ class Args:

json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines_multi.json", "w"), indent=4, default=int)
json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_multi_IDX.json", "w"), indent=4, default=int)
print(Constants.CACHE_DIR / "intrinsic_baselines.json")
print(Constants.CACHE_DIR / "intrinsic_baselines_multi.json")

0 comments on commit e6a31e7

Please sign in to comment.