Skip to content

Commit

Permalink
Update bi-encoder-batch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
svjack authored Feb 18, 2024
1 parent 93dbb6e commit 1777be2
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions script/bi_encoder/bi-encoder-batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,34 @@
#train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join(os.path.abspath(""), "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"])
train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join("..data/", "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"])


from sentence_transformers import InputExample
class TripletsDataset(Dataset):
def __init__(self, model, qa_df):
assert set(["question", "answer", "q_idx"]).intersection(set(qa_df.columns.tolist())) == set(["question", "answer", "q_idx"])
self.model = model
self.qa_df = qa_df
self.q_idx_set = set(qa_df["q_idx"].value_counts().index.tolist())

def __getitem__(self, index):
#raise NotImplementedError
label = torch.tensor(1, dtype=torch.long)
choice_s = self.qa_df.iloc[index]
choice_s = self.qa_df.iloc[index]
query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"]
query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"]
neg_q_idx = np.random.choice(list(self.q_idx_set.difference(set([q_idx]))))
neg_text = self.qa_df[self.qa_df["q_idx"] == neg_q_idx].sample()["answer"].iloc[0]
return [self.model.tokenize(query_text), self.model.tokenize(pos_text), self.model.tokenize(neg_text)], label
#### InputExample(texts=['I can\'t log in to my account.',
#'Unable to access my account.',
#'I need help with the payment process.'],
#label=1),
return InputExample(texts = [query_text, pos_text, neg_text], label = 1)
'''
return [self.model.tokenize(query_text),
self.model.tokenize(pos_text),
self.model.tokenize(neg_text)], label
'''
#return (query_text, pos_text, q_idx)

def __len__(self):
return self.qa_df.shape[0]

Expand Down

0 comments on commit 1777be2

Please sign in to comment.