From bdf6952033020f38d7e03b8842828507004df2da Mon Sep 17 00:00:00 2001 From: igorsterner Date: Sat, 8 Jun 2024 18:00:56 +0100 Subject: [PATCH] SentSeg data, +SM training and permutation tests --- wtpsplit/data_acquisition/extract_all_data.py | 480 +++++++++++++++--- wtpsplit/evaluation/permutation_test.py | 249 +++++---- wtpsplit/evaluation/permutation_test_data.py | 247 +++++++++ wtpsplit/evaluation/permutation_test_utils.py | 127 +++-- wtpsplit/train/train_FT.py | 296 +++++------ 5 files changed, 1045 insertions(+), 354 deletions(-) create mode 100644 wtpsplit/evaluation/permutation_test_data.py diff --git a/wtpsplit/data_acquisition/extract_all_data.py b/wtpsplit/data_acquisition/extract_all_data.py index af12e5c8..6ec71db1 100644 --- a/wtpsplit/data_acquisition/extract_all_data.py +++ b/wtpsplit/data_acquisition/extract_all_data.py @@ -20,7 +20,6 @@ from wtpsplit.evaluation import preprocess_sentence from wtpsplit.utils import Constants - UD_TREEBANK_PATH = "../data/ud-treebanks-v2.13" # source: https://universaldependencies.org/#download ERSATZ_DATA_PATH = "../data/ersatz-test-suite/segmented" # source: https://github.com/rewicks/ersatz-test-suite @@ -80,7 +79,7 @@ punct_chars = set(Constants.PUNCTUATION_CHARS) -def corrupt(sentences, lang): +def corrupt_asr(sentences, lang): if sentences is None: return None @@ -116,9 +115,33 @@ def corrupt(sentences, lang): return corrupted_sentences +def corrupt_social_media(sentences, lang): + + if sentences is None: + return None + + corrupted_sentences = [] + for sentence in sentences: + if random.random() < 0.5: + sentence = "".join([char for char in sentence if char not in punct_chars]) + if random.random() < 0.5: + sentence = sentence.lower() + + for punct in punct_chars: + count = 0 + while random.random() < 0.5: + count += 1 + sentence = sentence.replace(punct, punct * count) + + sentence = preprocess_sentence(sentence) + corrupted_sentences.append(sentence) + + return corrupted_sentences + + @dataclass class Args: - output_file: str = "../data/preprocessed_training_data/all_data_02_05.pth" + output_file: str = "../data/preprocessed_training_data/all_data_11_05.pth" include_train_data: bool = True cache_dir: str = "../data/cache/" @@ -128,7 +151,7 @@ class Args: eval_data = {lang_code: {"sentence": {}, "compound": {}} for lang_code in Constants.LANGINFO.index} - # # Ersatz data + # Ersatz data for lang_code in tqdm(Constants.LANGINFO.index): if lang_code in ERSATZ_TEST_DATASETS: eval_data[lang_code]["sentence"]["ersatz"] = { @@ -154,7 +177,45 @@ class Args: ], } - # UD + OPUS100 sentences + TED + eval_data[lang_code]["sentence"]["ersatz-corrupted-asr"] = { + "meta": { + "train_data": ( + corrupt_asr( + ( + eval_data[lang_code]["sentence"]["ersatz"]["meta"]["train_data"][:10000] + if eval_data[lang_code]["sentence"]["ersatz"]["meta"]["train_data"] is not None + else None + ), + lang_code, + ) + ) + }, + "data": corrupt_asr( + eval_data[lang_code]["sentence"]["ersatz"]["data"][:10000], + lang_code, + ), + } + + eval_data[lang_code]["sentence"]["ersatz-corrupted-social-media"] = { + "meta": { + "train_data": ( + corrupt_social_media( + ( + eval_data[lang_code]["sentence"]["ersatz"]["meta"]["train_data"][:10000] + if eval_data[lang_code]["sentence"]["ersatz"]["meta"]["train_data"] is not None + else None + ), + lang_code, + ) + ) + }, + "data": corrupt_social_media( + eval_data[lang_code]["sentence"]["ersatz"]["data"][:10000], + lang_code, + ), + } + + # UD + OPUS100 sentences + TED + NLLB for lang_code in tqdm(Constants.LANGINFO.index): opus_dset_name = Constants.LANGINFO.loc[lang_code, "opus100"] @@ -187,23 +248,39 @@ class Args: ] opus100_train_sentences = None + opus100_train_sentences = opus100_train_sentences[:10000] if opus100_train_sentences is not None else None + eval_data[lang_code]["sentence"]["opus100"] = { "meta": {"train_data": (opus100_train_sentences if args.include_train_data else None)}, "data": opus100_sentences, } - corrupted_opus100_train_sentences = ( - opus100_train_sentences[:10000] if opus100_train_sentences is not None else None - ) - corrupted_opus100_sentences = opus100_sentences[:10000] - - corrupted_opus100_sentences = corrupt(corrupted_opus100_sentences, lang_code) - - corrupted_opus100_train_sentences = corrupt(opus100_train_sentences, lang_code) + eval_data[lang_code]["sentence"]["opus100-corrupted-asr"] = { + "meta": { + "train_data": ( + corrupt_asr( + (opus100_train_sentences), + lang_code, + ) + if args.include_train_data + else None + ) + }, + "data": corrupt_asr(opus100_sentences[:10000], lang_code), + } - eval_data[lang_code]["sentence"]["opus100-corrupted"] = { - "meta": {"train_data": (corrupted_opus100_train_sentences if args.include_train_data else None)}, - "data": corrupted_opus100_sentences, + eval_data[lang_code]["sentence"]["opus100-corrupted-social-media"] = { + "meta": { + "train_data": ( + corrupt_social_media( + (opus100_train_sentences), + lang_code, + ) + if args.include_train_data + else None + ) + }, + "data": corrupt_social_media(opus100_sentences[:10000], lang_code), } if Constants.LANGINFO.loc[lang_code, "ud"] not in (np.nan, None): @@ -231,19 +308,33 @@ class Args: )[0] ).read() ) - ud_train_sentences = [preprocess_sentence(sentence.metadata["text"]) for sentence in ud_train_data] + ud_train_sentences = [preprocess_sentence(sentence.metadata["text"]) for sentence in ud_train_data][ + :10000 + ] except IndexError: ud_train_sentences = None ud_sentences = [preprocess_sentence(sentence.metadata["text"]) for sentence in ud_data] + eval_data[lang_code]["sentence"]["ud"] = { "meta": {"train_data": (ud_train_sentences if args.include_train_data else None)}, "data": ud_sentences, } - eval_data[lang_code]["sentence"]["ud-corrupted"] = { - "meta": {"train_data": (corrupt(ud_train_sentences, lang_code) if args.include_train_data else None)}, - "data": corrupt(ud_sentences, lang_code), + eval_data[lang_code]["sentence"]["ud-corrupted-asr"] = { + "meta": { + "train_data": (corrupt_asr(ud_train_sentences, lang_code) if args.include_train_data else None) + }, + "data": corrupt_asr(ud_sentences, lang_code), + } + + eval_data[lang_code]["sentence"]["ud-corrupted-social-media"] = { + "meta": { + "train_data": ( + corrupt_social_media(ud_train_sentences, lang_code) if args.include_train_data else None + ) + }, + "data": corrupt_social_media(ud_sentences, lang_code), } # TED 2020 @@ -258,19 +349,64 @@ class Args: sentences = [preprocess_sentence(sentence) for sentence in sentences] - corrupted_sentences = corrupt(sentences, lang_code) + train_sentences = sentences[: len(sentences) // 2] + test_sentences = sentences[len(sentences) // 2 :] + + eval_data[lang_code]["sentence"]["ted2020-corrupted-asr"] = { + "meta": {"train_data": (corrupt_asr(train_sentences, lang_code) if args.include_train_data else None)}, + "data": corrupt_asr(test_sentences, lang_code), + } + + eval_data[lang_code]["sentence"]["ted2020-corrupted-social-media"] = { + "meta": { + "train_data": ( + corrupt_social_media(train_sentences, lang_code) if args.include_train_data else None + ) + }, + "data": corrupt_social_media(test_sentences, lang_code), + } else: + print(f"Failed to download TED2020 data for {lang_code}") - train_sentences = corrupted_sentences[: len(sentences) // 2] - test_sentences = corrupted_sentences[len(sentences) // 2 :] + for lang_code in ["ceb", "jv", "mn", "yo"]: + url = f"https://object.pouta.csc.fi/OPUS-NLLB/v1/mono/{lang_code}.txt.gz" + res = requests.get(url) + + if res.status_code == 200: + with gzip.open(BytesIO(res.content), "rt", encoding="utf-8") as f: + sentences = f.read().splitlines() + + random.shuffle(sentences) # because they come alphabetically sorted + + sentences = sentences[:20000] + + sentences = [preprocess_sentence(sentence) for sentence in sentences] - eval_data[lang_code]["sentence"]["ted2020-corrupted"] = { + else: + raise Exception + + train_sentences = sentences[: len(sentences) // 2] + test_sentences = sentences[len(sentences) // 2 :] + + eval_data[lang_code]["sentence"]["nllb"] = { "meta": {"train_data": (train_sentences if args.include_train_data else None)}, "data": test_sentences, } + eval_data[lang_code]["sentence"]["nllb-corrupted-asr"] = { + "meta": {"train_data": (corrupt_asr(train_sentences, lang_code) if args.include_train_data else None)}, + "data": corrupt_asr(test_sentences, lang_code), + } + + eval_data[lang_code]["sentence"]["nllb-corrupted-social-media"] = { + "meta": { + "train_data": (corrupt_social_media(train_sentences, lang_code) if args.include_train_data else None) + }, + "data": corrupt_social_media(test_sentences, lang_code), + } + # UD Code-Switching Corpora # UD_Turkish_German-SAGT @@ -310,9 +446,14 @@ class Args: "data": ud_test_sentences, } - eval_data["tr-de"]["sentence"]["code-switching-corrupted"] = { - "meta": {"train_data": corrupt(ud_train_sentences, "en")}, - "data": corrupt(ud_test_sentences, "en"), + eval_data["tr-de"]["sentence"]["code-switching-corrupted-asr"] = { + "meta": {"train_data": corrupt_asr(ud_train_sentences, "en")}, + "data": corrupt_asr(ud_test_sentences, "en"), + } + + eval_data["tr-de"]["sentence"]["code-switching-corrupted-social-media"] = { + "meta": {"train_data": corrupt_social_media(ud_train_sentences, "en")}, + "data": corrupt_social_media(ud_test_sentences, "en"), } # UD_Spanish_English-Miami @@ -341,14 +482,17 @@ class Args: "data": ud_test_sentences, } - eval_data["es-en"]["sentence"]["code-switching-corrupted"] = { - "meta": {"train_data": corrupt(ud_train_sentences, "es")}, - "data": corrupt(ud_test_sentences, "es"), + eval_data["es-en"]["sentence"]["code-switching-corrupted-asr"] = { + "meta": {"train_data": corrupt_asr(ud_train_sentences, "es")}, + "data": corrupt_asr(ud_test_sentences, "es"), } - # Vietnamese--English + eval_data["es-en"]["sentence"]["code-switching-corrupted-social-media"] = { + "meta": {"train_data": corrupt_social_media(ud_train_sentences, "es")}, + "data": corrupt_social_media(ud_test_sentences, "es"), + } - # no headers + # Vietnamese--English canvec_corpus = pd.read_excel("../data/vietnamese-english/CanVEC_CSW.xlsx", header=None) # sentences are the first columnn @@ -367,9 +511,14 @@ class Args: "data": test_sentences, } - eval_data["vi-en"]["sentence"]["code-switching-corrupted"] = { - "meta": {"train_data": corrupt(train_sentences, "vi")}, - "data": corrupt(test_sentences, "vi"), + eval_data["vi-en"]["sentence"]["code-switching-corrupted-asr"] = { + "meta": {"train_data": corrupt_asr(train_sentences, "vi")}, + "data": corrupt_asr(test_sentences, "vi"), + } + + eval_data["vi-en"]["sentence"]["code-switching-corrupted-social-media"] = { + "meta": {"train_data": corrupt_social_media(train_sentences, "vi")}, + "data": corrupt_social_media(test_sentences, "vi"), } # Denglisch @@ -425,22 +574,16 @@ class Args: "data": denglisch_test_sentences, } - eval_data["en-de"]["sentence"]["code-switching-corrupted"] = { - "meta": {"train_data": corrupt(denglisch_train_sentences, "de")}, - "data": corrupt(denglisch_test_sentences, "de"), + eval_data["en-de"]["sentence"]["code-switching-corrupted-asr"] = { + "meta": {"train_data": corrupt_asr(denglisch_train_sentences, "de")}, + "data": corrupt_asr(denglisch_test_sentences, "de"), } - # keep only if a 1 or 2 in labels of any sentence in a post - - denglisch_test_posts = [ - post for post in denglisch_test_posts if any(("1" in labels and "2" in labels) for sentence, labels in post) - ] - - denglisch_train_posts = [ - post for post in denglisch_train_posts if any(("1" in labels and "2" in labels) for sentence, labels in post) - ] + eval_data["en-de"]["sentence"]["code-switching-corrupted-social-media"] = { + "meta": {"train_data": corrupt_social_media(denglisch_train_sentences, "de")}, + "data": corrupt_social_media(denglisch_test_sentences, "de"), + } - # remove labels denglisch_test_posts = [[sentence for sentence, labels in post] for post in denglisch_test_posts] denglisch_train_posts = [[sentence for sentence, labels in post] for post in denglisch_train_posts] @@ -449,9 +592,14 @@ class Args: "data": denglisch_test_posts, } - eval_data["en-de"]["sentence"]["short-sequences-corrupted"] = { - "meta": {"train_data": [corrupt(s, "de") for s in denglisch_train_posts]}, - "data": [corrupt(s, "de") for s in denglisch_test_posts], + eval_data["en-de"]["sentence"]["short-sequences-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, "de") for s in denglisch_train_posts]}, + "data": [corrupt_asr(s, "de") for s in denglisch_test_posts], + } + + eval_data["en-de"]["sentence"]["short-sequences-corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, "de") for s in denglisch_train_posts]}, + "data": [corrupt_social_media(s, "de") for s in denglisch_test_posts], } # Short sequences @@ -486,7 +634,6 @@ class Args: if tweet_sentences: serbian_test_tweets.append(tweet_sentences) - # keep only if more than one sentence in a tweet serbian_train_tweets = [tweet for tweet in serbian_train_tweets if len(tweet) > 1] serbian_test_tweets = [tweet for tweet in serbian_test_tweets if len(tweet) > 1] @@ -495,9 +642,14 @@ class Args: "data": serbian_test_tweets, } - eval_data["sr"]["sentence"]["short-sequences-corrupted"] = { - "meta": {"train_data": [corrupt(s, "sr") for s in serbian_train_tweets]}, - "data": [corrupt(s, "sr") for s in serbian_test_tweets], + eval_data["sr"]["sentence"]["short-sequences-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, "sr") for s in serbian_train_tweets]}, + "data": [corrupt_asr(s, "sr") for s in serbian_test_tweets], + } + + eval_data["sr"]["sentence"]["short-sequences-corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, "sr") for s in serbian_train_tweets]}, + "data": [corrupt_social_media(s, "sr") for s in serbian_test_tweets], } # slovenian @@ -531,9 +683,14 @@ class Args: "data": slovenian_test_tweets, } - eval_data["sl"]["sentence"]["short-sequences-corrupted"] = { - "meta": {"train_data": [corrupt(s, "sl") for s in slovenian_train_tweeets]}, - "data": [corrupt(s, "sl") for s in slovenian_test_tweets], + eval_data["sl"]["sentence"]["short-sequences-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, "sl") for s in slovenian_train_tweeets]}, + "data": [corrupt_asr(s, "sl") for s in slovenian_test_tweets], + } + + eval_data["sl"]["sentence"]["short-sequences-corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, "sl") for s in slovenian_train_tweeets]}, + "data": [corrupt_social_media(s, "sl") for s in slovenian_test_tweets], } # estonian @@ -569,12 +726,207 @@ class Args: "data": estonian_test_paragraphs, } - eval_data["et"]["sentence"]["short-sequences-corrupted"] = { - "meta": {"train_data": [corrupt(s, "et") for s in estonian_train_paragraphs]}, - "data": [corrupt(s, "et") for s in estonian_test_paragraphs], + eval_data["et"]["sentence"]["short-sequences-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, "et") for s in estonian_train_paragraphs]}, + "data": [corrupt_asr(s, "et") for s in estonian_test_paragraphs], + } + + eval_data["et"]["sentence"]["short-sequences-corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, "et") for s in estonian_train_paragraphs]}, + "data": [corrupt_social_media(s, "et") for s in estonian_test_paragraphs], + } + + langs = ["de", "en", "es", "fr", "it", "pt"] + + all_subset_data = { + lang: { + "laws": {"train": [], "test": []}, + "judgements": {"train": [], "test": []}, + } + for lang in langs } - with open("../data/preprocessed_training_data/all_data_02_05.json", "w") as f: - json.dump(eval_data, f, indent=4, ensure_ascii=False) + for lang in tqdm(langs, desc="Legal data"): + data_dir = f"../data/MultiLegalSBD/data/{lang}/gold/" + + if lang == "pt": + all_files = glob.glob(f"{data_dir}/*.jsonl") + subsets = [file.split("/")[-1].split(".jsonl")[0] for file in all_files] + else: + all_files = glob.glob(f"{data_dir}/*_test.jsonl") + subsets = [file.split("/")[-1].split("_test.jsonl")[0] for file in all_files] + + for subset in subsets: + + if subset == "Constitution": + continue + + if lang == "pt": + train_data = [None] + else: + + train_data = [] + + with open( + f"../data/MultiLegalSBD/data/{lang}/gold/{subset}_train.jsonl", + "r", + encoding="utf-8", + ) as f: + for line in f: + train_data.append(json.loads(line)) + + train_subset_sentences = [] + for doc in train_data: + doc_sentences = [] + text = doc["text"] + for span in doc["spans"]: + sentence = text[span["start"] : span["end"]] + doc_sentences.append(preprocess_sentence(sentence)) + train_subset_sentences.append(doc_sentences) + + test_data = [] + + if lang == "pt": + test_data_file = f"../data/MultiLegalSBD/data/{lang}/gold/{subset}.jsonl" + else: + test_data_file = f"../data/MultiLegalSBD/data/{lang}/gold/{subset}_test.jsonl" + + with open( + test_data_file, + "r", + encoding="utf-8", + ) as f: + for line in f: + test_data.append(json.loads(line)) + + test_subset_sentences = [] + for doc in test_data: + doc_sentences = [] + text = doc["text"] + for span in doc["spans"]: + sentence = text[span["start"] : span["end"]] + doc_sentences.append(preprocess_sentence(sentence)) + test_subset_sentences.append(doc_sentences) + + eval_data[lang]["sentence"][f"legal-{subset}"] = { + "meta": {"train_data": train_subset_sentences}, + "data": test_subset_sentences, + } + + eval_data[lang]["sentence"][f"legal-{subset}-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, lang) for s in train_subset_sentences]}, + "data": [corrupt_asr(s, lang) for s in test_subset_sentences], + } + + eval_data[lang]["sentence"][f"legal-{subset}-corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, lang) for s in train_subset_sentences]}, + "data": [corrupt_social_media(s, lang) for s in test_subset_sentences], + } + + subsets2set = { + "CD_jug": "judgements", + "gesCode": "laws", + "CD_multi_legal": "judgements", + "CD_wipolex": "judgements", + "CivilCode": "laws", + "CriminalCode": "laws", + "CD_swiss_judgement": "judgements", + } + + if lang != "en": + set = subsets2set[subset] + else: + set = "judgements" + + all_subset_data[lang][set]["train"].extend(train_subset_sentences) + all_subset_data[lang][set]["test"].extend(test_subset_sentences) + + # constitution + + if lang in ["de", "en"]: + continue + + test_data = [] + test_data_file = f"../data/MultiLegalSBD/data/{lang}/gold/Constitution.jsonl" + with open( + test_data_file, + "r", + encoding="utf-8", + ) as f: + for line in f: + test_data.append(json.loads(line)) + + test_subset_sentences = [] + for doc in test_data: + doc_sentences = [] + text = doc["text"] + for span in doc["spans"]: + sentence = text[span["start"] : span["end"]] + doc_sentences.append(preprocess_sentence(sentence)) + test_subset_sentences.append(doc_sentences) + + eval_data[lang]["sentence"][f"legal-constitution"] = { + "meta": {"train_data": None}, + "data": test_subset_sentences, + } + + eval_data[lang]["sentence"][f"legal-constitution-corrupted-asr"] = { + "meta": {"train_data": None}, + "data": [corrupt_asr(s, lang) for s in test_subset_sentences], + } + + eval_data[lang]["sentence"][f"legal-constitution-corrupted-social-media"] = { + "meta": {"train_data": None}, + "data": [corrupt_social_media(s, lang) for s in test_subset_sentences], + } + + all_subset_data[lang]["laws"]["test"].extend(test_subset_sentences) + + for lang in all_subset_data: + for set in ["laws", "judgements"]: + eval_data[lang]["sentence"][f"legal-all-{set}"] = { + "meta": {"train_data": all_subset_data[lang][set]["train"]}, + "data": all_subset_data[lang][set]["test"], + } + + eval_data[lang]["sentence"][f"legal-all-{set}-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, lang) for s in all_subset_data[lang][set]["train"]]}, + "data": [corrupt_asr(s, lang) for s in all_subset_data[lang][set]["test"]], + } + + eval_data[lang]["sentence"][f"legal-all-{set}-corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, lang) for s in all_subset_data[lang][set]["train"]]}, + "data": [corrupt_social_media(s, lang) for s in all_subset_data[lang][set]["test"]], + } + + torch.save(eval_data, args.output_file.replace(".pth", "-all.pth")) + + eval_data = {} + eval_data["en"] = {} + eval_data["en"]["sentence"] = {} + + lyric_data = torch.load("../data/mldbW_verses.pth") + + for lang in lyric_data.keys(): + for genre in lyric_data[lang]["sentence"].keys(): + train_data = lyric_data[lang]["sentence"][genre]["meta"]["train_data"] + train_data = [[verse for verse in song if verse != ""] for song in train_data] + test_data = lyric_data[lang]["sentence"][genre]["data"] + test_data = [[verse for verse in song if verse != ""] for song in test_data] + + eval_data[lang]["sentence"]["lyrics-" + genre] = { + "meta": {"train_data": train_data}, + "data": test_data, + } + + eval_data[lang]["sentence"]["lyrics-" + genre + "-corrupted-asr"] = { + "meta": {"train_data": [corrupt_asr(s, lang) for s in train_data]}, + "data": [corrupt_asr(s, lang) for s in test_data], + } + + eval_data[lang]["sentence"]["lyrics-" + genre + "corrupted-social-media"] = { + "meta": {"train_data": [corrupt_social_media(s, lang) for s in train_data]}, + "data": [corrupt_social_media(s, lang) for s in test_data], + } - torch.save(eval_data, args.output_file) + torch.save(eval_data, args.output_file.replace(".pth", "-lyrics.pth")) diff --git a/wtpsplit/evaluation/permutation_test.py b/wtpsplit/evaluation/permutation_test.py index 2da807c0..c825c371 100644 --- a/wtpsplit/evaluation/permutation_test.py +++ b/wtpsplit/evaluation/permutation_test.py @@ -1,112 +1,161 @@ -import json +import argparse +import pickle from collections import defaultdict from pathlib import Path import numpy as np import pandas as pd + from wtpsplit.evaluation.permutation_test_utils import compute_prf, permutation_test, print_latex, reverse_where -from tqdm import tqdm -ALL_DIR = Path("data/permutation-test-data/03-05") +parser = argparse.ArgumentParser() + +parser.add_argument("--lang", type=str, required=True) +parser.add_argument("--table", type=str, required=True) + +args = parser.parse_args() + +ALL_DIR = Path("../data/permutation-test-data/") raw_data = defaultdict(lambda: defaultdict(dict)) val_results = defaultdict(lambda: defaultdict(dict)) -DATA_DIR = ALL_DIR / "data" -LATEX_DIR = ALL_DIR / "results" - -for file in DATA_DIR.glob("*IDX.json"): - if "12l" in file.stem: - continue - model = str(file.stem)[:-4] - with open(file, "r") as f: - data = json.load(f) - for lang in data.keys(): - for dataset in data[lang].keys(): - for model_type in data[lang][dataset].keys(): - if model_type == "true_indices" or model_type == "length": - continue - raw_data[lang][dataset][model + "-" + model_type] = data[lang][dataset][model_type] - - if "true_indicies" in raw_data[lang][dataset]: - assert raw_data[lang][dataset]["true_indices"] == data[lang][dataset]["true_indices"] - else: - raw_data[lang][dataset]["true_indices"] = data[lang][dataset]["true_indices"] - - if "length" in raw_data[lang][dataset]: - assert raw_data[lang][dataset]["length"] == data[lang][dataset]["length"] - else: - raw_data[lang][dataset]["length"] = data[lang][dataset]["length"] - - -for file in DATA_DIR.glob("*.json"): - if file.stem.endswith("IDX"): - continue - - with open(file, "r") as f: - data = json.load(f) - - model = file.stem - - for lang in data.keys(): - for dataset in data[lang].keys(): - for model_type in data[lang][dataset].keys(): - val_results[lang][dataset][model + "-" + model_type] = data[lang][dataset][model_type] - - -for lang in tqdm(raw_data.keys()): - for dataset in raw_data[lang].keys(): - - systems = sorted(list(raw_data[lang][dataset].keys())) - systems.remove("true_indices") - systems.remove("length") - - systems = [s for s in systems if val_results[lang][dataset][s] is not None] - - num_systems = len(systems) - - p_pvalues = pd.DataFrame(-100 + np.zeros((num_systems, num_systems)), index=systems, columns=systems) - r_pvalues = pd.DataFrame(-100 + np.zeros((num_systems, num_systems)), index=systems, columns=systems) - f_pvalues = pd.DataFrame(-100 + np.zeros((num_systems, num_systems)), index=systems, columns=systems) - - total_permutation_tests = num_systems * (num_systems - 1) // 2 - - for model in systems: - true_indices = raw_data[lang][dataset]["true_indices"] - pred_indices = raw_data[lang][dataset][model] - if pred_indices is None: - continue - length = raw_data[lang][dataset]["length"] - y_true, y_pred = reverse_where(true_indices, pred_indices, length) - _, _, f1 = compute_prf(y_true, y_pred) - print(f"{lang} {dataset} {model} F1: {f1}") - assert np.allclose( - f1, val_results[lang][dataset][model] - ), f" MISMATCH! {lang} {dataset} {model} F1: {f1} intrinsic_py_out: {val_results[lang][dataset][model]}" - - for i in range(num_systems): - for j in range(i + 1, num_systems): - true_indices = raw_data[lang][dataset]["true_indices"] - pred1_indices = raw_data[lang][dataset][systems[i]] - pred2_indices = raw_data[lang][dataset][systems[j]] - length = raw_data[lang][dataset]["length"] - y_true, y_pred1 = reverse_where(true_indices, pred1_indices, length) - y_true, y_pred2 = reverse_where(true_indices, pred2_indices, length) - - p_pvalue, r_pvalue, f_pvalue = permutation_test( - y_pred1, - y_pred2, - y_true, - num_rounds=10000, - ) - - p_pvalues.at[systems[i], systems[j]] = p_pvalue - r_pvalues.at[systems[i], systems[j]] = r_pvalue - f_pvalues.at[systems[i], systems[j]] = f_pvalue - - print_latex(p_pvalues, systems, LATEX_DIR / f"{lang}_{dataset}_p_pvalues.tex") - print_latex(r_pvalues, systems, LATEX_DIR / f"{lang}_{dataset}_r_pvalues.tex") - print_latex(f_pvalues, systems, LATEX_DIR / f"{lang}_{dataset}_f_pvalues.tex") - -print("All tests passed!") +DATA_DIR = ALL_DIR / f"all-stat-test-data/{args.table}" +LATEX_DIR = ALL_DIR / "p-values" +RESULTS_DATA_DIR = ALL_DIR / "results_data" + +spacy_langs = open("../data/permutation-test-data/all-stat-test-data/spacy_m_langs.txt").read().splitlines() + +with open(DATA_DIR / f"{args.table}_raw_data.pkl", "rb") as f: + raw_data = pickle.load(f) + +with open(DATA_DIR / f"{args.table}_val_results.pkl", "rb") as f: + val_results = pickle.load(f) + +all_systems_mapping = { + "benjamin_wtp-canine-s-3l_b512_s64_u0.01_k10-punct": "WtP-P", + "benjamin_wtp-canine-s-3l_b512_s64_u0.01_k10-t": "WtP-T", + "benjamin_wtp-canine-s-3l_b512_s64_u0.01_k10-u": "WtP-U", + "command-r_k10_s0Pv2-all-command-r": "C-R", + "meta-llama-3-8b-instruct_k10_s0Pv2-all-meta/meta-llama-3-8b-instruct": "L-3", + "xlmr-3l-v3_lc0.1-mix2-FT-33-33-33-v2-CorrSep_b512_s64_u0.25_k10-u": "SaT-SM", + "xlmr-3l-v3_look48_lc0.1-mix2_b512_s64_u0.025_k10-t": "SaT-T", + "xlmr-3l-v3_look48_lc0.1-mix2_b512_s64_u0.025_k10-u": "SaT-U", + "xlmr-3l-v4_LL_lora-v2_ep30_s10k_b512_s64_u0.5_k10-t": "SaT-Lora-T", + "xlmr-3l-v4_LL_lora-v2_ep30_s10k_b512_s64_u0.5_k10-u": "SaT-Lora-U", + "intrinsic_baselines-spacy_dp": "spacy-dp", + "intrinsic_baselines_multi-spacy_dp": "spacy-m", +} + +lora_filter_data = [ + ["ceb", "ted2020-corrupted-asr"], + ["et", "short-sequences"], + ["et", "short-sequences-corrupted-asr"], + ["ga", "ted2020-corrupted-asr"], + ["ha", "ted2020-corrupted-asr"], + ["ig", "ted2020-corrupted-asr"], + ["kk", "ud"], + ["ky", "ted2020-corrupted-asr"], + ["la", "ted2020-corrupted-asr"], + ["mg", "ted2020-corrupted-asr"], + ["mr", "ud"], + ["mt", "ted2020-corrupted-asr"], + ["pa", "ted2020-corrupted-asr"], + ["ta", "ud"], + ["tg", "ted2020-corrupted-asr"], + ["en-de", "short-sequences"], + ["en-de", "short-sequences-corrupted-asr"], + ["sr", "short-sequences"], + ["sr", "short-sequences-corrupted-asr"], + ["sl", "short-sequences"], + ["sl", "short-sequences-corrupted-asr"], +] + +for dataset in raw_data[args.lang].keys(): + + systems = list(all_systems_mapping.keys()).copy() + + if [args.lang, dataset] in lora_filter_data: + systems.remove("xlmr-3l-v4_LL_lora-v2_ep30_s10k_b512_s64_u0.5_k10-t") + else: + systems.remove("xlmr-3l-v4_LL_lora-v2_ep30_s10k_b512_s64_u0.5_k10-u") + + systems = [ + s for s in systems if s in val_results[args.lang][dataset] and val_results[args.lang][dataset][s] is not None + ] + + systems = sorted(systems, key=lambda x: val_results[args.lang][dataset][x], reverse=True) + + num_systems = len(systems) + + p_pvalues = pd.DataFrame(-100 + np.zeros((num_systems, num_systems)), index=systems, columns=systems) + r_pvalues = pd.DataFrame(-100 + np.zeros((num_systems, num_systems)), index=systems, columns=systems) + f_pvalues = pd.DataFrame(-100 + np.zeros((num_systems, num_systems)), index=systems, columns=systems) + + all_diffs = {system1: {} for system1 in systems} + + total_permutation_tests = num_systems * (num_systems - 1) // 2 + + for model in systems: + + true_indices = raw_data[args.lang][dataset]["true_indices"] + pred_indices = raw_data[args.lang][dataset][model] + if pred_indices is None: + continue + lengths = raw_data[args.lang][dataset]["lengths"] + y_true, y_pred = reverse_where(true_indices, pred_indices, lengths) + num_docs = len(y_true) + + _, _, f1 = compute_prf(y_true, y_pred, num_docs) + + assert np.allclose( + f1, val_results[args.lang][dataset][model] + ), f" MISMATCH! {args.lang} {dataset} {model} F1: {f1} intrinsic_py_out: {val_results[args.lang][dataset][model]}" + + for i in range(num_systems): + for j in range(i + 1, num_systems): + true_indices = raw_data[args.lang][dataset]["true_indices"] + pred1_indices = raw_data[args.lang][dataset][systems[i]] + pred2_indices = raw_data[args.lang][dataset][systems[j]] + lengths = raw_data[args.lang][dataset]["lengths"] + y_true, y_pred1 = reverse_where(true_indices, pred1_indices, lengths) + y_true, y_pred2 = reverse_where(true_indices, pred2_indices, lengths) + + diffs, p_pvalue, r_pvalue, f_pvalue = permutation_test( + y_pred1, + y_pred2, + y_true, + num_rounds=10000, + ) + + p_pvalues.at[systems[i], systems[j]] = p_pvalue + r_pvalues.at[systems[i], systems[j]] = r_pvalue + f_pvalues.at[systems[i], systems[j]] = f_pvalue + + all_diffs[systems[i]][systems[j]] = diffs + + print_latex( + f_pvalues, + systems, + all_systems_mapping, + val_results[args.lang][dataset], + LATEX_DIR / f"{dataset}/{args.lang}_f.tex", + ) + + saving_data = { + "p_pvalues": p_pvalues, + "r_pvalues": r_pvalues, + "f_pvalues": f_pvalues, + "all_diffs": all_diffs, + } + + if not (RESULTS_DATA_DIR / dataset).exists(): + (RESULTS_DATA_DIR / dataset).mkdir() + + with open(RESULTS_DATA_DIR / f"{dataset}/{args.lang}_data.pkl", "wb") as f: + pickle.dump(saving_data, f) + + print(f"Finished {args.lang} {dataset}") + +print("All validation tests passed and significance tests done!") diff --git a/wtpsplit/evaluation/permutation_test_data.py b/wtpsplit/evaluation/permutation_test_data.py new file mode 100644 index 00000000..f298021f --- /dev/null +++ b/wtpsplit/evaluation/permutation_test_data.py @@ -0,0 +1,247 @@ +import json +from collections import defaultdict +from pathlib import Path + +from tqdm import tqdm +import pickle + +ALL_DIR = Path("../data/permutation-test-data/") + +raw_data = defaultdict(lambda: defaultdict(dict)) +val_results = defaultdict(lambda: defaultdict(dict)) + + +DATA_DIR = ALL_DIR / "all-stat-test-data/main_all" +LATEX_DIR = ALL_DIR / "p-values" +RESULTS_DATA_DIR = ALL_DIR / "results_data" + +spacy_langs = open("../data/permutation-test-data/all-stat-test-data/spacy_m_langs.txt").read().splitlines() + + +for file in tqdm(DATA_DIR.glob("*IDX.json"), desc="Loading indices"): + + model = str(file.stem)[:-4] + with open(file, "r") as f: + data = json.load(f) + for lang in data.keys(): + if file.stem.startswith("intrinsic_baselines_multi") and lang not in spacy_langs: + continue + + for dataset in data[lang].keys(): + if ( + dataset.startswith("legal") + or dataset.startswith("ted") + or "corrupted-asr" in dataset + or "short-sequences" in dataset + or "code-switching" in dataset + ): + continue + + for model_type in data[lang][dataset].keys(): + + if lang == "cs" and (model_type == "meta/meta-llama-3-8b-instruct" or model_type == "command-r"): + continue + + if model_type.startswith("spacy_sent"): + continue + + if ( + ( + model_type == "true_indices" + or model_type == "length" + or model_type == "lengths" + or model_type == "refused" + ) + or data[lang][dataset][model_type] is None + or "predicted_indices" not in data[lang][dataset][model_type] + ): + continue + + data_list = data[lang][dataset][model_type]["predicted_indices"] + + if data_list is None: + continue + + if len(data_list) == 0: + data_list = [[]] + try: + if isinstance(data_list[0], int): + data_list = [data_list] + except: + print(data_list) + print(lang, dataset, model_type) + raise Exception + + raw_data[lang][dataset][model + "-" + model_type] = data_list + + true_indices = data[lang][dataset][model_type]["true_indices"] + + if isinstance(true_indices[0], int): + true_indices = [true_indices] + + if "true_indicies" in raw_data[lang][dataset]: + assert raw_data[lang][dataset]["true_indices"] == true_indices + else: + raw_data[lang][dataset]["true_indices"] = true_indices + + data_lengths = ( + data[lang][dataset][model_type]["length"] + if "length" in data[lang][dataset][model_type] + else data[lang][dataset][model_type]["lengths"] + ) + + if isinstance(data_lengths, int): + data_lengths = [data_lengths] + + if "lengths" in raw_data[lang][dataset]: + assert ( + raw_data[lang][dataset]["lengths"] == data_lengths + ), f"{lang}, {dataset}, {model_type}... [lengths assertion] before: {raw_data[lang][dataset]['lengths']} after: {data_lengths}" + else: + raw_data[lang][dataset]["lengths"] = data_lengths + + +for file in tqdm(DATA_DIR.glob("*.json"), desc="Loading F1s"): + if file.stem.endswith("IDX"): + continue + + with open(file, "r") as f: + data = json.load(f) + + model = file.stem + + for lang in data.keys(): + if file.stem.startswith("intrinsic_baselines_multi") and lang not in spacy_langs: + continue + + for dataset in data[lang].keys(): + for model_type in data[lang][dataset].keys(): + if lang == "cs" and (model_type == "meta/meta-llama-3-8b-instruct" or model_type == "command-r"): + continue + + if model_type == "f1": + renamed_model_type = "llm" + else: + renamed_model_type = model_type + result = data[lang][dataset][model_type] + + if result is None or result == {}: + continue + elif not isinstance(result, float): + result = result["f1"] + + val_results[lang][dataset][model + "-" + renamed_model_type] = result + + +all_systems_mapping = { + "benjamin_wtp-canine-s-3l_b512_s64_u0.01_k10-punct": "WtP-P", + "benjamin_wtp-canine-s-3l_b512_s64_u0.01_k10-t": "WtP-T", + "benjamin_wtp-canine-s-3l_b512_s64_u0.01_k10-u": "WtP-U", + "command-r_k10_s0Pv2-all-command-r": "C-R", + "meta-llama-3-8b-instruct_k10_s0Pv2-all-meta/meta-llama-3-8b-instruct": "L-3", + "xlmr-3l-v3_lc0.1-mix2-FT-33-33-33-v2-CorrSep_b512_s64_u0.25_k10-u": "SaT-SM", + "xlmr-3l-v3_look48_lc0.1-mix2_b512_s64_u0.025_k10-t": "SaT-T", + "xlmr-3l-v3_look48_lc0.1-mix2_b512_s64_u0.025_k10-u": "SaT-U", + "xlmr-3l-v4_LL_lora-v2_ep30_s10k_b512_s64_u0.5_k10-t": "SaT-Lora-T", + "xlmr-3l-v4_LL_lora-v2_ep30_s10k_b512_s64_u0.5_k10-u": "SaT-Lora-U", + "intrinsic_baselines-spacy_dp": "spacy-dp", + "intrinsic_baselines_multi-spacy_dp": "spacy-m", +} + +main_table_exclude = [ + ["bn", "ud"], + ["ceb", "ud"], + ["es", "ersatz"], + ["fr", "opus100"], + ["hy", "opus100"], + ["id", "ud"], + ["jv", "ud"], + ["mn", "opus100"], + ["nl", "opus100"], + ["ru", "opus100"], + ["sq", "ud"], + ["th", "ud"], + ["yo", "opus100"], + ["yo", "ud"], +] + + +for lang in raw_data.keys(): + for system in all_systems_mapping.keys(): + + main_results = [] + if ( + "opus100" in raw_data[lang] + and system in raw_data[lang]["opus100"] + and [lang, "opus100"] not in main_table_exclude + ): + main_results.append(val_results[lang]["opus100"][system]) + if "ud" in raw_data[lang] and system in raw_data[lang]["ud"] and [lang, "ud"] not in main_table_exclude: + main_results.append(val_results[lang]["ud"][system]) + if ( + "ersatz" in raw_data[lang] + and system in raw_data[lang]["ersatz"] + and [lang, "ersatz"] not in main_table_exclude + ): + main_results.append(val_results[lang]["ersatz"][system]) + + if main_results == []: + continue + + avg_f1 = sum(main_results) / len(main_results) + + preds_main_results_indicies = [] + trues_main_results_indicies = [] + lengths_main_results = [] + + val_results[lang]["main_table_mean"][system] = avg_f1 + + if ( + "opus100" in raw_data[lang] + and system in raw_data[lang]["opus100"] + and [lang, "opus100"] not in main_table_exclude + ): + preds_main_results_indicies.append(raw_data[lang]["opus100"][system][0]) + trues_main_results_indicies.append(raw_data[lang]["opus100"]["true_indices"]) + lengths_main_results.append(raw_data[lang]["opus100"]["lengths"][0]) + + if "ud" in raw_data[lang] and system in raw_data[lang]["ud"] and [lang, "ud"] not in main_table_exclude: + preds_main_results_indicies.append(raw_data[lang]["ud"][system][0]) + trues_main_results_indicies.append(raw_data[lang]["ud"]["true_indices"]) + lengths_main_results.append(raw_data[lang]["ud"]["lengths"][0]) + + if ( + "ersatz" in raw_data[lang] + and system in raw_data[lang]["ersatz"] + and [lang, "ersatz"] not in main_table_exclude + ): + preds_main_results_indicies.append(raw_data[lang]["ersatz"][system][0]) + trues_main_results_indicies.append(raw_data[lang]["ersatz"]["true_indices"]) + lengths_main_results.append(raw_data[lang]["ersatz"]["lengths"][0]) + + raw_data[lang]["main_table_mean"][system] = preds_main_results_indicies + + if "true_indices" in raw_data[lang]["main_table_mean"]: + assert ( + raw_data[lang]["main_table_mean"]["true_indices"] == trues_main_results_indicies + ), f"{lang} {system}, {[len(i) for i in trues_main_results_indicies]}, {[len(i) for i in raw_data[lang]['main_table_mean']['true_indices']]}" + else: + raw_data[lang]["main_table_mean"]["true_indices"] = trues_main_results_indicies + + if "lengths" in raw_data[lang]["main_table_mean"]: + assert ( + raw_data[lang]["main_table_mean"]["lengths"] == lengths_main_results + ), f"{lang} {system} {raw_data[lang]['main_table_mean']['lengths']} {lengths_main_results}" + else: + raw_data[lang]["main_table_mean"]["lengths"] = lengths_main_results + + +raw_data = {k: dict(v) for k, v in raw_data.items()} + +with open(DATA_DIR / "main_all_raw_data.pkl", "wb") as f: + pickle.dump(raw_data, f) + +val_results = {k: dict(v) for k, v in val_results.items()} + +with open(DATA_DIR / "main_all_val_results.pkl", "wb") as f: + pickle.dump(val_results, f) diff --git a/wtpsplit/evaluation/permutation_test_utils.py b/wtpsplit/evaluation/permutation_test_utils.py index 466cd925..c1402d3d 100644 --- a/wtpsplit/evaluation/permutation_test_utils.py +++ b/wtpsplit/evaluation/permutation_test_utils.py @@ -1,23 +1,37 @@ import numpy as np from tqdm import tqdm +from multiprocessing import Pool -def compute_prf(true_values, predicted_values): +def compute_prf(true_values, predicted_values, num_docs): - TP = np.sum((predicted_values == 1) & (true_values == 1)) - FP = np.sum((predicted_values == 1) & (true_values == 0)) - FN = np.sum((predicted_values == 0) & (true_values == 1)) + f1 = 0 + r = 0 + p = 0 - precision = TP / (TP + FP) if (TP + FP) > 0 else 0 - recall = TP / (TP + FN) if (TP + FN) > 0 else 0 - f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + for true, pred in zip(true_values, predicted_values): - return precision, recall, f1_score + TP = np.sum((pred == 1) & (true == 1)) + FP = np.sum((pred == 1) & (true == 0)) + FN = np.sum((pred == 0) & (true == 1)) + precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0 -def test_func(x, y, true): - p_x, r_x, f1_x = compute_prf(true, x) - p_y, r_y, f1_y = compute_prf(true, y) + p += precision + r += recall + f1 += 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + p /= num_docs + r /= num_docs + f1 /= num_docs + + return p, r, f1 + + +def test_func(x, y, true, num_docs): + p_x, r_x, f1_x = compute_prf(true, x, num_docs) + p_y, r_y, f1_y = compute_prf(true, y, num_docs) diff_p = np.abs(p_x - p_y) diff_r = np.abs(r_x - r_y) @@ -26,49 +40,50 @@ def test_func(x, y, true): return diff_p, diff_r, diff_f1 +def permutation_test_single_round(x, y, true, y_lengths, num_docs, flips): + sample_x = [np.where(flips[:m], y[i], x[i]) for i, m in enumerate(y_lengths)] + sample_y = [np.where(flips[:m], x[i], y[i]) for i, m in enumerate(y_lengths)] + + return test_func(sample_x, sample_y, true, num_docs) + + def permutation_test( x, y, true, num_rounds=10000, - seed=None, ): - rng = np.random.RandomState(seed) - - m, n = len(x), len(y) + # print(num_rounds) - if m != n: - raise ValueError( - f"x and y must have the same" f" length if `paired=True`, but they had lengths {m} and {n} respectively." - ) + x_lengths = [len(i) for i in x] + y_lengths = [len(i) for i in y] - sample_x = np.empty(m) - sample_y = np.empty(n) + for i, j in zip(x_lengths, y_lengths): + assert i == j p_at_least_as_extreme = 0.0 r_at_least_as_extreme = 0.0 f_at_least_as_extreme = 0.0 - p_reference_stat, r_reference_stat, f_reference_stat = test_func(x, y, true) - - # this loop currently takes 30 seconds. Probably all because of test_func although that is already quite optimized - # maybe we can do these rounds in parallel - - for i in range(num_rounds): - flip = rng.randn(m) > 0.0 + num_docs = len(true) - # for i, f in enumerate(flip): - # if f: - # sample_x[i], sample_y[i] = y[i], x[i] - # else: - # sample_x[i], sample_y[i] = x[i], y[i] + p_reference_stat, r_reference_stat, f_reference_stat = test_func(x, y, true, num_docs) - sample_x = np.where(flip, y, x) - sample_y = np.where(flip, x, y) + flips = np.random.randint(2, size=(num_rounds, max(y_lengths))) - diff_p, diff_r, diff_f = test_func(sample_x, sample_y, true) + with Pool(5) as pool: + results = list( + pool.starmap( + permutation_test_single_round, + tqdm( + [(x, y, true, y_lengths, num_docs, flips[i]) for i in range(num_rounds)], + total=num_rounds, + ), + ), + ) + for diff_p, diff_r, diff_f in results: if diff_p > p_reference_stat or np.isclose(diff_p, p_reference_stat): p_at_least_as_extreme += 1.0 @@ -79,13 +94,16 @@ def permutation_test( f_at_least_as_extreme += 1.0 return ( + results, p_at_least_as_extreme / num_rounds, r_at_least_as_extreme / num_rounds, f_at_least_as_extreme / num_rounds, ) -def print_latex(df, systems, filename): +def print_latex(df, systems, all_systems_mapping, results, filename): + + filename.parent.mkdir(parents=True, exist_ok=True) with open(filename, "w") as f: @@ -93,16 +111,31 @@ def print_latex(df, systems, filename): while " " in latex: latex = latex.replace(" ", " ") latex = latex.replace("-100.000", "-") + + for system, system_name in all_systems_mapping.items(): + latex = latex.replace(system, system_name) + + for system, system_name in all_systems_mapping.items(): + if system in results: + latex += "\n" + latex += f"% {system_name}: {round(results[system], 3)}" + f.write(latex) -def reverse_where(true_indices, pred_indices, length): - y_true = np.zeros(length) - y_true[true_indices] = 1 - y_pred = np.zeros(length) - try: - y_pred[pred_indices] = 1 - except: - print(f"pred_indices: {pred_indices}") - raise Exception - return y_true, y_pred +def reverse_where(true_indices, pred_indices, lengths): + + y_true_all = [] + y_pred_all = [] + + for true, pred, length in zip(true_indices, pred_indices, lengths): + + y_true = np.zeros(length) + y_true[true] = 1 + y_pred = np.zeros(length) + y_pred[pred] = 1 + + y_true_all.append(y_true) + y_pred_all.append(y_pred) + + return y_true_all, y_pred_all diff --git a/wtpsplit/train/train_FT.py b/wtpsplit/train/train_FT.py index 3159cf8f..494a78b5 100644 --- a/wtpsplit/train/train_FT.py +++ b/wtpsplit/train/train_FT.py @@ -19,18 +19,16 @@ from wtpsplit.utils import Constants parser = argparse.ArgumentParser() + parser.add_argument("--block_size", type=int, default=256) parser.add_argument("--num_layers", type=int, default=12) parser.add_argument("--lim_lookahead", type=bool, default=False) -parser.add_argument("--upsample_non_whitespace", type=bool, default=False) parser.add_argument("--without_pretraining", type=bool, default=False) -parser.add_argument("--corruption_in_pretraining", type=bool, default=False) -parser.add_argument("--new_tokenizer", type=bool, default=False) -parser.add_argument("--upsampling_in_pretraining", type=bool, default=False) -args = parser.parse_args() +parser.add_argument("--no_sm_corruption", type=bool, default=False) +args = parser.parse_args() -data_path = "data/all_data.pth" +data_path = "../data/preprocessed_training_data/all_data_11_05-all.pth" all_data = torch.load(data_path) block_size = args.block_size @@ -38,64 +36,116 @@ train_sentences = defaultdict(lambda: defaultdict(list)) test_sentences = defaultdict(lambda: defaultdict(list)) +punct_chars = set(Constants.PUNCTUATION_CHARS) + + for lang_code in tqdm(all_data, desc="Loading train/dev data"): - if ( + if "-" in lang_code or "_" in lang_code: + pass + elif ( "ud" in all_data[lang_code]["sentence"] and all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] is not None ): - print(f"Found UD data for {lang_code}") train_data = all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] - train_sentences[lang_code]["all"].extend(train_data) + if len(train_data) < 10000: + train_data = train_data * (10000 // len(train_data) + 1) + + if len(train_data) < 5000: + train_data = train_data * (10000 // len(train_data) + 1) + + train_sentences[lang_code]["uncorrupted"].extend(train_data) + + if not args.no_sm_corruption: + + train_data = all_data[lang_code]["sentence"]["ud-corrupted-asr"]["meta"]["train_data"] + + if len(train_data) < 5000: + train_data = train_data * (10000 // len(train_data) + 1) + + train_sentences[lang_code]["corrupted-asr"].extend(train_data) + + train_data = all_data[lang_code]["sentence"]["ud-corrupted-social-media"]["meta"]["train_data"] + + if len(train_data) < 5000: + train_data = train_data * (10000 // len(train_data) + 1) - train_data = all_data[lang_code]["sentence"]["ud-corrupted"]["meta"]["train_data"] - train_sentences[lang_code]["all"].extend(train_data) + train_sentences[lang_code]["corrupted-social-media"].extend(train_data) elif ( "opus100" in all_data[lang_code]["sentence"] and all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"] is not None ): - print(f"Found Opus100 data for {lang_code}") + train_data = all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"] + assert len(train_data) == 10000 + train_sentences[lang_code]["uncorrupted"].extend(train_data) - train_data = all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"][:10000] - train_sentences[lang_code]["all"].extend(train_data) + if not args.no_sm_corruption: - train_data = all_data[lang_code]["sentence"]["opus100-corrupted"]["meta"]["train_data"][:10000] - train_sentences[lang_code]["all"].extend(train_data) + train_data = all_data[lang_code]["sentence"]["opus100-corrupted-asr"]["meta"]["train_data"] + assert len(train_data) == 10000 + train_sentences[lang_code]["corrupted-asr"].extend(train_data) + + train_data = all_data[lang_code]["sentence"]["opus100-corrupted-social-media"]["meta"]["train_data"] + assert len(train_data) == 10000 + train_sentences[lang_code]["corrupted-social-media"].extend(train_data) else: - print(f"No data found for {lang_code}") + + train_data = all_data[lang_code]["sentence"]["nllb"]["meta"]["train_data"] + assert len(train_data) == 10000 + train_sentences[lang_code]["uncorrupted"].extend(train_data) + + if not args.no_sm_corruption: + + train_data = all_data[lang_code]["sentence"]["nllb-corrupted-asr"]["meta"]["train_data"] + assert len(train_data) == 10000 + train_sentences[lang_code]["corrupted-asr"].extend(train_data) + + train_data = all_data[lang_code]["sentence"]["nllb-corrupted-social-media"]["meta"]["train_data"] + assert len(train_data) == 10000 + train_sentences[lang_code]["corrupted-social-media"].extend(train_data) for dataset in all_data[lang_code]["sentence"]: + if any(dataset.startswith(x) for x in ["short-sequences", "legal"]): + continue + test_data = all_data[lang_code]["sentence"][dataset]["data"] - test_sentences[dataset][lang_code].extend(test_data[:100]) + test_sentences[lang_code][dataset].extend(test_data[:200]) tokenizer_checkpoint = "xlm-roberta-base" if args.without_pretraining: model_checkpoint = "xlm-roberta-base" -elif args.upsampling_in_pretraining: - model_checkpoint = "data/models/xlmr-3l-v3_look48_snW4" +elif args.num_layers == 1: + if args.lim_lookahead: + raise NotImplementedError("Not implemented") + else: + model_checkpoint = "../data/models/xlmr-1l-v3_lc0.1_mix2" elif args.num_layers == 3: if args.lim_lookahead: - model_checkpoint = "data/models/xlmr-3l-v3_look48-NEW" + model_checkpoint = "../data/models/xlmr-3l-v3_look48_lc0.1-mix2" else: - model_checkpoint = "data/models/xlmr-3l-v3" + model_checkpoint = "../data/models/xlmr-3l-v3_lc0.1-mix2" elif args.num_layers == 6: if args.lim_lookahead: - model_checkpoint = "data/models/xlmr-6l-v3_look48-OLD" + model_checkpoint = "../data/models/xlmr-6l-v3_look48_lc0.1-mix2" else: - model_checkpoint = "data/models/xlmr-6l-v3" + model_checkpoint = "../data/models/xlmr-6l-v3_lc0.1-mix2" +elif args.num_layers == 9: + if args.lim_lookahead: + raise NotImplementedError("Not downloaded") + else: + model_checkpoint = "../data/models/xlmr-9l-v3_lc0.1_mix2" elif args.num_layers == 12: - if args.corruption_in_pretraining: - model_checkpoint = "data/models/xlmr-12l-v3_lc0.25-mix2" - elif args.new_tokenizer: - model_checkpoint = "data/models/xlmr-12l-v3-NEW" + if args.lim_lookahead: + model_checkpoint = "../data/models/xlmr-12l-v3_look48_lc0.1" else: - model_checkpoint = "data/models/xlmr-12l-v3-OLD" + model_checkpoint = "../data/models/xlmr-12l-v3_lc0.1-mix2" + else: raise ValueError("Invalid number of layers. Valid values are 3, 6, 12.") @@ -104,25 +154,41 @@ tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast) -model = SubwordXLMForTokenClassification.from_pretrained( - model_checkpoint, - num_labels=1, - ignore_mismatched_sizes=True, -) +if args.num_layers == 3 and args.without_pretraining: + model = SubwordXLMForTokenClassification.from_pretrained( + model_checkpoint, + num_labels=1, + ignore_mismatched_sizes=True, + num_hidden_layers=3, + ) +else: + model = SubwordXLMForTokenClassification.from_pretrained( + model_checkpoint, + num_labels=1, + ignore_mismatched_sizes=True, + ) -print("Model loaded") +def tokenize_and_get_labels(sentences, lang_code, dataset_name): -def tokenize_and_get_labels(sentences, separator): + separator = Constants.SEPARATORS.get(lang_code, " ") joined_sentence = "" sentence_start_positions = [] current_position = 0 for sentence in sentences: + if random.random() < 0.1 and sentence[-1] in punct_chars and dataset_name == "corrupted-social-media": + if separator == " ": + separator_used = "" + else: + separator_used = " " + else: + separator_used = separator + if joined_sentence: - joined_sentence += separator - current_position += len(separator) + joined_sentence += separator_used + current_position += len(separator_used) start_position = current_position joined_sentence += sentence current_position += len(sentence) @@ -153,33 +219,35 @@ def tokenize_and_get_labels(sentences, separator): return input_ids, labels -def pack_sentences(input_data_dict, block_size, by_dataset_too=False): +def pack_sentences(input_data_dict, block_size): - if not by_dataset_too: - packed_data = defaultdict(lambda: {"input_ids": [], "attention_mask": [], "labels": []}) - else: - packed_data = defaultdict(lambda: defaultdict(lambda: {"input_ids": [], "attention_mask": [], "labels": []})) + packed_data = defaultdict(lambda: defaultdict(lambda: {"input_ids": [], "attention_mask": [], "labels": []})) - for dataset_name in tqdm(input_data_dict): - for lang_code, sentences in input_data_dict[dataset_name].items(): + for lang_code in tqdm(input_data_dict): + for dataset_name, sentences in input_data_dict[lang_code].items(): - separator = Constants.SEPARATORS.get(lang_code, " ") + if dataset_name == "corrupted-social-media": + p_add_to_block = 0.5 + else: + p_add_to_block = 1.0 token_count, one_block_sentences = 0, [] for sentence in sentences: - if not sentence or sentence.isnumeric(): - continue - # TODO change this to tokenize in one go for efficiency num_sentence_tokens = len(tokenizer(sentence, add_special_tokens=False)["input_ids"]) - if token_count + num_sentence_tokens < block_size - 4: + if not sentence or sentence.isnumeric() or num_sentence_tokens == 0: + continue + + if token_count + num_sentence_tokens < block_size - 4 and ( + random.random() <= p_add_to_block or len(one_block_sentences) == 0 + ): one_block_sentences.append(sentence) token_count += num_sentence_tokens else: if one_block_sentences: - input_ids, labels = tokenize_and_get_labels(one_block_sentences, separator) + input_ids, labels = tokenize_and_get_labels(one_block_sentences, lang_code, dataset_name) num_to_pad = block_size - len(input_ids) attention_mask = [1] * len(input_ids) + [0] * num_to_pad @@ -192,14 +260,9 @@ def pack_sentences(input_data_dict, block_size, by_dataset_too=False): len(labels), ) - if not by_dataset_too: - packed_data[lang_code]["input_ids"].append(input_ids) - packed_data[lang_code]["attention_mask"].append(attention_mask) - packed_data[lang_code]["labels"].append(labels) - else: - packed_data[dataset_name][lang_code]["input_ids"].append(input_ids) - packed_data[dataset_name][lang_code]["attention_mask"].append(attention_mask) - packed_data[dataset_name][lang_code]["labels"].append(labels) + packed_data[lang_code][dataset_name]["input_ids"].append(input_ids) + packed_data[lang_code][dataset_name]["attention_mask"].append(attention_mask) + packed_data[lang_code][dataset_name]["labels"].append(labels) if num_sentence_tokens > block_size - 4: one_block_sentences = [] @@ -209,7 +272,7 @@ def pack_sentences(input_data_dict, block_size, by_dataset_too=False): token_count = num_sentence_tokens if one_block_sentences: - input_ids, labels = tokenize_and_get_labels(one_block_sentences, separator) + input_ids, labels = tokenize_and_get_labels(one_block_sentences, lang_code, dataset_name) num_to_pad = block_size - len(input_ids) attention_mask = [1] * len(input_ids) + [0] * num_to_pad @@ -219,49 +282,31 @@ def pack_sentences(input_data_dict, block_size, by_dataset_too=False): assert len(input_ids) == block_size, len(input_ids) assert len(input_ids) == len(labels), (len(input_ids), len(labels)) - if not by_dataset_too: - packed_data[lang_code]["input_ids"].append(input_ids) - packed_data[lang_code]["attention_mask"].append(attention_mask) - packed_data[lang_code]["labels"].append(labels) - else: - packed_data[dataset_name][lang_code]["input_ids"].append(input_ids) - packed_data[dataset_name][lang_code]["attention_mask"].append(attention_mask) - packed_data[dataset_name][lang_code]["labels"].append(labels) + packed_data[lang_code][dataset_name]["input_ids"].append(input_ids) + packed_data[lang_code][dataset_name]["attention_mask"].append(attention_mask) + packed_data[lang_code][dataset_name]["labels"].append(labels) - if not by_dataset_too: - assert len(packed_data[lang_code]["input_ids"]) == len(packed_data[lang_code]["labels"]) - else: - assert len(packed_data[dataset_name][lang_code]["input_ids"]) == len( - packed_data[dataset_name][lang_code]["labels"] - ) + assert len(packed_data[lang_code][dataset_name]["input_ids"]) == len( + packed_data[lang_code][dataset_name]["labels"] + ) return packed_data packed_train_data = pack_sentences(train_sentences, block_size) +packed_test_data = pack_sentences(test_sentences, block_size) +test_dataset = {lang_code: defaultdict(dict) for lang_code in packed_test_data} -packed_test_data = pack_sentences(test_sentences, block_size, by_dataset_too=True) - -test_dataset = {dataset_name: defaultdict(dict) for dataset_name in packed_test_data} - -for dataset_name in packed_test_data: - for lang_code in packed_test_data[dataset_name]: - test_dataset[dataset_name][lang_code] = Dataset.from_dict(packed_test_data[dataset_name][lang_code]) - -if args.lim_lookahead: - lookahead = 48 -else: - lookahead = 512 +for lang_code in packed_test_data: + for dataset_name in packed_test_data[lang_code]: + test_dataset[lang_code][dataset_name] = Dataset.from_dict(packed_test_data[lang_code][dataset_name]) experiment_name = model_checkpoint.split("/")[-1] -experiment_name += f"-FT-{args.num_layers}L-{args.block_size}BS-{lookahead}LA" - -if args.upsampling_in_pretraining: - experiment_name += "-upsample_nW4_in_WtP" +experiment_name += str(args.num_layers) + "L" -if args.upsample_non_whitespace: - experiment_name += "-upsample_nW4_in_FT" +if args.no_sm_corruption: + experiment_name += "-no-corruption" def compute_prf(true_values, predicted_values): @@ -293,39 +338,16 @@ def compute_metrics(p): predictions = predictions[labels != -100] labels = labels[labels != -100] - thresholds = np.concatenate( - [ - np.arange(0.0000001, 0.000001, 0.0000001), - np.arange(0.000001, 0.00001, 0.000001), - np.arange(0.00001, 0.0001, 0.00001), - np.arange(0.0001, 0.001, 0.0001), - np.arange(0.001, 0.01, 0.001), - np.arange(0.01, 0.1, 0.01), - np.arange(0.1, 1, 0.1), - ] - ) + threshold = 0.25 - precision_best = 0 - recall_best = 0 - f1_best = 0 - threshold_best = 0 + preds = (predictions > threshold).astype(int) - for threshold in thresholds[::-1]: - preds = (predictions > threshold).astype(int) - - precision, recall, f1 = compute_prf(labels, preds) - - if f1 > f1_best: - precision_best = precision - recall_best = recall - f1_best = f1 - threshold_best = threshold + precision, recall, f1 = compute_prf(labels, preds) output_dict = { - "precision": precision_best, - "recall": recall_best, - "f1": f1_best, - "threshold": threshold_best, + "precision": precision, + "recall": recall, + "f1": f1, } return output_dict @@ -341,11 +363,9 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): def on_step_end(self, args, state, control, **kwargs): if state.global_step % args.eval_steps == 0: - for dataset_name in self.eval_datasets: - for lang_code, eval_dataset in self.eval_datasets[dataset_name].items(): - + for lang_code in self.eval_datasets: + for dataset_name, eval_dataset in self.eval_datasets[lang_code].items(): metrics = trainer.evaluate(eval_dataset) - for metric, result in metrics.items(): wandb.log( { @@ -357,29 +377,21 @@ def on_step_end(self, args, state, control, **kwargs): multi_dataset_eval_callback = MultiDatasetEvalCallback(test_dataset) -if not args.upsample_non_whitespace: - - train_datasets = [Dataset.from_dict(data) for lang_code, data in packed_train_data.items()] -else: - train_datasets = [] +train_datasets = [] - for lang_code, data in packed_train_data.items(): - if Constants.SEPARATORS.get(lang_code, " ") == "": - train_datasets.extend([Dataset.from_dict(data)] * 20) - else: - train_datasets.append(Dataset.from_dict(data)) +for lang_code in packed_train_data: + for dataset_name in packed_train_data[lang_code]: + train_datasets.append(Dataset.from_dict(packed_train_data[lang_code][dataset_name])) random.shuffle(train_datasets) - train_datasets = ConcatDataset(train_datasets) - -run = wandb.init(project="WtP-FT", entity="igorsterner") +run = wandb.init(project="SaT-SM", entity="igorsterner") wandb.run.name = experiment_name args = TrainingArguments( - output_dir=Path("data/models") / experiment_name, + output_dir=Path("../data/models") / experiment_name, overwrite_output_dir=True, evaluation_strategy="steps", eval_steps=250, @@ -401,7 +413,6 @@ def on_step_end(self, args, state, control, **kwargs): class RoundRobinSampler: def __init__(self, samplers: Sequence[Iterable], reinit: bool = False): - self.samplers = samplers self.reinit = reinit @@ -418,7 +429,6 @@ def __iter__(self): if not self.reinit: break - # re-initialize the iterator it = iter(self.samplers[i]) iterators[i] = it yield next(it) @@ -491,7 +501,7 @@ def get_train_dataloader(self) -> DataLoader: batch_sampler=DistributedRoundRobinBatchSampler( lengths=sizes, batch_size=self.args.train_batch_size, - drop_last=self.args.dataloader_drop_last, + drop_last=False, rank=self.args.process_index, num_replicas=self.args.world_size, seed=self.args.seed, @@ -507,7 +517,7 @@ def get_train_dataloader(self) -> DataLoader: trainer = CustomTrainer( model=model, args=args, - train_dataset=train_datasets, # Now it's a concatenated dataset + train_dataset=train_datasets, eval_dataset=None, compute_metrics=compute_metrics, tokenizer=tokenizer,