Skip to content

Fix PNP loss to make it work with negatives without related positives#660

Merged
KevinMusgrave merged 10 commits intoKevinMusgrave:devfrom
Puzer:pnp_loss_nan_fix
Nov 11, 2023
Merged

Fix PNP loss to make it work with negatives without related positives#660
KevinMusgrave merged 10 commits intoKevinMusgrave:devfrom
Puzer:pnp_loss_nan_fix

Conversation

@Puzer
Copy link
Copy Markdown
Contributor

@Puzer Puzer commented Sep 12, 2023

Currently PNP loss returns NaN if you have some negatives examples without related positive examples
labels = torch.tensor([1, 1, 2])

Let's say you have anchor, positive and negative.
N_pos (from PNP loss) for labels in this case will be [2, 2, 0]
So in this case you will get devision by 0 and loss will be NaN as result.

This fix keeps only positive instances at the final stage.

I've tested that for my use-case and it works quite well.
However I'm not sure wether it's mathematically correct or not, maybe there is more reasonable fix of this issue.

@KevinMusgrave
Copy link
Copy Markdown
Owner

@interestingzhuo Any thoughts on this?

@KevinMusgrave KevinMusgrave changed the base branch from master to dev October 18, 2023 04:00
@interestingzhuo
Copy link
Copy Markdown
Contributor

@interestingzhuo Any thoughts on this?
That's right! And the code will be

loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1)
loss = torch.sum(loss) / torch.sum(safe_N)

for effective training.

@KevinMusgrave
Copy link
Copy Markdown
Owner

Thanks @Puzer and @interestingzhuo !

@KevinMusgrave KevinMusgrave merged commit dc772a8 into KevinMusgrave:dev Nov 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants