-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 60e11b1
Showing
11 changed files
with
1,831 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[codestyle] | ||
indentation = True | ||
|
||
[main] | ||
version = 0.1.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[encoding] | ||
text_encoding = utf-8 | ||
|
||
[main] | ||
version = 0.1.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[vcs] | ||
use_version_control = False | ||
version_control_system = | ||
|
||
[main] | ||
version = 0.1.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[workspace] | ||
restore_data_on_startup = True | ||
save_data_on_exit = True | ||
save_history = True | ||
save_non_project_files = False | ||
|
||
[main] | ||
version = 0.1.0 | ||
recent_files = ['/home/svjack/.config/spyder-py3/temp.py', '/home/svjack/temp_dir/bi_cross_model/script/bi-encoder-batch.py', '/home/svjack/temp_dir/bi_cross_model/script/bi-encoder-data.py', '/home/svjack/temp_dir/bi_cross_model/script/bi_encoder/bi-encoder-batch.py', '/home/svjack/temp_dir/bi_cross_model/script/bi_encoder/bi-encoder-data.py', '/home/svjack/temp_dir/bi_cross_model/download.sh', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_data_prepare_train.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_data_prepare.py', '/home/svjack/temp_dir/bi_cross_model/script/try_sbert_neg_sampler.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_enccoder_v0.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/try_sbert_neg_sampler.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/try_sbert_neg_sampler_valid.py', '/home/svjack/temp_dir/bi_cross_model/script/choose_right_params.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_random_on_multi_eval.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/cross_encoder_random_on_multi_eval.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/valid_cross_encoder_on_bi_encoder.py'] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
editdistance==0.5.3 | ||
elasticsearch==7.8.1 | ||
elasticsearch-dbapi==0.1.3 | ||
es-pandas==0.0.16 | ||
progressbar2==3.53.1 | ||
seaborn==0.10.1 | ||
sentence-transformers==0.3.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
import gzip | ||
import logging | ||
import math | ||
import os | ||
import random | ||
import tarfile | ||
from collections import defaultdict | ||
from datetime import datetime | ||
from glob import glob | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from sentence_transformers import (LoggingHandler, SentenceTransformer, | ||
evaluation, losses, models, util) | ||
from torch.utils.data import DataLoader, Dataset, IterableDataset | ||
|
||
#train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join(os.path.abspath(""), "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"]) | ||
train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join("..data/", "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"]) | ||
|
||
|
||
class TripletsDataset(Dataset): | ||
def __init__(self, model, qa_df): | ||
assert set(["question", "answer", "q_idx"]).intersection(set(qa_df.columns.tolist())) == set(["question", "answer", "q_idx"]) | ||
self.model = model | ||
self.qa_df = qa_df | ||
self.q_idx_set = set(qa_df["q_idx"].value_counts().index.tolist()) | ||
|
||
def __getitem__(self, index): | ||
#raise NotImplementedError | ||
label = torch.tensor(1, dtype=torch.long) | ||
choice_s = self.qa_df.iloc[index] | ||
query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"] | ||
query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"] | ||
neg_q_idx = np.random.choice(list(self.q_idx_set.difference(set([q_idx])))) | ||
neg_text = self.qa_df[self.qa_df["q_idx"] == neg_q_idx].sample()["answer"].iloc[0] | ||
return [self.model.tokenize(query_text), self.model.tokenize(pos_text), self.model.tokenize(neg_text)], label | ||
#return (query_text, pos_text, q_idx) | ||
|
||
def __len__(self): | ||
return self.qa_df.shape[0] | ||
|
||
|
||
class NoSameLabelsBatchSampler: | ||
def __init__(self, dataset, batch_size): | ||
self.dataset = dataset | ||
self.idx_org = list(range(len(dataset))) | ||
random.shuffle(self.idx_org) | ||
self.idx_copy = self.idx_org.copy() | ||
self.batch_size = batch_size | ||
|
||
def __iter__(self): | ||
batch = [] | ||
labels = set() | ||
num_miss = 0 | ||
|
||
num_batches_returned = 0 | ||
while num_batches_returned < self.__len__(): | ||
if len(self.idx_copy) == 0: | ||
random.shuffle(self.idx_org) | ||
self.idx_copy = self.idx_org.copy() | ||
|
||
idx = self.idx_copy.pop(0) | ||
#label = self.dataset[idx][1].cpu().tolist() | ||
label = self.dataset.qa_df["q_idx"].iloc[idx] | ||
|
||
if label not in labels: | ||
num_miss = 0 | ||
batch.append(idx) | ||
labels.add(label) | ||
if len(batch) == self.batch_size: | ||
yield batch | ||
batch = [] | ||
labels = set() | ||
num_batches_returned += 1 | ||
else: | ||
num_miss += 1 | ||
self.idx_copy.append(idx) #Add item again to the end | ||
|
||
if num_miss >= len(self.idx_copy): #To many failures, flush idx_copy and start with clean | ||
self.idx_copy = [] | ||
|
||
def __len__(self): | ||
return math.ceil(len(self.dataset) / self.batch_size) | ||
|
||
|
||
def transform_part_df_into_Evaluator_format(part_df): | ||
req = part_df.copy() | ||
req["qid"] = req["question"].fillna("").map(hash).map(str) | ||
req["cid"] = req["answer"].fillna("").map(hash).map(str) | ||
queries = dict(map(tuple ,req[["qid", "question"]].drop_duplicates().values.tolist())) | ||
corpus = dict(map(tuple ,req[["cid", "answer"]].drop_duplicates().values.tolist())) | ||
qid_cid_set_df = req[["qid", "cid"]].groupby("qid")["cid"].apply(set).apply(sorted).apply(tuple).reset_index() | ||
qid_cid_set_df.columns = ["qid", "cid_set"] | ||
relevant_docs = dict(map(tuple ,qid_cid_set_df.drop_duplicates().values.tolist())) | ||
relevant_docs = dict(map(lambda t2: (t2[0], set(t2[1])) ,relevant_docs.items())) | ||
return queries, corpus, relevant_docs | ||
|
||
|
||
dev_queries, dev_corpus, dev_rel_docs = transform_part_df_into_Evaluator_format(valid_part.sample(frac=0.1)) | ||
ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, dev_corpus, dev_rel_docs, name='ms-marco-train_eval', batch_size=2) | ||
|
||
|
||
|
||
model_str = "xlm-roberta-base" | ||
#word_embedding_model = models.Transformer(model_str, max_seq_length=512) | ||
word_embedding_model = models.Transformer(model_str, max_seq_length=256) | ||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) | ||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | ||
|
||
|
||
|
||
train_dataset = TripletsDataset(model=model, qa_df = train_part.sample(frac = 1.0, replace=False)) | ||
bs_obj = NoSameLabelsBatchSampler(train_dataset, batch_size=8) | ||
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=1, batch_sampler=bs_obj, num_workers=1) | ||
train_loss = losses.MultipleNegativesRankingLoss(model=model) | ||
|
||
|
||
model_save_path = os.path.join(os.path.abspath(""), "bi_encoder_save") | ||
if not os.path.exists(model_save_path): | ||
os.mkdir(model_save_path) | ||
|
||
|
||
model.fit(train_objectives=[(train_dataloader, train_loss)], | ||
evaluator=ir_evaluator, | ||
epochs=10, | ||
warmup_steps=1000, | ||
output_path=model_save_path, | ||
evaluation_steps=5000, | ||
use_amp=True | ||
) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
import os | ||
from copy import deepcopy | ||
from functools import reduce | ||
from glob import glob | ||
|
||
import editdistance | ||
import numpy as np | ||
import pandas as pd | ||
|
||
###https://github.com/brightmart/nlp_chinese_corpus | ||
###https://github.com/brightmart/nlp_chinese_corpus#4%E7%A4%BE%E5%8C%BA%E9%97%AE%E7%AD%94json%E7%89%88webtext2019zh-%E5%A4%A7%E8%A7%84%E6%A8%A1%E9%AB%98%E8%B4%A8%E9%87%8F%E6%95%B0%E6%8D%AE%E9%9B%86 | ||
###https://drive.google.com/open?id=1u2yW_XohbYL2YAK6Bzc5XrngHstQTf0v | ||
|
||
data_dir = r"/home/svjack/temp_dir/webtext2019zh" | ||
json_files = glob(os.path.join(data_dir, "*.json")) | ||
train_json = list(filter(lambda path: "train" in path.lower(), json_files))[0] | ||
def json_reader(path, chunksize = 100): | ||
assert path.endswith(".json") | ||
return pd.read_json(path, lines = True, chunksize=chunksize) | ||
|
||
train_reader = json_reader(train_json, chunksize=10000) | ||
times = 100 | ||
df_list = [] | ||
for i, df in enumerate(train_reader): | ||
df_list.append(df) | ||
if i + 1 >= times: | ||
break | ||
|
||
train_head_df = pd.concat(df_list, axis = 0) | ||
content_len_df = pd.concat([train_head_df["content"], train_head_df["content"].map(len)], axis = 1) | ||
content_len_df.columns = ["content", "c_len"] | ||
|
||
|
||
qa_df = train_head_df[["title", "content"]].copy() | ||
qa_df = qa_df.rename(columns = {"title": "question", "content": "answer"}).fillna("") | ||
|
||
qa_df = qa_df[qa_df["question"].map(len) <= 500] | ||
qa_df = qa_df[qa_df["answer"].map(len) <= 500] | ||
|
||
|
||
quests = deepcopy(qa_df["question"]) | ||
question_cmp = pd.concat([quests.sort_values().shift(1), quests.sort_values()], axis = 1) | ||
question_cmp["edit_val"] = question_cmp.fillna("").apply(lambda s: editdistance.eval(s.iloc[0], s.iloc[1]) / (len(s.iloc[0]) + len(s.iloc[1])), axis = 1) | ||
question_cmp.columns = ["q0", "q1", "edit_val"] | ||
|
||
threshold = 0.2 | ||
question_nest_list = [[]] | ||
for idx ,r in question_cmp.iterrows(): | ||
q0, q1, v = r.iloc[0], r.iloc[1], r.iloc[2] | ||
if v < threshold: | ||
question_nest_list[-1].append(q0) | ||
question_nest_list[-1].append(q1) | ||
else: | ||
question_nest_list.append([]) | ||
|
||
|
||
idx_question_df_zip = pd.DataFrame(list(map(lambda x: [x] ,question_nest_list))) | ||
|
||
idx_question_df_zip = idx_question_df_zip[idx_question_df_zip.iloc[:, 0].map(len) > 0] | ||
idx_question_df_zip.columns = ["question"] | ||
idx_question_df_zip["q_idx"] = np.arange(idx_question_df_zip.shape[0]).tolist() | ||
|
||
idx_question_df = idx_question_df_zip.explode("question") | ||
|
||
#idx_question_df = pd.DataFrame(reduce(lambda a, b: a + b, map(lambda idx: list(map(lambda q: (idx, q), question_nest_list[idx])), range(len(question_nest_list))))) | ||
#idx_question_df.columns = ["q_idx", "question"] | ||
#idx_question_df.drop_duplicates().to_csv(os.path.join("/home/svjack/temp_dir/", "idx_question_df.csv"), index = False) | ||
|
||
idx_question_df_dd = idx_question_df.drop_duplicates() | ||
|
||
|
||
|
||
qa_df_dd = qa_df.drop_duplicates() | ||
cat_qa_df_with_idx = pd.merge(qa_df_dd, idx_question_df_dd, on = "question", how = "inner") | ||
q_idx_set = set(cat_qa_df_with_idx["q_idx"].value_counts().index.tolist()) | ||
|
||
q_idx_size_bigger_or_eql_3 = ((cat_qa_df_with_idx["q_idx"].value_counts() >= 3).reset_index()).groupby("q_idx")["index"].apply(set).apply(list)[True] | ||
q_idx_size_bigger_or_eql_3_df = cat_qa_df_with_idx[cat_qa_df_with_idx["q_idx"].isin(q_idx_size_bigger_or_eql_3)].copy() | ||
|
||
|
||
def produce_label_list(length = 10, p_list = [0.1, 0.1, 0.8]): | ||
from functools import reduce | ||
assert sum(p_list) == 1 | ||
p_array = np.asarray(p_list) | ||
assert all((p_array[:-1] <= p_array[1:]).astype(bool).tolist()) | ||
num_array = (p_array * length).astype(np.int32) | ||
num_list = num_array.tolist() | ||
num_list = list(map(lambda x: max(x, 1), num_list)) | ||
num_list[-1] = length - sum(num_list[:-1]) | ||
return np.random.permutation(reduce(lambda a, b: a + b ,map(lambda idx: [idx] * num_list[idx], range(len(p_list))))) | ||
|
||
q_idx_size_bigger_or_eql_3_df["r_idx"] = q_idx_size_bigger_or_eql_3_df.index.tolist() | ||
|
||
def map_r_idx_list_to_split_label_zip(r_idx_list): | ||
split_label_list = produce_label_list(len(r_idx_list)) | ||
assert len(split_label_list) == len(r_idx_list) | ||
return zip(*[r_idx_list, split_label_list]) | ||
|
||
r_idx_split_label_items = reduce(lambda a, b: a + b ,q_idx_size_bigger_or_eql_3_df.groupby("q_idx")["r_idx"].apply(set).apply(list).apply(map_r_idx_list_to_split_label_zip).apply(list).tolist()) | ||
r_idx_split_label_df = pd.DataFrame(r_idx_split_label_items) | ||
r_idx_split_label_df.columns = ["r_idx", "split_label"] | ||
assert r_idx_split_label_df.shape[0] == pd.merge(q_idx_size_bigger_or_eql_3_df, r_idx_split_label_df, on = "r_idx", how = "inner").shape[0] | ||
|
||
q_idx_size_bigger_or_eql_3_df_before_split = pd.merge(q_idx_size_bigger_or_eql_3_df, r_idx_split_label_df, on = "r_idx", how = "inner") | ||
train_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 2].copy() | ||
train_part = pd.concat([train_part, cat_qa_df_with_idx[(1 - cat_qa_df_with_idx["q_idx"].isin(q_idx_size_bigger_or_eql_3)).astype(bool)].copy()], axis = 0) | ||
valid_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 0].copy() | ||
test_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 1].copy() | ||
|
||
assert set(valid_part["q_idx"].tolist()) == set(test_part["q_idx"].tolist()) | ||
assert set(valid_part["q_idx"].tolist()) == set(valid_part["q_idx"].tolist()).intersection(train_part["q_idx"].tolist()) | ||
|
||
train_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "train_part.csv"), index = False) | ||
test_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "test_part.csv"), index = False) | ||
valid_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "valid_part.csv"), index = False) |
Oops, something went wrong.