Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jun 5, 2024
1 parent e6a31e7 commit b4e70c5
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions wtpsplit/evaluation/llm_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import replicate

from wtpsplit.evaluation import get_labels, evaluate_sentences_llm
from wtpsplit.evaluation.intrinsic_pairwise import generate_k_mers
from wtpsplit.utils import Constants
import time

Expand Down Expand Up @@ -49,7 +50,7 @@
@dataclass
class Args:
eval_data_path: str = "data/all_data_11_05"
type: str = "lyrics" # all, lyrics
type: str = "lyrics" # all, lyrics, pairs, short_proc
llm_provider: str = "cohere" # cohere, replicate
label_delimiter: str = "|" # NOT \n or \n\n
gap_char = "@"
Expand All @@ -69,13 +70,13 @@ def replicate_provider(text, train_data, lang_code, args):
llm_prompt = prompt_factory(text, train_data, lang_code, args)
# print(llm_prompt)
n_tries = 0
while n_tries < 100:
while n_tries < 1:
try:
llm_input = {
"system_prompt": "",
"prompt": llm_prompt,
# "max_new_tokens": 50_000,
"max_tokens": 50_000,
"max_tokens": 4000,
}
llm_output = api.run(args.model, llm_input)
llm_output = "".join(llm_output)
Expand Down Expand Up @@ -217,6 +218,11 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
continue
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
if args.type == "pairs" and dataset_name != "ersatz" and dataset_name != "ted2020-corrupted-asr":
continue
if (args.k != 10 or args.n_shots != 0) and dataset_name != "ersatz":
print("SKIP: ", lang_code, dataset_name)
continue
else:
dset_group = lang_group[dataset_name]
if "test_preds" not in dset_group and "test_preds_0" not in dset_group:
Expand All @@ -227,6 +233,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
isinstance(test_sentences[0], list)
and "lyrics" not in dataset_name
and "short" not in dataset_name
and args.type != "pairs"
):
# documents: only 10% of documents. 1000 sentences --> 100 docs
max_n_sentences = args.max_n_test_sentences // 10
Expand All @@ -236,9 +243,18 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
else:
max_n_sentences = args.max_n_test_sentences
test_sentences = test_sentences[:max_n_sentences]
if isinstance(test_sentences[0], list):
if isinstance(test_sentences[0], list) or args.type == "pairs":
if args.type == "pairs":
all_pairs = generate_k_mers(
test_sentences,
k=2,
do_lowercase=False,
do_remove_punct=False,
sample_pct=0.5
)
test_sentences = all_pairs
# list of lists: chunk each sublist
if "short" in dataset_name or "lyrics" in dataset_name:
if "short" in dataset_name or "lyrics" in dataset_name or args.type == "pairs":
# only here: no chunking
test_chunks = test_sentences
test_texts = [
Expand All @@ -263,7 +279,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
if args.n_shots:
train_sentences = eval_data[lang_code]["sentence"][dataset_name]["meta"]["train_data"][:100]
if train_sentences:
if "short" in dataset_name:
if "short" in dataset_name or args.type == "pairs":
# here: entire samples (tweets e.g.)
train_chunks = train_sentences
train_texts = ["\n".join(train_chunk).strip() for train_chunk in train_chunks]
Expand All @@ -288,7 +304,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
dset_group.create_dataset(
f"test_chunks_{i}",
data=[test_chunks[i]]
if "short" in dataset_name or "lyrics" in dataset_name
if "short" in dataset_name or "lyrics" in dataset_name or args.type == "pairs"
else test_chunks[i],
)

Expand Down Expand Up @@ -342,7 +358,7 @@ def prompt_factory(test_chunk, train_data, lang_code, args):

prompt_start = (
main_prompt
+ f"When provided with multiple examples, you are to respond only to the last one: # Output {n_shots + 1}."
+ f"When provided with multiple examples, you are to respond only to the last one: Output {n_shots + 1}."
if n_shots
else main_prompt
)
Expand Down Expand Up @@ -383,6 +399,8 @@ def postprocess_llm_output(llm_output, lang):
llm_output = llm_output.replace(args.label_delimiter, " ")
llm_output = llm_output.replace("\n\n", args.label_delimiter)
llm_output = llm_output.replace("\n", args.label_delimiter)
# replace multiple newlines with 1
llm_output = re.sub(r"\n+", "\n", llm_output)

# remove leading #, # Input, :
llm_output = llm_output.strip("#").strip().strip("Input").strip(":").strip()
Expand Down Expand Up @@ -537,10 +555,12 @@ def main(args):
default_dir.mkdir(parents=True, exist_ok=True)
alignment_dir.mkdir(parents=True, exist_ok=True)

if args.type == "all":
if args.type == "all" or args.type == "pairs":
eval_data_path = args.eval_data_path + "-all.pth"
elif args.type == "lyrics":
eval_data_path = args.eval_data_path + "-lyrics.pth"
elif args.type == "short_proc":
eval_data_path = args.eval_data_path + "-short_proc.pth"
else:
raise ValueError(f"Unknown type: {args.type}")

Expand Down Expand Up @@ -620,7 +640,7 @@ def concatenate_texts(group):
for dataset_name in df["dataset_name"].unique():
results[lang_code][dataset_name] = {args.model: {}} # Initialize nested dict with model
indices[lang_code][dataset_name] = {args.model: {}}
if "lyrics" in dataset_name or "short" in dataset_name:
if "lyrics" in dataset_name or "short" in dataset_name or args.type == "pairs":
exclude_every_k = 0
else:
exclude_every_k = args.k
Expand Down Expand Up @@ -708,7 +728,7 @@ def concatenate_texts(group):
"success_rate": len(df[df["test_preds"] != ""]) / len(df),
"model": args.model,
"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"system_prompt": SYSTEM_PROMPT if args.type == "all" else LYRICS_PROMPT,
"system_prompt": LYRICS_PROMPT if args.type == "lyrics" else SYSTEM_PROMPT,
}

json.dump(
Expand Down

0 comments on commit b4e70c5

Please sign in to comment.