This repository hosts the code for the paper:
RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning (ICLR 2022)
by Xiaojian Ma, Weili Nie, Zhiding Yu, Huaizu Jiang, Chaowei Xiao, Yuke Zhu and Anima Anandkumar
- 🔥🔥 09/10/2022: Pre-trained models on GQA are now released.
Reasoning about visual relationships is central to how humans interpret the visual world. This task remains challenging for current deep learning algorithms since it requires addressing three key technical problems jointly: 1) identifying object entities and their properties, 2) inferring semantic relations between pairs of entities, and 3) generalizing to novel object-relation combinations, i.e., systematic generalization. In this work, we use vision transformers (ViTs) as our base model for visual reasoning and make better use of concepts defined as object entities and their relations to improve the reasoning ability of ViTs. Specifically, we introduce a novel concept-feature dictionary to allow flexible image feature retrieval at training time with concept keys. This dictionary enables two new concept-guided auxiliary tasks: 1) a global task for promoting relational reasoning, and 2) a local task for facilitating semantic object-centric correspondence learning. To examine the systematic generalization of visual reasoning models, we introduce systematic splits for the standard HICO and GQA benchmarks. We show the resulting model, Concept-guided Vision Transformer (or RelViT for short) significantly outperforms prior approaches on HICO and GQA by 16% and 13% in the original split, and by 43% and 18% in the systematic split. Our ablation analyses also reveal our model's compatibility with multiple ViT variants and robustness to hyper-parameters.
-
Install PyTorch:
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
-
Install the necessary packages with
requirements.txt
:pip install -r requirements.txt
The code has been tested with Python 3.8, PyTorch 1.11.0 and CUDA 11.6 on Ubuntu 20.04
Please refer to data preparation
HICO
bash scripts/train_hico_image.sh configs/train_hico.yaml
-
In
configs/train_hico.yaml
you may find some configurable options:- use
eval_mode
to run different experiments: original or systematic generalization test - use
model_args.encoder_args.encoder
andload_encoder
to select the vision backbone. There are five options available:pvtv2_b2
,pvtv2_b3
,swin_small
,swin_base
andvit_small_16
. - use
relvit
to turn on/off RelViT auxillary loss - use
relvit_weight
to adjust the coefficient of RelViT auxillary loss - use
relvit_local_only
to control if you only use RelViT local/global task - use
relvit_mode
to control if you want to include EsViT loss. - use
relvit_sample_uniform
to choose from uniform or "most-recent" concept sampling - use
relvit_concept_use
andrelvit_num_concepts
to choose the concept used by RelViT amongHOI
,verb
andobject
In general, we don't recommend modifying other parameters.
- use
-
All the GPUs will be used by default. To run with the recommended batch size, you may need 1 V100 32G GPU.
GQA
bash scripts/train_gqa_image.sh configs/train_gqa.yaml
-
In
configs/train_gqa.yaml
you may find some configurable options:- use
eval_mode
to run different experiments: original or systematic generalization test - use
model_args.encoder_args.encoder
andload_encoder
to select the vision backbone. There are five options available:pvtv2_b2
,pvtv2_b3
,swin_small
,swin_base
andvit_small_16
. - use
relvit
to turn on/off RelViT auxillary loss - use
relvit_weight
to adjust the coefficient of RelViT auxillary loss - use
relvit_local_only
to control if you only use RelViT local/global task - use
relvit_mode
to control if you want to include EsViT loss - use
relvit_sample_uniform
to choose from uniform or "most-recent" concept sampling
In general, we don't recommend modifying other parameters.
- use
-
All the GPUs will be used by default. To run with the recommended batch size, you may need up to 64 V100 32G GPUs. This is because we need to fine-tune the vision backbone during training.
HICO
bash scripts/train_hico_image.sh configs/train_hico.yaml --test_only --test_model <path to best_model.pth>
GQA
bash scripts/train_gqa_image.sh configs/train_gqa.yaml --test_only --test_model <path to best_model.pth>
tag | encoder | experiment | result | URL |
---|---|---|---|---|
swin-small-relvit |
swin_small |
GQA (val) | 61.38 | link |
swin-base-relvit |
swin_base |
GQA (val) | 65.54 | link |
Please check the LICENSE file for both the code and the released pre-trained models. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].
The authors have referred the following projects:
Please consider citing our paper if you find our work helpful for your research:
@inproceedings{ma2022relvit,
title={RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning},
author={Xiaojian Ma and Weili Nie and Zhiding Yu and Huaizu Jiang and Chaowei Xiao and Yuke Zhu and Song-Chun Zhu and Anima Anandkumar},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=afoV8W3-IYp}
}