Check out our colab-demo for a quick example on how test-time training works for multi-coil accelerated MRI reconstruction:
This repository provides code for reproducing the results in the paper:
''Test-Time Training Can Close the Natural Distribution Shift Performance Gap in Deep Learning Based Compressed Sensing,'' by Mohammad Zalbagi Darestani, Jiayu Liu, and Reinhard Heckel.
Code by: Mohammad Zalbagi Darestani ([email protected]) and Reinhard Heckel ([email protected])
In order to study our domain adaptation method for multiple notions of robustness, the considered problem in the paper is accelerated MRI reconstruction where the task is to reconstruct an image from a few measurements. In this regard, we specifically provide experiments to test our method for U-Net and end-to-end variational network (VarNet) under three natural distribution shifts, each designated with a jupyter notebook:
(i) anatomy_shift.ipynb,
(ii) dataset_shift.ipynb, and
(iii) modality_shift.ipynb, and
(iv) acceleration_shift.ipynb.
On a normal computer, it takes aproximately 10 minutes to install all the required softwares and packages.
The code has been tested on the following operating system:
Linux: Ubuntu 20.04.2
To reproduce the results by running each of the jupyter notebooks, the following softwares are required. Assuming the experiment is being performed in a docker container or a linux machine, the following libraries and packages need to be installed:
apt-get update
apt-get install python3.6 # --> or any other system-specific command for installing python3 on your system.
pip install jupyter
pip install numpy
pip install matplotlib
pip install sigpy
pip install h5py
pip install scikit-image
pip install runstats
pip install pytorch_msssim
pip install pytorch-lightning==0.7.5
pip install test-tube
pip install Pillow
If pip does not come with the version of python you installed, install pip manually from here. Also, install pytorch from here according to your system specifications.
Install bart toolbox by following the instructions on their home page.
Note. After installing pytorch lightning, if you run into a 'state-dict' error for VarNet, you might need to replace parsing.py in /opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py from here. This is due to the version mismatch in their recent release (0.7.5).
The experiments are performed on the following datasets:
The fastMRI dataset (both knee and brain datasets are required).
The Stanford dataset (all 19 volumes should be downloaded).
In a train_data folder, we specify which files from the 2 datasets above are used for training and testing. In our training notebook, we show how to access those file names.
To directly run the experiments without training models from the scratch, one can use our model checkpoints for U-Net and VarNet. However, we provide a training code to reproduce model checkpoints as well.
You may simply clone this repository, enter U-Net's or VarNet's folder, and finally run each notebook to reproduce the results.
Note. You need to download the necessary datasets and the checkpoints according to the experiment you intend to run.
Code for training the U-Net and the VarNet is taken from the fastMRI repository with modifications.
We'll update this section later.
This project is covered by Apache 2.0 License.