Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem in contrastive loss #16

Closed
Jarvis73 opened this issue Jul 29, 2021 · 10 comments
Closed

Problem in contrastive loss #16

Jarvis73 opened this issue Jul 29, 2021 · 10 comments

Comments

@Jarvis73
Copy link

Jarvis73 commented Jul 29, 2021

Hi, Dr. Zhou,

Thanks for releasing the code. When reading the code about the contrastive loss in function _contrastive(), a mask is computed by following two lines:

mask = torch.eq(y_anchor, y_contrast.T).float().cuda()

and
mask = mask.repeat(anchor_count, contrast_count)

Now I think the shape of the mask is [anchor_num * anchor_count, class_num * cache_size]. If I did not misunderstand the code, the mask is a 'positive' mask, and each line represents the positive samples of an anchor view.

Then in L134-L138, the function of logits_mask is confusing:

logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(),
0)
mask = mask * logits_mask

Could you please explain these lines?

Suppose I have anchor_num=6 (2 images, 3 valid classes per image), anchor_count=2 (sample two pixels per class), class_num=5 (class number), cache_size=2 (memory size), then the following code raises RuntimeError:

mask = torch.ones((6 * 2, 5 * 2)).scatter_(1, torch.arange(6 * 2).view(-1, 1), 0)

Output:

Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
RuntimeError: index 10 is out of bounds for dimension 1 with size 10
@lorafei
Copy link

lorafei commented Aug 12, 2021

Same question here. Can anyone explain why the masks and logits are formed like this?

@wangbo-zhao
Copy link

why we use logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(),
0)

@eezywu
Copy link

eezywu commented Aug 23, 2021

why we use logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(),
0)

I think the purpose is to ignore the similarity of the feature itself.

@HenryPengZou
Copy link

Same question here, could you help explain why the masks and logits are formed like this? @tfzhou

@bomtorazek
Copy link

why we use logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(),
0)

I think the purpose is to ignore the similarity of the feature itself.

Hmmm, but when "queue" is not None, "contrast_feature" is from "queue" and "anchor_feature" is from "X_anchor", so I think they are totally different and there is no need to ignore the feature itself.

@Xxxxiahaofeng
Copy link

I have another question about function _contrastive().
What's the purpose of

logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)  
logits = anchor_dot_contrast - logits_max.detach()

If size of anchor_dot_contrast is [800, 800], then logits_max is the diagonal elements of anchor_dot_contrast with shape [800, 1]. So the subtraction means every column of anchor_dot_contrast has to be subtracted by logits_max.

@xianxuan-z
Copy link

I finally find partners here,I have a question.
What does the args “n_views” mean in Contrastive loss?Can you show me an example?

@tfzhou
Copy link
Owner

tfzhou commented Dec 13, 2021

Thanks for the questions! @Jarvis73 L134-138 is to remove self-contrastive. @Xxxxiahaofeng, the purpose to remove logits_max is for numerical stability. It is not always necessary.

@tfzhou tfzhou closed this as completed Dec 26, 2021
@purse1996
Copy link

I agree with your opinion. I think
when queue is None,
logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(),
0)

@crisz94
Copy link

crisz94 commented May 8, 2024

Thanks for the questions! @Jarvis73 L134-138 is to remove self-contrastive. @Xxxxiahaofeng, the purpose to remove logits_max is for numerical stability. It is not always necessary.

hi @tfzhou , I think self-contrastive only exists when queue==None.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests