Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViT training ? #10

Open
alexcbb opened this issue Dec 18, 2023 · 24 comments
Open

ViT training ? #10

alexcbb opened this issue Dec 18, 2023 · 24 comments

Comments

@alexcbb
Copy link

alexcbb commented Dec 18, 2023

Hello,

Thank you for your very interesting work ! I'm currently trying to replicate your results with your provided codebase and I was wondering whether you also tested a Vision Transformer architecture as encoder ? You compared in the paper with DINO, but I wanted to know if you where able to get some properties close to what they obtained (a kind of saliency map with the attention map around the object of interest).

Thank you again for your response !

@xwen99
Copy link
Member

xwen99 commented Dec 18, 2023

Hi @alexcbb, thanks for your attention to our work!

We actually didn't thoroughly experiment with ViTs due to computation constraints. Regarding the object-centric attention maps of DINO, we believe that is a merit of Transformers, and for CNNs we need to find another path. Our method explores doing it via explicit clustering on top of CNN features, which indeed worked.
Besides that, we also tried to find similar visualizations within CNNs themselves, and we found PCA on dense feature maps produced plausible results. Due to the hierarchical structure of CNN, the resulting visualization's resolution is relatively low. We tried tricks like modifying the stride, which did not help much.
Hope that is helpful for you!

@alexcbb
Copy link
Author

alexcbb commented Dec 18, 2023

Thank you for your quick answer !

I'm would be very interested to explore whether such training would be beneficial for Vision Transformer (even for a small version like ViT-S 16) : I'm first trying to check whether I can reproduce your results with ResNet and then wants to apply it to ViT. I think this could be beneficial to extract object knowledge to some extent and bring some prior for the training, and more again on scene-centric datasets.

Can I ask for your help in this process ? (mainly on the replication of the results)

Thanks again !

@xwen99
Copy link
Member

xwen99 commented Dec 18, 2023

Feel free to leave a message if there is trouble working on that.

@alexcbb
Copy link
Author

alexcbb commented Dec 18, 2023

For the pre-training on COCO it is indicated that it was performed on 8 GPU NVIDIA 2080 Ti for 800 epochs. Do you have maybe an average time required for such training, and eventually some memory consumption information ?

@xwen99
Copy link
Member

xwen99 commented Dec 18, 2023

It should took up almost all memories of 8x2080 Ti, roughly 80GB in total. I do not remember well the precise time it took for training, maybe roughly 2~3 days?

@alexcbb
Copy link
Author

alexcbb commented Dec 19, 2023

I made some small changes to launch the training (I created a Pytorch Lightning module to ease the deployment on clusters) and began to launch a 800 epoch training on COCO. Here is an overview of the current evolution of the loss (it is now at around 230 epochs after ~1 day of training), does it seems to be a right convergence curve ?
Screenshot from 2023-12-19 17-30-11

@alexcbb
Copy link
Author

alexcbb commented Dec 22, 2023

It seems that I'm not able to replicate your figure 3 after pre-training and I don't understand why, the prototypes seems a bit weird and there's no mask on my final image

@xwen99
Copy link
Member

xwen99 commented Dec 22, 2023

Fig 3 is simply produced using viz_slots.py with the default configs. The model is the default model on coco, with 800 epochs of training. Please check if there are any errors in your reimplementation.

@alexcbb
Copy link
Author

alexcbb commented Jan 3, 2024

Hello, the problem was in my visualization file, it seems that I'm now able to obtain well aligned concepts ! I've seen in the paper that you say one would need to scale the loss according to the batch size (if we would augment its size). Can you maybe tell me more about this ? (I've trained my model using your default parameters and batch size of 1536 without any huge problem on the results)

@xwen99
Copy link
Member

xwen99 commented Jan 4, 2024

Hi, we scale the learning rate linearly with the batch size, as done by many previous works. This part is already implemented in the code, and basically no more modification is needed for you:

lr=args.batch_size * args.world_size / 256 * args.base_lr,

@alexcbb alexcbb closed this as completed Jan 9, 2024
@alexcbb alexcbb reopened this Jan 9, 2024
@alexcbb
Copy link
Author

alexcbb commented Jan 9, 2024

Hello, I would have a question concerning the slot loss part. You specify in equation (5) a masking over the slots that do not occupy dominating groups, that you then use for the computing of an InfoNCE loss. I was wondering in the code with the ctr_loss_filtered function of SlotCon why you would use the mask_intersection over mask_q to select the slots of q ? Is it in order to avoid slots that do not have positive pair in k, or is it another explanation ? Thank you in advance. Did you make any ablations on this masking on whether it was helping or no for the training ?

@xwen99
Copy link
Member

xwen99 commented Jan 10, 2024

Your anticipation is correct, this is to make sure they form a positive pair, such that both the query and key slots exist across views. From my memory, we didn't ablate much on that.

@alexcbb
Copy link
Author

alexcbb commented Jan 10, 2024

Your anticipation is correct, this is to make sure they form a positive pair, such that both the query and key slots exist across views. From my memory, we didn't ablate much on that.

Ok, thank you again for your answer. Did you ever encounter the case where there's no positive pair in the views ? While I was trying to train the model with a ViT backbone, I obtained a NaN loss and the issue comes from the slot loss. At a certain point, it is not able to get any intersection mask. If you have encountered this during your experiments it would help me a lot !

@xwen99
Copy link
Member

xwen99 commented Jan 11, 2024

Well actually I can't recall well about the details..., you may consider dropping that pair in this case

@KJ-rc
Copy link

KJ-rc commented Jan 29, 2024

Hi, @alexcbb @xwen99

I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs.
The rest of the settings can be found below.

Prototype visualization makes sense but is weird. I am checking the code now.
I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

@alexcbb
Copy link
Author

alexcbb commented Jan 29, 2024

Hi, @alexcbb @xwen99

I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below.

Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hello, on my side I was not able to make the training converge properly. The slot loss is returning time to time a NaN on masking and I don't know why this is happenning. Can I maybe know what changes you made to replace the backbone ? Concerning the hyperparameters I've got the same as yours (I've took the same hyperparameter as DINO training). Your prototypes looks coherent for me, what does seems weird for you ? I would gladly have a discussion with you about your re-implementation if you agree !

@KJ-rc
Copy link

KJ-rc commented Jan 29, 2024

Hi, @alexcbb @xwen99
I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below.
Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hello, on my side I was not able to make the training converge properly. The slot loss is returning time to time a NaN on masking and I don't know why this is happenning. Can I maybe know what changes you made to replace the backbone ? Concerning the hyperparameters I've got the same as yours (I've took the same hyperparameter as DINO training). Your prototypes looks coherent for me, what does seems weird for you ? I would gladly have a discussion with you about your re-implementation if you agree !

Sure. You can send me an email. ([Update] - my email: [email protected])
For a quick answer, I tried to make minimum changes. Specifically,

  • Borrowed the ViT-S implimentation from DINOv1, and made its output a 4D torch.Tensor: [B, C, H, W].
    return x[:, 1:].transpose(-2, -1).reshape(-1, self.embed_dim, h, w)

  • Change the num_channel of SlotCon(nn.Module) to 384.
    self.num_channels = 384

  • Use an AdamW optimizer as described above

Regarding prototype visualization, I found there are some empty prototypes, and in the 4th column from the right-hand side, it shows "cat", "cow" and "bear" while the ResNet50 based one can output a pure cat prototype. I would say the semantic consistency is lower in ViT-S based one. Again, I am not sure if this behavior is correct, or not.

@alexcbb
Copy link
Author

alexcbb commented Jan 29, 2024

@KJ-rc
Ok thank you, can you provide me your mail, I'm not able to find it.
Concerning the implementation from DINO you trained it from scratch, right ? I will try as you said, maybe there are some issues in my code I didn't saw and let you know if I have the same artifact as yours. But I'm pretty sure I've done the same as you did (very few changes)

But yes you are right I didn't saw the empty prototypes ! This is quite weird indeed. Maybe some changes in the hyperparameters can make it a bit better (like the temperature from the student/teacher ?). What about also let it train for longer (like 300~400 epochs) ? I will let you know as soon as I'm able to train the ViT from scratch on my side

@alexcbb
Copy link
Author

alexcbb commented Jan 31, 2024

@KJ-rc Just to be sure, as DINO do not use BatchNorm in its projection Head, did you also removed the BatchNorm (and subsequently the SyncBatchNorm calls in SlotCon) ?

@KJ-rc
Copy link

KJ-rc commented Jan 31, 2024

Hi,
I did only the modifications listed above.
I consider projectors to have a higher dependency on pre-trained methods rather than backbone architecture,
so I keep the batch norm layer.

@alexcbb
Copy link
Author

alexcbb commented Jan 31, 2024

It seems that by reducing the batch size to 256 and adding a gradient clipping the training is now working. I'll see how it evolves and let you know about my final results !

@alexcbb
Copy link
Author

alexcbb commented Feb 5, 2024

Hello, it seems that I've got the same issue as @KJ-rc when training SlotCon with ViT but I think this issue is because there are "dead slot" appearing during the training as said in the annex D. I've tested to print out 100 slots to check for the semantics and over those 100, around 20 were "dead slot" without any meaning. It seems to be quite related to the discussion in Annex D, what do you think ?
res3

@xwen99
Copy link
Member

xwen99 commented Feb 7, 2024

See if this paper helps you understand the dead slots: https://openreview.net/forum?id=Z2dVrgLpsF

@alexcbb
Copy link
Author

alexcbb commented Feb 8, 2024

See if this paper helps you understand the dead slots: https://openreview.net/forum?id=Z2dVrgLpsF

Thank you for the sharing. It would be interesting to evaluate whether the problem would disappear with such regularization !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants