From fd6716eb0154a04d91eadde68d35103db23971f4 Mon Sep 17 00:00:00 2001 From: markus583 Date: Sat, 18 May 2024 18:09:17 +0000 Subject: [PATCH] update data pth, idcs --- wtpsplit/evaluation/intrinsic_pairwise.py | 40 +++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index f20864e1..fe109cb0 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -45,7 +45,7 @@ class Args: # } # } # } - eval_data_path: str = "data/all_data_04_05.pth" + eval_data_path: str = "data/all_data_11_05-all.pth" valid_text_path: str = None # "data/sentence/valid.parquet" device: str = "cpu" block_size: int = 512 @@ -61,7 +61,7 @@ class Args: keep_logits: bool = True skip_corrupted: bool = True skip_punct: bool = True - return_indices: bool = False + return_indices: bool = True # k_mer-specific args k: int = 2 @@ -254,7 +254,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st # eval data for dataset_name, dataset in eval_data[lang_code]["sentence"].items(): - if args.skip_corrupted and "corrupted" in dataset_name: + if args.skip_corrupted and "corrupted" in dataset_name and"ted2020" not in dataset_name: continue try: if args.adapter_path: @@ -407,6 +407,8 @@ def main(args): # now, compute the intrinsic scores. results = {} clfs = {} + if args.return_indices: + indices = {} # Initialize lists to store scores for each metric across all languages u_scores, t_scores, punct_scores = [], [], [] u_accs, t_accs, punct_accs = [], [], [] @@ -419,6 +421,8 @@ def main(args): print(f"Predicting {lang_code}...") results[lang_code] = {} clfs[lang_code] = {} + if args.return_indices: + indices[lang_code] = {} for dataset_name, dataset in dsets["sentence"].items(): sentences = dataset["data"][: args.max_n_test_sentences] @@ -437,7 +441,7 @@ def main(args): ) if lang_code not in f or dataset_name not in f[lang_code]: continue - + if "train_logits" in f[lang_code][dataset_name] and not args.skip_adaptation: feature_indices = None # it is sufficient to feed in 1 long sequence of tokens here since we only use logits for LR @@ -489,6 +493,8 @@ def main(args): score_u = [] acc_u = [] thresholds = [] + u_indices, true_indices = [], [] + length = [] for i, k_mer in enumerate(sent_k_mers): start, end = f[lang_code][dataset_name]["test_logit_lengths"][i] if args.adjust_threshold: @@ -504,7 +510,7 @@ def main(args): thresholds.append(threshold_adjusted) else: thresholds.append(args.threshold) - single_score_u, _, info, u_indices, _ = evaluate_mixture( + single_score_u, _, info, cur_u_indices, _ = evaluate_mixture( lang_code, f[lang_code][dataset_name]["test_logits"][:][start:end], list(k_mer), @@ -517,11 +523,16 @@ def main(args): score_u = np.mean(score_u) score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None - score_punct = np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None + score_punct = ( + np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None + ) acc_u = np.mean(acc_u) acc_t = np.mean(acc_t) if score_t else None acc_punct = np.mean(acc_punct) if score_punct else None threshold = np.mean(thresholds) + u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else []) + true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else []) + length.append(cur_u_indices["length"]) results[lang_code][dataset_name] = { "u": score_u, @@ -534,6 +545,10 @@ def main(args): "threshold_adj": threshold, } + if args.return_indices: + indices[lang_code][dataset_name] = { + "u": {"predicted_indices": u_indices, "true_indices": true_indices, "length": length}, + } # just for printing score_t = score_t or 0.0 score_punct = score_punct or 0.0 @@ -583,6 +598,19 @@ def main(args): ), indent=4, ) + + if args.return_indices: + json.dump( + indices, + open( + Constants.CACHE_DIR / "intrinsic_pairwise" / f"{save_str}_IDX.json", + "w", + ), + default=int, + indent=4, + ) + print(Constants.CACHE_DIR / "intrinsic_pairwise" / f"{save_str}_IDX.json") + print("Indices saved to file.") if not args.keep_logits: os.remove(f.filename)