Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: adityavavre <[email protected]>
  • Loading branch information
adityavavre committed Aug 2, 2024
1 parent 22b536d commit e936139
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
data_type: str = 'train', # train, query or doc
num_hard_negatives: int = 4, # number of hard negatives to use per query during training
num_hard_negatives: int = 4, # number of hard negatives to use per query during training
):
"""
file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,14 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
if not self.trainer.training:
return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors)

num_tensors_per_example = 2 + self.hard_negatives_to_train # query, pos_doc and 'n' hard neg_docs
bs = output_tensor.shape[0] // num_tensors_per_example
chunks = output_tensor.chunk(bs) # chunk to get tensors for each example
query_hs = torch.stack([item[0] for item in chunks]) # first item in every chunk is the query
pos_doc_hs = torch.stack([item[1] for item in chunks]) # second item is the pos_doc
neg_doc_hs = torch.stack([item[i + 2] for item in chunks for i in range(self.hard_negatives_to_train)]) # rest are hard negatives
num_tensors_per_example = 2 + self.hard_negatives_to_train # query, pos_doc and 'n' hard neg_docs
bs = output_tensor.shape[0] // num_tensors_per_example
chunks = output_tensor.chunk(bs) # chunk to get tensors for each example
query_hs = torch.stack([item[0] for item in chunks]) # first item in every chunk is the query
pos_doc_hs = torch.stack([item[1] for item in chunks]) # second item is the pos_doc
neg_doc_hs = torch.stack(
[item[i + 2] for item in chunks for i in range(self.hard_negatives_to_train)]
) # rest are hard negatives

query_hs = torch.nn.functional.normalize(query_hs, dim=1)
pos_doc_hs = torch.nn.functional.normalize(pos_doc_hs, dim=1)
Expand Down

0 comments on commit e936139

Please sign in to comment.