Skip to content

Commit 9b9eb18

Browse files
committed
add main execution files
1 parent 29f931e commit 9b9eb18

8 files changed

+1219
-0
lines changed

code/__init__.py

Whitespace-only changes.

code/create_embeddings.py

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import os
2+
import sys
3+
4+
from tqdm import tqdm
5+
import numpy as np
6+
7+
import torch
8+
from sentence_transformers import SentenceTransformer
9+
10+
from utils.args import create_embeddings_args
11+
from utils import embed
12+
13+
14+
parser = create_embeddings_args()
15+
args = parser.parse_args()
16+
17+
18+
TORCH_DTYPE = torch.float16 if args.use_amp_fp16 else torch.float32
19+
NUMPY_DTYPE = np.float16 if args.use_amp_fp16 else np.float32
20+
MAX_C4_SUBFILE_IDX = 1023
21+
MAX_WIKIPEDIA_SUBFILE_IDX = 12
22+
MAX_STACKEXCHANGE_SUBFILE_IDX = 13
23+
24+
device = args.device
25+
model = SentenceTransformer(args.model_name).to(device)
26+
db_path = args.saving_path + args.saving_file
27+
28+
if not os.path.exists(args.saving_path):
29+
os.makedirs(args.saving_path)
30+
print(f"Created directory path {args.saving_path}")
31+
32+
if args.group_name == "c4":
33+
c4_subfiles = [
34+
f"{args.data_path}c4-train.{i:05d}-of-01024.json.gz"
35+
for i in range(MAX_C4_SUBFILE_IDX + 1)
36+
]
37+
last_uid = embed.retrieve_last_saved_uid(db_path, args.group_name)
38+
39+
# define next starting point from already processed files in the database
40+
if last_uid is not None:
41+
_, last_subfile_idx_text = last_uid.split(".")
42+
last_subfile_idx = int(last_subfile_idx_text)
43+
44+
if last_subfile_idx == MAX_C4_SUBFILE_IDX:
45+
print("All C4 subfiles have been processed.")
46+
sys.exit(0)
47+
else:
48+
next_subfile_idx = last_subfile_idx + 1
49+
else:
50+
next_subfile_idx = 0
51+
52+
# start processing
53+
for c4_subfile in tqdm(c4_subfiles[next_subfile_idx:], file=sys.stdout):
54+
uid = c4_subfile.split("/")[-1].split("-")[1] # uid = "train.xxxxx"
55+
print(f"Embedding file {uid}.")
56+
samples = embed.get_c4_subfile_texts(c4_subfile)
57+
if samples is not None:
58+
embeddings = embed.embed_samples(
59+
samples=samples,
60+
model=model,
61+
batch_size=args.batch_size,
62+
dtype=TORCH_DTYPE,
63+
to_cpu=True,
64+
)
65+
embeddings = np.array(embeddings, dtype=NUMPY_DTYPE)
66+
embed.write_to_hdf5(
67+
path=db_path,
68+
group_name=args.group_name,
69+
uid=uid,
70+
embeddings=embeddings,
71+
compression=args.compression,
72+
compression_opts=args.compression_opts,
73+
)
74+
elif args.group_name == "wikipedia":
75+
wikipedia_subfiles = [
76+
f"{args.data_path}wiki_{i:02d}.jsonl"
77+
for i in range(MAX_WIKIPEDIA_SUBFILE_IDX + 1)
78+
]
79+
last_uid = embed.retrieve_last_saved_uid(db_path, args.group_name)
80+
81+
# define next starting point from already processed files in the database
82+
if last_uid is not None:
83+
_, last_subfile_idx_text = last_uid.split("_")
84+
last_subfile_idx = int(last_subfile_idx_text)
85+
86+
if last_subfile_idx == MAX_WIKIPEDIA_SUBFILE_IDX:
87+
print("All Wikipedia subfiles have been processed.")
88+
sys.exit(0)
89+
else:
90+
next_subfile_idx = last_subfile_idx + 1
91+
else:
92+
next_subfile_idx = 0
93+
94+
# start processing
95+
for wikipedia_subfile in tqdm(
96+
wikipedia_subfiles[next_subfile_idx:], file=sys.stdout
97+
):
98+
uid = wikipedia_subfile.split("/")[-1].split(".")[0] # uid = "wiki_xx"
99+
print(f"Embedding file {uid}.")
100+
samples = embed.get_jsonl_subfile_texts(wikipedia_subfile)
101+
if samples is not None:
102+
embeddings = embed.embed_samples(
103+
samples=samples,
104+
model=model,
105+
batch_size=args.batch_size,
106+
dtype=TORCH_DTYPE,
107+
to_cpu=True,
108+
)
109+
embeddings = np.array(embeddings, dtype=NUMPY_DTYPE)
110+
embed.write_to_hdf5(
111+
path=db_path,
112+
group_name=args.group_name,
113+
uid=uid,
114+
embeddings=embeddings,
115+
compression=args.compression,
116+
compression_opts=args.compression_opts,
117+
)
118+
elif args.group_name == "wikihow":
119+
# there is only one file for wikihow
120+
last_uid = embed.retrieve_last_saved_uid(db_path, args.group_name)
121+
122+
if last_uid is not None:
123+
print("Wikihow has been processed.")
124+
sys.exit(0)
125+
else:
126+
next_subfile_idx = 0
127+
128+
# start processing
129+
uid = args.data_path.split("/")[-1].split(".")[0] # uid = "train"
130+
print(f"Embedding file {uid}.")
131+
samples = embed.get_jsonl_subfile_texts(args.data_path)
132+
if samples is not None:
133+
embeddings = embed.embed_samples(
134+
samples=samples,
135+
model=model,
136+
batch_size=args.batch_size,
137+
dtype=TORCH_DTYPE,
138+
to_cpu=True,
139+
)
140+
embeddings = np.array(embeddings, dtype=NUMPY_DTYPE)
141+
embed.write_to_hdf5(
142+
path=db_path,
143+
group_name=args.group_name,
144+
uid=uid,
145+
embeddings=embeddings,
146+
compression=args.compression,
147+
compression_opts=args.compression_opts,
148+
)
149+
elif args.group_name == "stackexchange":
150+
stackexchange_subfiles = [
151+
f"{args.data_path}stack_{i:02d}.jsonl"
152+
for i in range(MAX_STACKEXCHANGE_SUBFILE_IDX + 1)
153+
]
154+
last_uid = embed.retrieve_last_saved_uid(db_path, args.group_name)
155+
156+
# define next starting point from already processed files in the database
157+
if last_uid is not None:
158+
_, last_subfile_idx_text = last_uid.split("_")
159+
last_subfile_idx = int(last_subfile_idx_text)
160+
161+
if last_subfile_idx == MAX_STACKEXCHANGE_SUBFILE_IDX:
162+
print("All Stackexchange subfiles have been processed.")
163+
sys.exit(0)
164+
else:
165+
next_subfile_idx = last_subfile_idx + 1
166+
else:
167+
next_subfile_idx = 0
168+
169+
# start processing
170+
for stackexchange_subfile in tqdm(
171+
stackexchange_subfiles[next_subfile_idx:], file=sys.stdout
172+
):
173+
uid = stackexchange_subfile.split("/")[-1].split(".")[0] # uid = "stack_xx"
174+
print(f"Embedding file {uid}.")
175+
samples = embed.get_jsonl_subfile_texts(stackexchange_subfile)
176+
if samples is not None:
177+
embeddings = embed.embed_samples(
178+
samples=samples,
179+
model=model,
180+
batch_size=args.batch_size,
181+
dtype=TORCH_DTYPE,
182+
to_cpu=True,
183+
)
184+
embeddings = np.array(embeddings, dtype=NUMPY_DTYPE)
185+
embed.write_to_hdf5(
186+
path=db_path,
187+
group_name=args.group_name,
188+
uid=uid,
189+
embeddings=embeddings,
190+
compression=args.compression,
191+
compression_opts=args.compression_opts,
192+
)

code/create_task_samples.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
import json
3+
4+
from sklearn.model_selection import train_test_split
5+
from datasets import Dataset
6+
from thefuzz import fuzz
7+
8+
from utils.args import create_task_samples_args
9+
from utils import common as c
10+
from utils.ts_creation import deduplicate
11+
from utils.ts_creation import (
12+
MetaInstructions,
13+
FormatExtractor,
14+
generate_few_shots,
15+
check_prompt_length,
16+
)
17+
18+
19+
parser = create_task_samples_args()
20+
args = parser.parse_args()
21+
22+
if args.task in ["bioqa", "medqa"]:
23+
prompt_instruction = MetaInstructions.QA_MC_INSTRUCTION
24+
extract_fn = FormatExtractor.qa_mc
25+
elif args.task == "csqa":
26+
prompt_instruction = [
27+
MetaInstructions.QA_YN_INSTRUCTION_Q,
28+
MetaInstructions.QA_YN_INSTRUCTION_S,
29+
]
30+
extract_fn = FormatExtractor.qa_yn
31+
elif args.task == "recipegen":
32+
prompt_instruction = MetaInstructions.RECIPEGEN_INSTRUCTION
33+
extract_fn = FormatExtractor.recipe
34+
elif args.task == "summarization":
35+
prompt_instruction = MetaInstructions.SUMMARIZATION_INSTRUCTION
36+
extract_fn = FormatExtractor.summarization
37+
else:
38+
raise ValueError("Unknown task or no instruction prompt found.")
39+
40+
41+
configs = c.get_configs(args, sampling=True)
42+
model = c.load_vllm_model(args)
43+
44+
few_shots = [fs for fs in c.jsonl_generator(args.few_shot_path, return_string=False)]
45+
corpus_samples = [
46+
ex for ex in c.jsonl_generator(args.corpus_samples_path, return_string=False)
47+
]
48+
49+
# prepare all few_shot + corpus combinations
50+
prompts = [
51+
generate_few_shots(
52+
prompt_instruction=prompt_instruction,
53+
corpus_example=sample,
54+
few_shots=few_shots,
55+
task=args.task,
56+
num_shots=args.num_shots,
57+
)
58+
for sample in corpus_samples
59+
]
60+
prompts = check_prompt_length(args, prompts, max_length=args.max_tokenization_length)
61+
print(f"Number of valid prompts to generate task samples from: {len(prompts)}")
62+
63+
generated = c.vllm_generate(prompts, model, configs["sampling_config"])
64+
task_samples = [{"task_sample": task_sample} for task_sample in generated]
65+
66+
with open(args.output_path_raw, "w") as f:
67+
for sample in task_samples:
68+
f.write(json.dumps(sample) + "\n")
69+
print(f"Finished saving {len(task_samples)} raw, unfiltered task samples.")
70+
71+
print("Starting filtering and cleaning of task samples...")
72+
valid_task_samples = []
73+
format_errors = ["index,exception\n"]
74+
for i, sample in enumerate(task_samples):
75+
try:
76+
valid_task_samples.append(extract_fn(sample))
77+
except Exception as e:
78+
format_errors.append(f"{i},{e}\n")
79+
continue
80+
print(
81+
f"Removed {len(task_samples) - len(valid_task_samples)} samples due to formatting errors."
82+
)
83+
84+
with open(args.output_path_error_msgs, "w") as csvfile:
85+
csvfile.writelines(format_errors)
86+
print("Saved extraction format error messages as a CSV file.")
87+
88+
89+
# two-step fuzzy deduplication
90+
# step 1: filter out task samples that are too similar to human few-shots
91+
few_shot_strings = [extract_fn(s, is_few_shot=True) for s in few_shots]
92+
filtered_1 = [
93+
s
94+
for s in valid_task_samples
95+
if max(fuzz.token_set_ratio(s, fss) for fss in few_shot_strings)
96+
< args.deduplication_ratio
97+
]
98+
len_filtered_1 = len(filtered_1)
99+
print(
100+
f"Removed {len(valid_task_samples) - len_filtered_1} samples due to similarity with few-shots."
101+
)
102+
103+
# step 2: deduplicate task samples among themselves
104+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
105+
filtered_2 = deduplicate(filtered_1, ratio=args.deduplication_ratio)
106+
print(
107+
f"\nRemoved {len_filtered_1 - len(filtered_2)} samples due to similarity among themselves."
108+
)
109+
110+
with open(args.output_path_clean, "w") as f:
111+
for sample in filtered_2:
112+
f.write(json.dumps(sample) + "\n")
113+
print(f"Finished saving {len(filtered_2)} clean and filtered task samples.")
114+
115+
num_final = args.num_final_task_samples - len(few_shots)
116+
filtered_3, _ = train_test_split(filtered_2, train_size=num_final)
117+
final_task_samples = [
118+
{**extract_fn(s, return_dict=True), "is_few_shot": 0} for s in filtered_3 # type: ignore
119+
]
120+
final_task_samples += [
121+
{**extract_fn(fs, return_dict=True), "is_few_shot": 1} for fs in few_shot_strings # type: ignore
122+
]
123+
124+
ds = Dataset.from_list(final_task_samples)
125+
ds.save_to_disk(args.output_path_final)
126+
print(f"Finished saving {len(final_task_samples)} final task samples.")

code/retrieve_docs.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
from time import time
3+
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from sentence_transformers import SentenceTransformer
7+
8+
from utils.embed import embed_samples
9+
from utils.args import retrieve_docs_args
10+
from utils.common import jsonl_generator
11+
from utils import retrieve
12+
13+
14+
parser = retrieve_docs_args()
15+
args = parser.parse_args()
16+
17+
start = time()
18+
device = args.device
19+
model = SentenceTransformer(args.model_name).to(device)
20+
TORCH_DTYPE = torch.float16 if args.use_amp_fp16 else torch.float32
21+
22+
# load all few_shot texts
23+
few_shots = [fs for fs in jsonl_generator(args.few_shot_path, return_string=True)]
24+
print("Starting embedding of few-shot samples...")
25+
fs_embeddings = embed_samples(
26+
samples=few_shots, model=model, batch_size=32, dtype=TORCH_DTYPE, to_cpu=False
27+
).to(TORCH_DTYPE)
28+
print("Finished embedding few-shot samples.")
29+
30+
print("Starting calculation of top-p similarities...")
31+
top_p_out = retrieve.top_p_similarities(
32+
path=args.database_path,
33+
fs_embeddings=fs_embeddings,
34+
p=args.top_p_percentile,
35+
device=args.device,
36+
)
37+
print("Finished calculation of top-p similarities.")
38+
39+
print("Starting calculation of top-k final document similarities...")
40+
indices_per_subfile = retrieve.top_k_indices(
41+
sim_metadata=top_p_out.sim_metadata,
42+
sim_values=top_p_out.sim_values,
43+
sim_idxs=top_p_out.sim_idxs,
44+
sim_values_mean=top_p_out.sim_values_mean,
45+
sim_idxs_mean=top_p_out.sim_idxs_mean,
46+
k=args.num_samples_to_retrieve,
47+
)
48+
num_samples = sum(len(l) for l in indices_per_subfile)
49+
print("Finished calculation of top-k final document similarities.")
50+
print(f"Identified {num_samples} documents.")
51+
paths = retrieve.build_paths(top_p_out.sim_metadata, args)
52+
assert (
53+
len(top_p_out[0]) == len(indices_per_subfile) == len(paths)
54+
), "Contents of metadata, indices per subfile, and paths do not match."
55+
56+
# retrieve final text documents from corpora based on selected top indices
57+
print("Retrieving top-k documents from files...")
58+
59+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
60+
most_similar_json = []
61+
dataset = retrieve.JsonDataset(top_p_out.sim_metadata, indices_per_subfile, paths)
62+
dataloader = DataLoader(dataset, batch_size=1, num_workers=2)
63+
# open files and extract jsons from subfiles in parallel
64+
for i, batch in enumerate(dataloader):
65+
most_similar_json.extend([t[0] for t in batch])
66+
print(f"Retrieved docs from file {i} of {len(top_p_out.sim_metadata)}.", end="\r")
67+
end = time()
68+
print("\nFinished retrieving top-k documents.")
69+
print(f"Total retrieval took {(end-start)/60:.2f} minutes.")
70+
71+
print("Start saving retrieved documents...")
72+
with open(args.saving_path, "w") as jsonl_file:
73+
for json_obj in most_similar_json:
74+
jsonl_file.write(json_obj)
75+
print(f"Finished saving retrieved documents at {args.saving_path}.")

0 commit comments

Comments
 (0)