This is the official pytorch implementation of the paper:
Complementary Random Masking for RGB-T Semantic Segmentation
Ukcheol Shin, Kyunghyun Lee, In So Kweon
[Paper] [Project page]
Further visualization can be found in the video.
- 2023.03.30: Release evaluation code and pre-trained weights.
- 2023.12.19: Release training code
This codebase was developed and tested with the following packages.
- OS: Ubuntu 20.04.1 LTS
- CUDA: 11.3
- PyTorch: 1.10.1
- Python: 3.9.16
- Detectron2: 0.6
You can build your conda environment with the provided YAML file.
conda env create --file environment.yml
Or you can build it manually.
conda create pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -c conda-forge --name CRM
conda activate CRM
python -m pip install detectron2 -f \
https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
pip install mmcv==1.7.1 pytorch-lightning==1.9.2 scikit-learn==1.2.2 timm==0.6.13 imageio==2.27.0 setuptools==59.5.0
If you want to test this codebase in another software stack, check the following compatibility:
After building conda environment, compile CUDA kernel for MSDeformAttn. If you have trouble, refer here
cd models/mask2former/pixel_decoder/ops/
sh make.sh
Download datasets and place them in 'datasets' folder in the following structure:
- MF dataset or RTFNet preprocessed version
- PST900 dataset
- KP dataset, Segmentation label or Pre-organized KP dataset
Since the original KP dataset has a large volume (>35GB) and requesting labels takes time, we recommend to use our pre-organized dataset (includes labels as well).
<datasets>
|-- <MFdataset>
|-- <images>
|-- <labels>
|-- train.txt
|-- val.txt
|-- test.txt
...
|-- <PSTdataset>
|-- <train>
|-- rgb
|-- thermal
|-- labels
...
|-- <test>
|-- rgb
|-- thermal
|-- labels
...
|-- <KPdataset>
|-- <images>
|-- set00
|-- set01
...
|-- <labels>
|-- train.txt
|-- val.txt
|-- test.txt
...
- Download SwinTransformer backbone weights pretrained in ImageNet and convert its format for Mask2Former compatibility:
cd pretrained
sh download_backbone_pt_weight.sh
sh convert_pth_to_pkl.sh
- Train a model with the config file. If you want to change hyperparamter (e.g., batch size, epoch, learning rate, etc), edit config file in 'configs' folder.
Single GPU, MF dataset, Swin-S model
CUDA_VISIBLE_DEVICES=0 python train.py --config-file ./configs/MFdataset/swin/CRM_swin_small_224.yaml --num-gpus 1 --name MF_CRM_swin_S
Multi GPUs, KP dataset, Swin-B model
CUDA_VISIBLE_DEVICES=0,1 python train.py --config-file ./configs/KPdataset/swin/CRM_swin_base_224.yaml --num-gpus 2 --name KP_CRM_swin_B_multi
- Start a
tensorboard
session to check training progress.
tensorboard --logdir=checkpoints/
You can see the progress by opening https://localhost:6006 on your browser.
Evaluate the trained model by running
CUDA_VISIBLE_DEVICES=0 python test.py --config-file ./configs/MFdataset/swin/CRM_swin_small_224.yaml --num-gpus 1 --name Eval_MF_CRM_swin_S --checkpoint "PATH for WEIGHT"
We offer the pre-trained weights on three RGB-T semantic segmentation dataset.
Architecture | Backbone | mIOU | Weight |
---|---|---|---|
CRM (Mask2Former) | Swin-T | 59.1% | MF-CRM-Swin-T |
CRM (Mask2Former) | Swin-S | 61.2% | MF-CRM-Swin-S |
CRM (Mask2Former) | Swin-B | 61.4% | MF-CRM-Swin-B |
Architecture | Backbone | mIOU | Weight |
---|---|---|---|
CRM (Mask2Former) | Swin-T | 85.9% | PST-CRM-Swin-T |
CRM (Mask2Former) | Swin-S | 86.9% | PST-CRM-Swin-S |
CRM (Mask2Former) | Swin-B | 88.0% | PST-CRM-Swin-B |
Architecture | Backbone | mIOU | Weight |
---|---|---|---|
CRM (Mask2Former) | Swin-T | 51.2% | KP-CRM-Swin-T |
CRM (Mask2Former) | Swin-S | 54.4% | KP-CRM-Swin-S |
CRM (Mask2Former) | Swin-B | 55.2% | KP-CRM-Swin-B |
Our code is licensed under a MIT License.
Please cite the following paper if you use our work in your research.
@article{shin2023comp,
title={Complementary Random Maksing for RGB-T Semantic Segmentation},
author={Shin, Ukcheol and Lee, Kyunghyun and Kweon, In So},
journal={Arxiv pre-print},
year={2023}
}
Our network architecture and codebase are built upon Mask2Former.
- Mask2Former (CVPR 2022)
We use the evaluation metric provided by RTFNet.
- RTFNet (RA-L 2019)