Skip to content

Commit

Permalink
update data pth, idcs
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 18, 2024
1 parent 545beca commit fd6716e
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Args:
# }
# }
# }
eval_data_path: str = "data/all_data_04_05.pth"
eval_data_path: str = "data/all_data_11_05-all.pth"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cpu"
block_size: int = 512
Expand All @@ -61,7 +61,7 @@ class Args:
keep_logits: bool = True
skip_corrupted: bool = True
skip_punct: bool = True
return_indices: bool = False
return_indices: bool = True

# k_mer-specific args
k: int = 2
Expand Down Expand Up @@ -254,7 +254,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
if args.skip_corrupted and "corrupted" in dataset_name:
if args.skip_corrupted and "corrupted" in dataset_name and"ted2020" not in dataset_name:
continue
try:
if args.adapter_path:
Expand Down Expand Up @@ -407,6 +407,8 @@ def main(args):
# now, compute the intrinsic scores.
results = {}
clfs = {}
if args.return_indices:
indices = {}
# Initialize lists to store scores for each metric across all languages
u_scores, t_scores, punct_scores = [], [], []
u_accs, t_accs, punct_accs = [], [], []
Expand All @@ -419,6 +421,8 @@ def main(args):
print(f"Predicting {lang_code}...")
results[lang_code] = {}
clfs[lang_code] = {}
if args.return_indices:
indices[lang_code] = {}

for dataset_name, dataset in dsets["sentence"].items():
sentences = dataset["data"][: args.max_n_test_sentences]
Expand All @@ -437,7 +441,7 @@ def main(args):
)
if lang_code not in f or dataset_name not in f[lang_code]:
continue

if "train_logits" in f[lang_code][dataset_name] and not args.skip_adaptation:
feature_indices = None
# it is sufficient to feed in 1 long sequence of tokens here since we only use logits for LR
Expand Down Expand Up @@ -489,6 +493,8 @@ def main(args):
score_u = []
acc_u = []
thresholds = []
u_indices, true_indices = [], []
length = []
for i, k_mer in enumerate(sent_k_mers):
start, end = f[lang_code][dataset_name]["test_logit_lengths"][i]
if args.adjust_threshold:
Expand All @@ -504,7 +510,7 @@ def main(args):
thresholds.append(threshold_adjusted)
else:
thresholds.append(args.threshold)
single_score_u, _, info, u_indices, _ = evaluate_mixture(
single_score_u, _, info, cur_u_indices, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:][start:end],
list(k_mer),
Expand All @@ -517,11 +523,16 @@ def main(args):

score_u = np.mean(score_u)
score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None
score_punct = np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None
score_punct = (
np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None
)
acc_u = np.mean(acc_u)
acc_t = np.mean(acc_t) if score_t else None
acc_punct = np.mean(acc_punct) if score_punct else None
threshold = np.mean(thresholds)
u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else [])
true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else [])
length.append(cur_u_indices["length"])

results[lang_code][dataset_name] = {
"u": score_u,
Expand All @@ -534,6 +545,10 @@ def main(args):
"threshold_adj": threshold,
}

if args.return_indices:
indices[lang_code][dataset_name] = {
"u": {"predicted_indices": u_indices, "true_indices": true_indices, "length": length},
}
# just for printing
score_t = score_t or 0.0
score_punct = score_punct or 0.0
Expand Down Expand Up @@ -583,6 +598,19 @@ def main(args):
),
indent=4,
)

if args.return_indices:
json.dump(
indices,
open(
Constants.CACHE_DIR / "intrinsic_pairwise" / f"{save_str}_IDX.json",
"w",
),
default=int,
indent=4,
)
print(Constants.CACHE_DIR / "intrinsic_pairwise" / f"{save_str}_IDX.json")
print("Indices saved to file.")
if not args.keep_logits:
os.remove(f.filename)

Expand Down

0 comments on commit fd6716e

Please sign in to comment.