-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Description
Thank you for publishing the BEiT v2 code! I’m pretraining BEiT v2 on a custom industrial dataset (1.1 M fault and normal images for training; 200 k normal images for validation) and have a few questions:
(1)Codebook Utilization with Single Unit
I only have ~1 000 labeled images for one equipment unit. Pretraining VQ-KD on this single unit (700k images for 100 epochs) yields ~5 000 unused codebook embeddings.
To improve codebook usage, I added unlabeled images from other units. These units are connected via components and together form a complete vehicle. Do you think mixing in images from other units is a sound strategy for better codebook utilization? Result at (2)
(2)Slow Loss Decrease in VQ-KD Pretraining
I followed the tutorial and trained for 100 epochs, but the loss kept decreasing very slowly. I then extended to 300 epochs, yet the loss trend remains almost unchanged. Should I increase the number of epochs further, or would you recommend tuning other hyperparameters (e.g., learning rate schedule, weight decay)?
(3)Impact of Different CLIP Initializations
OpenCLIP provides several ViT-B/16 models pretrained on various datasets. Have you experimented with initializing vqkd model using these different CLIP checkpoints?In your experience, how much does the choice of CLIP-pretrained weights influence the final performance?
(4)Using DINO v2 as a Teacher Model
DINO v2 is a new version that also employs knowledge distillation. Do you think DINO v2 could serve as an effective teacher model for BEiT v2 pretraining?
Model is trained on 8 A10-24GB
python -m torch.distributed.launch --nproc_per_node=8 run_vqkd_training.py --data_set image_folder --data_path /data/beit2/data --eval_data_path /data/eval_data --output_dir /data/beit2/vqkd_output --log_dir /data/beit2/vqkd_output --process_type default --train_interpolation bicubic --min_crop_scale 0.08 --model vqkd_encoder_base_decoder_3x768x12_clip --teacher_input_size 224 --codebook_n_emd 8192 --codebook_emd_dim 32 --quantize_kmeans_init --rec_loss_type cosine --batch_size 64 --opt adamw --opt_betas 0.9 0.99 --weight_decay 1e-4 --warmup_epochs 10 --epochs 300 --save_ckpt_freq 20
python -m torch.distributed.launch --nproc_per_node=8 run_beitv2_pretraining.py --data_set image_folder --data_path /data/beit2/data --output_dir /data/beit2/pretraining_output --log_dir /data/beit2/pretraining_output --model beit_base_patch16_224_8k_vocab_cls_pt --shared_lm_head True --early_layers 9 --head_layers 2 --num_mask_patches 75 --second_input_size 224 --second_interpolation bicubic --min_crop_scale 0.2 --tokenizer_model vqkd_encoder_base_decoder_3x768x12_clip --tokenizer_weight /data/beit2/vqkd_output/checkpoint-279.pth --batch_size 96 --lr 5.625e-4 --warmup_epochs 10 --clip_grad 3.0 --drop_path 0. --layer_scale_init_value 0.1 --imagenet_default_mean_and_std --opt_betas 0.9 0.999 --opt_eps 1e-8 --epochs 300 --save_ckpt_freq 20 --init_ckpt /data/beit2/beitv2_base_patch16_224_pt1k.pth --weight_decay 0.05
I have limited GPU resources, so any suggestions on accelerating convergence or making more efficient use of resources would be greatly appreciated.
Thanks in advance for your help!