Skip to content

DongXzz/RoLI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RoLI

This repository contains the official PyTorch implementation for Robust Linear Initialization (RoLI), CVPR 2024.

Paper link: Initialization Matters for Adversarial Transfer Learning


This repository contains the official PyTorch implementation for Robust Linear Initialization (RoLI).

intro

Environment settings

We provide the required pacakges in requirements.txt file. To use in a new Conda environment,

conda create --name roli python=3.8.17
conda activate roli
pip install -r requirements.txt

Datasets preperation:

We use CIFAR-10 and CIFAR-100 from Visual Task Adaptation Benchmark (VTAB). Please see VTAB_SETUP.md for instructions. Note, only CIFAR-10 and CIFAR-100 are required. We follow the official train/test splits and the size of validation set is specified in TRAIN_SPLIT_PERCENT in code.

The train/val/test splits for other datasets are provided in data_splits. Copy the JSON file to the location where your dataset is downloaded.

Adversarially robust/non-robust pretraining preperation:

Download and place the pretrained model to pretrain.

  • Robust Pretrained Model: ARES2.0 for SwinB and ViTB.

  • Non-robust Pretrained Model: For Swin Transformer(swin_base_patch4_window7_224), we use the weights from official implementation. For ViT, we use the weights from torchvision.

Training:

We provide training configs for both RoLI and RanLI. Our trained RoLI checkpoints can be found at google drive.

Before training, make sure to set the OUTPUT_DIR and change the DATAPATH to the location where your dataset is downloaded.

RanLI

For adversarially fully finetuning (RanLI-Full-FT) in CIFAR-10,

python train.py --config-file=configs/finetune/cifar10_ranli.yaml

To adversarially finetuning with PEFT, such as RanLI-LoRA in Stanford Dogs,

python train.py --config-file=configs/lora/stanforddogs_ranli.yaml

For adversarially linear probing (RanLI-Linear) in Caltech256,

python train.py --config-file=configs/linear/caltech256_ranli.yaml

RoLI

RoLI includes two stage training: robust linear initialization and adversarially finetuning.

For RoLI-Full-FT in CIFAR-10, we first train RanLI-Linear. After training, two checkpoints are saved: the best validation performance checkpoint best_model.pth, and the last checkpoint last_model.pth. Next, we set the WEIGHT_PATH to the location of best_model.pth. Finally, we perform adversarial full finetuning in CIFAR-10.

python train.py --config-file=configs/linear/cifar10_ranli.yaml
# Set WEIGHT_PATH in finetune/cifar10_roli.yaml to the path of RanLI-Linear trained checkpoint best_model.pth.
python train.py --config-file=configs/finetune/cifar10_roli.yaml

Similarly, to train the RoLI-Bias in CUB200,

python train.py --config-file=configs/linear/cub_ranli.yaml
# Set WEIGHT_PATH in bias/cub_roli.yaml to the path of RanLI-Linear trained checkpoint best_model.pth.
python train.py --config-file=configs/bias/cub_roli.yaml

Testing:

To test, set SOLVER.TOTAL_EPOCH to 0 and set WEIGHT_PATH to the checkpoint location in the config file.

We provide one sample test config to evaluate RoLI-Adapter in Stanford Dogs,

python train.py --config-file=configs/adapter/stanforddogs_roli_eval.yaml

To test with AutoAttack,

python autoattack.py --config-file=configs/adapter/stanforddogs_roli_eval.yaml

Citation:

If you find our work helpful in your research, please cite it as:

@inproceedings{hua2024initialization,
  title={Initialization Matters for Adversarial Transfer Learning},
  author={Hua, Andong and Gu, Jindong and Xue, Zhiyu and Carlini, Nicholas and Wong, Eric and Qin, Yao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2024}
}

Acknowledgement:

This rebo is built on VPT.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages