Segmentation of Multiple Sclerosis Lesions across Hospitals: Learn Continually or Train from Scratch?
This is the official PyTorch-based repository containing the code and instructions for reproducing the results of the above-mentioned paper, accepted at Medical Imaging Meets NeurIPS (MedNeurIPS) 2022 Workshop, New Orleans, LA, USA.
The paper can be found here.
This work presents a case for using continual learning for the segmentation of Multiple Sclerosis (MS) lesions across multi-center data. In particular, the problem is formalized as domain-incremental learning and uses experience replay, a well-known continual learning method, for MS lesion segmentation across eight different hospitals/centers. As shown in the figure below, four types of experiments are performed: single-domain, multi-domain, sequential fine-tuning, and experience replay. Our results show that replay performs better than fine-tuning as more data arrive and also achieves positive backward transfer, both in terms of the Dice score on a held-out test set. More importantly, replay also outperforms multi-domain (IID) training, hence suggesting that lifelong learning is a promising long-term solution for improving automated segmenttation of MS lesions compared to training from scratch.
We use soft labels (instead of binarizing them) in our training phase as they have been shown to improve generalizability and reduce model-overconfidence. Soft segmentation outputs provide a measure of uncertainty as can be seen at the lesion boundaries.
-
main_pl_*.py
: These files contain the main code for the four types of experiments, each having a separate file. -
train.sh
: Contains the bash script for calling one of themain_pl_*.py
files to train the model across multiple seeds. -
utils/
: Contains 3 filesa.
create_json_data.py
: Creates ajson
file (in the Decathlon format) for each center based on the defined train/test splitb.
generate_json.sh
: Bash script for generate json files mentioned above.c.
metrics.py
: Contains the implementations of some continual learning metrics. -
plots/
: Contains code for creating the plots described in the paper.
The code uses the following main packages - torch
, pytorch-lightning
, monai
, wandb
, and ivadomed
. It is tested only on a Linux environment with Python 3.8. The first step is to clone the repository:
git clone https://github.com/naga-karthik/continual-learning-ms
Then,
cd continual-learning-ms/
conda create -n venv_cl_ms python=3.8
conda activate venv_cl_ms
pip install -r requirements.txt
If you find this work or code useful in your research, please consider citing:
@article{nagakarthik2022Segmentation,
title={Segmentation of Multiple Sclerosis Lesion across Hospitals: Learn Continually or Train from Scratch?},
author={Naga Karthik, Enamundram and Kerbrat, Anne and Labauge, Pierre and Granberg, Tobias and Talbott, Jason and Reich, Daniel S and Filippi, Massimo and Bakshi, Rohit and Callot, Virginie and Chandar, Sarath and Cohen-Adad, Julien},
journal={MedNeurIPS: Medical Imaging Meets NeurIPS Workshop},
year={2022},
url="https://arxiv.org/pdf/2210.15091.pdf"
}