Skip to content

Latest commit

 

History

History
50 lines (35 loc) · 3.26 KB

README.md

File metadata and controls

50 lines (35 loc) · 3.26 KB

Multitask Prompt Learning for Vision-Language Models

This repo contains the codebase of a series of research projects focused on adapting vision-language models like CLIP to downstream datasets via multitask prompt learning:

                                                (a) CoOp                         (b) VPT                         (c) UPT

How to Install

This code is built on top of the toolbox Dassl.pytorch and CoOp so you need to install the dassl and PyTorch environment first. After that, run pip install -r requirements.txt under MVLPT/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets from CoOp for multitask source prompt initialization or run the following script after install gdown.

bash scripts/data.sh

Note that the dataset for target ELEVATER benchmark will be downloaded automatically in MVLPT/trainers/vision_benchmark/.

How to Run

Click a paper below to see the detailed instructions on how to run the code to reproduce the results.

Models and Results

  • The pre-trained weights of MVLPT (MCoOp, MVPT, MUPT) on 11 tasks based on ViT-B/16 and ViT-B/32 can be downloaded altogether via this link. The weights can be used to reproduce the results in Table 1 of MVLPT's paper (i.e., the results on ImageNet and its four variants with domain shift). To load the weights and run the evaluation code, you will need to specify --model-dir and --load-epoch (see this script for example).

Citation

If you use this code in your research, please kindly cite the following papers

@article{shen2022mvlpt,
    title={Multitask Vision-Language Prompt Tuning},
    author = {Shen, Sheng and Yang, Shijia and Zhang, Tianjun and Zhai, Bohan and Gonzalez, Joseph E. and Keutzer, Kurt and Darrell, Trevor},
    journal={arXiv preprint arXiv:2211.11720},
    year={2022}
}