From b4e70c52cbfde1dc5d551caee5849d0a84526953 Mon Sep 17 00:00:00 2001 From: markus583 Date: Wed, 5 Jun 2024 13:09:07 +0000 Subject: [PATCH] up --- wtpsplit/evaluation/llm_sentence.py | 42 +++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/wtpsplit/evaluation/llm_sentence.py b/wtpsplit/evaluation/llm_sentence.py index 8d03867a..23714762 100644 --- a/wtpsplit/evaluation/llm_sentence.py +++ b/wtpsplit/evaluation/llm_sentence.py @@ -19,6 +19,7 @@ import replicate from wtpsplit.evaluation import get_labels, evaluate_sentences_llm +from wtpsplit.evaluation.intrinsic_pairwise import generate_k_mers from wtpsplit.utils import Constants import time @@ -49,7 +50,7 @@ @dataclass class Args: eval_data_path: str = "data/all_data_11_05" - type: str = "lyrics" # all, lyrics + type: str = "lyrics" # all, lyrics, pairs, short_proc llm_provider: str = "cohere" # cohere, replicate label_delimiter: str = "|" # NOT \n or \n\n gap_char = "@" @@ -69,13 +70,13 @@ def replicate_provider(text, train_data, lang_code, args): llm_prompt = prompt_factory(text, train_data, lang_code, args) # print(llm_prompt) n_tries = 0 - while n_tries < 100: + while n_tries < 1: try: llm_input = { "system_prompt": "", "prompt": llm_prompt, # "max_new_tokens": 50_000, - "max_tokens": 50_000, + "max_tokens": 4000, } llm_output = api.run(args.model, llm_input) llm_output = "".join(llm_output) @@ -217,6 +218,11 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): continue if dataset_name not in lang_group: dset_group = lang_group.create_group(dataset_name) + if args.type == "pairs" and dataset_name != "ersatz" and dataset_name != "ted2020-corrupted-asr": + continue + if (args.k != 10 or args.n_shots != 0) and dataset_name != "ersatz": + print("SKIP: ", lang_code, dataset_name) + continue else: dset_group = lang_group[dataset_name] if "test_preds" not in dset_group and "test_preds_0" not in dset_group: @@ -227,6 +233,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): isinstance(test_sentences[0], list) and "lyrics" not in dataset_name and "short" not in dataset_name + and args.type != "pairs" ): # documents: only 10% of documents. 1000 sentences --> 100 docs max_n_sentences = args.max_n_test_sentences // 10 @@ -236,9 +243,18 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): else: max_n_sentences = args.max_n_test_sentences test_sentences = test_sentences[:max_n_sentences] - if isinstance(test_sentences[0], list): + if isinstance(test_sentences[0], list) or args.type == "pairs": + if args.type == "pairs": + all_pairs = generate_k_mers( + test_sentences, + k=2, + do_lowercase=False, + do_remove_punct=False, + sample_pct=0.5 + ) + test_sentences = all_pairs # list of lists: chunk each sublist - if "short" in dataset_name or "lyrics" in dataset_name: + if "short" in dataset_name or "lyrics" in dataset_name or args.type == "pairs": # only here: no chunking test_chunks = test_sentences test_texts = [ @@ -263,7 +279,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): if args.n_shots: train_sentences = eval_data[lang_code]["sentence"][dataset_name]["meta"]["train_data"][:100] if train_sentences: - if "short" in dataset_name: + if "short" in dataset_name or args.type == "pairs": # here: entire samples (tweets e.g.) train_chunks = train_sentences train_texts = ["\n".join(train_chunk).strip() for train_chunk in train_chunks] @@ -288,7 +304,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): dset_group.create_dataset( f"test_chunks_{i}", data=[test_chunks[i]] - if "short" in dataset_name or "lyrics" in dataset_name + if "short" in dataset_name or "lyrics" in dataset_name or args.type == "pairs" else test_chunks[i], ) @@ -342,7 +358,7 @@ def prompt_factory(test_chunk, train_data, lang_code, args): prompt_start = ( main_prompt - + f"When provided with multiple examples, you are to respond only to the last one: # Output {n_shots + 1}." + + f"When provided with multiple examples, you are to respond only to the last one: Output {n_shots + 1}." if n_shots else main_prompt ) @@ -383,6 +399,8 @@ def postprocess_llm_output(llm_output, lang): llm_output = llm_output.replace(args.label_delimiter, " ") llm_output = llm_output.replace("\n\n", args.label_delimiter) llm_output = llm_output.replace("\n", args.label_delimiter) + # replace multiple newlines with 1 + llm_output = re.sub(r"\n+", "\n", llm_output) # remove leading #, # Input, : llm_output = llm_output.strip("#").strip().strip("Input").strip(":").strip() @@ -537,10 +555,12 @@ def main(args): default_dir.mkdir(parents=True, exist_ok=True) alignment_dir.mkdir(parents=True, exist_ok=True) - if args.type == "all": + if args.type == "all" or args.type == "pairs": eval_data_path = args.eval_data_path + "-all.pth" elif args.type == "lyrics": eval_data_path = args.eval_data_path + "-lyrics.pth" + elif args.type == "short_proc": + eval_data_path = args.eval_data_path + "-short_proc.pth" else: raise ValueError(f"Unknown type: {args.type}") @@ -620,7 +640,7 @@ def concatenate_texts(group): for dataset_name in df["dataset_name"].unique(): results[lang_code][dataset_name] = {args.model: {}} # Initialize nested dict with model indices[lang_code][dataset_name] = {args.model: {}} - if "lyrics" in dataset_name or "short" in dataset_name: + if "lyrics" in dataset_name or "short" in dataset_name or args.type == "pairs": exclude_every_k = 0 else: exclude_every_k = args.k @@ -708,7 +728,7 @@ def concatenate_texts(group): "success_rate": len(df[df["test_preds"] != ""]) / len(df), "model": args.model, "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "system_prompt": SYSTEM_PROMPT if args.type == "all" else LYRICS_PROMPT, + "system_prompt": LYRICS_PROMPT if args.type == "lyrics" else SYSTEM_PROMPT, } json.dump(