Skip to content

Commit

Permalink
Adding support for multiple hard negatives in gpt embedding model
Browse files Browse the repository at this point in the history
Signed-off-by: adityavavre <[email protected]>
  • Loading branch information
adityavavre committed Aug 2, 2024
1 parent c277b9a commit dce0105
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from random import choices, sample
from typing import Mapping, Optional

import datasets
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
data_type: str = 'train', # train, query or doc
num_hard_negatives: int = 4, # number of hard negatives to use per query during training
):
"""
file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format.
Expand All @@ -73,6 +75,7 @@ def __init__(
self.index_mapping_dir = index_mapping_dir
self.virtual_tokens = virtual_tokens
self.truncation_method = truncation_method
self.num_hard_negatives = num_hard_negatives
if special_tokens is None:
self.special_tokens = {
"system_turn_start": "<extra_id_0>",
Expand Down Expand Up @@ -157,7 +160,16 @@ def _process_example(self, example):
if self.data_type == 'train':
q = self.tokenizer.text_to_ids("query: " + example['query'].strip())
d = self.tokenizer.text_to_ids("passage: " + example['pos_doc'].strip())
nd = self.tokenizer.text_to_ids("passage: " + example['neg_doc'].strip())
# handle cases where the required number of hard negatives are not present
if len(example['neg_doc']) < self.num_hard_negatives:
nd = example['neg_doc']
# sample rest with replacement
nd = nd + choices(example['neg_doc'], k=self.num_hard_negatives - len(example['neg_doc']))
else:
# sample without replacement
nd = sample(example['neg_doc'], k=self.num_hard_negatives)
assert len(nd) == self.num_hard_negatives, "Error in sampling required number of hard negatives"
nd = [self.tokenizer.text_to_ids("passage: " + ex.strip()) for ex in nd]
elif self.data_type == 'query':
q = self.tokenizer.text_to_ids("query: " + example['query'].strip())
d, nd = None, None
Expand All @@ -179,22 +191,22 @@ def _process_example(self, example):
# these pad/eos tokens are placeholders for virtual tokens for ptuning (if used)
q = [self.tokenizer.eos_id] * self.virtual_tokens + q # type: ignore
d = [self.tokenizer.eos_id] * self.virtual_tokens + d # type: ignore
nd = [self.tokenizer.eos_id] * self.virtual_tokens + nd # type: ignore
nd = [[self.tokenizer.eos_id] * self.virtual_tokens + n for n in nd] # type: ignore

if self.add_bos:
q = [self.tokenizer.bos_id] + q # type: ignore
d = [self.tokenizer.bos_id] + d # type: ignore
nd = [self.tokenizer.bos_id] + nd # type: ignore
nd = [[self.tokenizer.bos_id] + n for n in nd] # type: ignore

# TODO: (@adithyare) should probably add a warning before truncation
q = q[: self.max_seq_length - 1]
d = d[: self.max_seq_length - 1]
nd = nd[: self.max_seq_length - 1]
nd = [n[: self.max_seq_length - 1] for n in nd]

if self.add_eos:
q = q + [self.tokenizer.eos_id] # type: ignore
d = d + [self.tokenizer.eos_id] # type: ignore
nd = nd + [self.tokenizer.eos_id] # type: ignore
nd = [n + [self.tokenizer.eos_id] for n in nd] # type: ignore

processed_example = {
'query': q,
Expand Down Expand Up @@ -244,9 +256,12 @@ def collate_fn(self, batch):
lengths.append(len(item['query']))
input_ids.append(item['pos_doc'])
lengths.append(len(item['pos_doc']))
input_ids.append(item['neg_doc'])
lengths.append(len(item['neg_doc']))
max_length = max(max_length, len(item['query']), len(item['pos_doc']), len(item['neg_doc']))
for nd in item['neg_doc']:
input_ids.append(nd)
lengths.append(len(nd))
max_length = max(
max_length, len(item['query']), len(item['pos_doc']), *(len(nd) for nd in item['neg_doc'])
)
elif self.data_type == 'query':
input_ids.append(item['query'])
lengths.append(len(item['query']))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ class MegatronGPTEmbeddingModel(MegatronGPTSFTModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
self.temperature = self.cfg.get('temperature', 0.02)
self.use_all_possible_negatives = self.cfg.get("use_all_possible_negatives", True)
self.hard_negatives_to_train = self.cfg.data.train_ds.get("hard_negatives_to_train", 4)
self.use_all_possible_negatives = self.cfg.get("use_all_possible_negatives", False)
self.global_inbatch_negatives = self.cfg.get("global_inbatch_negatives", True)
self.use_inbatch_negatives = self.cfg.get("use_inbatch_negatives", True)
if self.cfg.get("do_mrl", False):
min_mrl = self.cfg.get("min_mrl_dim", int(np.log2(32))) - 1
max_mrl = int(np.log2(self.cfg.hidden_size // 2))
Expand Down Expand Up @@ -398,17 +400,36 @@ def local_validation_step(self, dataloader_iter):

return loss, non_loss_tensors

def constrastive_scores(self, pos_doc_hs, neg_doc_hs, query_hs, bs, temperature, use_all_possible_negatives=False):
all_doc_hs = torch.cat([pos_doc_hs, neg_doc_hs], dim=0) # (2bs) x hidden_size
cs = torch.mm(query_hs, all_doc_hs.transpose(0, 1)) # (bs) x (2bs)
pos_cs = cs[:, :bs].diag()
neg_cs = cs[:, bs:].diag()
def constrastive_scores(
self,
pos_doc_hs,
neg_doc_hs,
query_hs,
bs,
temperature,
num_hard_negatives,
use_all_possible_negatives,
use_inbatch_negatives,
):
all_doc_hs = torch.cat([pos_doc_hs, neg_doc_hs], dim=0) # ((hn+1)bs) x hidden_size
cs = torch.mm(query_hs, all_doc_hs.transpose(0, 1)) # (bs) x ((hn+1)bs)
pos_cs = cs[:, :bs]
neg_cs = cs[:, bs:]
if use_all_possible_negatives:
labels = torch.arange(bs, device=cs.device).long()
else:
labels = torch.zeros(bs, device=cs.device).long()
cs = torch.cat([pos_cs.unsqueeze(1), neg_cs.unsqueeze(1)], dim=1)
pos_cs = pos_cs.clone().detach().mean()
neg_cs = neg_cs[
torch.arange(bs).unsqueeze(1).repeat(1, num_hard_negatives),
torch.arange(bs * num_hard_negatives).reshape(bs, num_hard_negatives),
]
if use_inbatch_negatives:
labels = torch.arange(bs, device=cs.device).long()
cs = torch.cat([pos_cs, neg_cs], dim=1)
else:
labels = torch.zeros(bs, device=cs.device).long()
cs = torch.cat([pos_cs.diag().unsqueeze(1), neg_cs], dim=1)

pos_cs = pos_cs.diag().clone().detach().mean()
neg_cs = neg_cs.clone().detach().mean()
cs = cs.clamp(-1.0, 1.0)
cs = cs / temperature
Expand All @@ -434,17 +455,27 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
eos_tensors = _gather_global_inbatch_representations(eos_tensors)
if not self.trainer.training:
return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors)
bs = eos_tensors.shape[0] // 3
query_hs = eos_tensors[::3, :] # every third tensor is a query (bs x hidden_size)
pos_doc_hs = eos_tensors[1::3, :] # every third tensor is a positive doc (bs x hidden_size)
neg_doc_hs = eos_tensors[2::3, :] # every third tensor is a negative doc (bs x hidden_size)

num_tensors_per_example = 2 + self.hard_negatives_to_train # query, pos_doc and 'n' hard neg_docs
bs = output_tensor.shape[0] // num_tensors_per_example
chunks = output_tensor.chunk(bs) # chunk to get tensors for each example
query_hs = torch.stack([item[0] for item in chunks]) # first item in every chunk is the query
pos_doc_hs = torch.stack([item[1] for item in chunks]) # second item is the pos_doc
neg_doc_hs = torch.stack([item[i + 2] for item in chunks for i in range(self.hard_negatives_to_train)]) # rest are hard negatives

query_hs = torch.nn.functional.normalize(query_hs, dim=1)
pos_doc_hs = torch.nn.functional.normalize(pos_doc_hs, dim=1)
neg_doc_hs = torch.nn.functional.normalize(neg_doc_hs, dim=1)

cs, pos_cs, neg_cs, labels = self.constrastive_scores(
pos_doc_hs, neg_doc_hs, query_hs, bs, self.temperature, self.use_all_possible_negatives
pos_doc_hs,
neg_doc_hs,
query_hs,
bs,
self.temperature,
self.hard_negatives_to_train,
self.use_all_possible_negatives,
self.use_inbatch_negatives,
)
loss = torch.nn.functional.cross_entropy(cs, labels)
if self.mrl_dims:
Expand All @@ -455,7 +486,9 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
query_hs[:, :dim],
bs,
self.temperature,
self.hard_negatives_to_train,
self.use_all_possible_negatives,
self.use_inbatch_negatives,
)
loss += torch.nn.functional.cross_entropy(cs_dim, labels)

Expand Down

0 comments on commit dce0105

Please sign in to comment.