CLIMATv2: Clinically-Inspired Multi-Agent Transformers for Disease Trajectory Forecasting from Multimodal Data
This is the implementation of the paper CLIMATv2: https://arxiv.org/abs/2210.13889. Its previous version (CLIMATv1) can be found at https://arxiv.org/abs/2104.03642.
The concept of the framework is as follows
The differences of CLIMATv2 compared to CLIMATv1 are:
- General practitioner (GP) is allowed to utilize multimodalities to perform diagnosis prediction (i.e, y_0)
- The diagnosis predictions of Radiologist and GP is enforced to be consistent
- Cross-entropy loss is replaced by CLUB (Calibrated Loss based on Upper Bound), which takes into account both performance and calibration during optimization.
Run commands:
git clone [email protected]:Oulu-IMEDS/CLIMATv2.git
cd ./CLIMATv2
conda create -n CLIMATv2 python=3.7
conda activate CLIMATv2
pip install -e .
You can use the ADNI metadata prepared in ./adni/Metadata/adni_fdgpet_prognosis.csv
, or regenerated them using
# Modify input and output paths, then run
python ./common/adni/preprocess_adni.py
# Standard voxels if needed
python ./common/adni/standardize_voxels.py
Command line:
# General setting using default values in configuration files in ./adni/configs/config_train.yaml
python train.py config=seq_multi_prog_climatv2
# Detailed setting
python train.py config=seq_multi_prog_climatv2 comment=mycomment \
bs=${BATCH_SIZE} num_workers=${NUM_WORKERS} root.path=/path/to/ANDI meta_root=/path/to/meta_dir/ fold_index=1 \
backbone_name=shufflenetv2 max_depth=4 num_cls_num=4 prognosis_coef=1 cons_coef=0.5 \
loss_name=CLUB club.s=0.5
config
can be
seq_multi_prog_climatv1
: CLIMATv1seq_multi_prog_climatv2
: CLIMATv2
Processing:
bs
: batch sizenum_workers
: the number of workers
Data setup:
root.path
: root directory of imagesmeta_root
: root directory of metadata (.csv or saved split configuration in .pkl)fold_index
: fold index (starting from 1)
Model:
backbone_name
: backbone for imaging feature extractionmax_depth
: the number of CNN blocks in imaging feature extraction modulen_meta_features
: the length of metadata featuresnum_cls_num
: the number of [CLS] embebddings in transformer P
Coefficients in loss
prognosis_coef
: coefficient for prognosis predictioncons_coef
: coefficient for consistency term
loss_name
is either
CLUB
: Calibrated loss based on upper bound (ours).club.s
: epsilon hyperparameter in CLUB.CE
: cross-entropy lossFL
: focal lossFLA
: adaptive focal lossMTL
: multi-task loss
Hyperparameters used in the paper:
python eval.py root.path=/path/to/imgs_dir/ meta_root=/path/to/metadata_dir/ \
eval.root=/path/to/trained_models_dir/ eval.patterns=${PATTERN} eval.output=/path/to/output.json \
use_only_baseline=True seed=${SEED} \
save_predictions=${SAVE_PREDICTIONS} save_attn=${SAVE_ATTENTION_MAPS}
Input data for evaluation:
root.path
: root directory of imagesmeta_root
: root directory of metadata (.csv or saved split configuration in .pkl)eval.root
: root directory containing sub-directories of trained settingseval.patterns
: a common pattern of saved model files (e.g.,pn_avg_ba
for average balanced accuracies, orpn_avg_mauc
for average mAUCs)eval.output
: path to file storing evaluation resultsuse_only_baseline
: whether to use data at the baseline as input (alwaysTrue
)save_predictions
: whether to save predictions for visualizationsave_attn
: whether to save attention maps for visualization
Run commands:
# Generate longitudinal data
python ./common/prepare_1img_seq_metadata.py
# Split data
python ./common/do_split.py
# General setting using default values in configuration files in ./oai/configs/config_train.yaml
python train.py config=seq_multi_prog_climatv2
# Detailed setting
python train.py config=seq_multi_prog_climatv2 \
bs=64 num_workers=8 root.path=/path/to/OAI/ meta_root=/path/to/meta_dir backbone_name=resnet18 site=C \
prognosis_coef=1.0 cons_coef=0.5 loss_name=CLUB n_meta_features=128 \
num_cls_num=8 club.s=0.5 grading=KL \
fold_index=1 seed=12345
Besides the arguments used for ADNI, we have the additional arguments for OAI:
Data:
site
: test acquisition site (C
, with the most data, is chosen for testing, meaning that sitesA
,B
,D
,E
are used for training and validation.)
grading
can be:
KL
: Kellgren and LawrenceJSL
: Lateral joint spaceJSM
: Medial joint spaceOSFL
: Lateral osteophyte in femurOSFM
: Medial osteophyte in femurOSTL
: Lateral osteophyte in tibiaOSTM
: Medial osteophyte in tibia
List of augmentations applied to knee images (Note: all right knee images are vertically flipped):
Same as above.
If you find the manuscript or codes useful, please cite as follows
@article{nguyen2022clinically,
title={Clinically-Inspired Multi-Agent Transformers for Disease Trajectory Forecasting from Multimodal Data},
author={Nguyen, Huy Hoang and Blaschko, Matthew B and Saarakkala, Simo and Tiulpin, Aleksei},
journal={arXiv preprint arXiv:2210.13889},
year={2022}
}