Official Pytorch implementation for our proposed CMF-Net and the corresponding dataset: M3OCTA.
Abstract: Ultra-wide optical coherence tomography angiography (UW-OCTA) is an emerging imaging technique that offers significant advantages over traditional OCTA by providing an exceptionally wide scanning range of up to 24 x 20 mm^{2}, covering both the anterior and posterior regions of the retina. However, the currently accessible UW-OCTA datasets suffer from limited comprehensive hierarchical information and corresponding disease annotations. To address this limitation, we have curated the pioneering M3OCTA dataset, which is the first multimodal (i.e., multilayer), multi-disease, and widest field-of-view UW-OCTA dataset. Furthermore, the effective utilization of multi-layer ultra-wide ocular vasculature information from UW-OCTA remains underdeveloped. To tackle this challenge, we propose the first cross-modal fusion framework that leverages multi-modal information for diagnosing multiple diseases. Through extensive experiments conducted on our openly available M3OCTA dataset, we demonstrate the effectiveness and superior performance of our method, both in fixed and varying modalities settings. The construction of the M3OCTA dataset, the first multimodal OCTA dataset encompassing multiple diseases, aims to advance research in the ophthalmic image analysis community.
Our proposed M3OCTA is the first multi-modal based ultra-wide retinal OCTA dataset, involving 1637 scans from 1046 eyes of 620 individuals imaged in Zigong First People’s Hospital through 24×20 scan mode. Specifically, 1067 scans contains choroid large vessel image; images of 1310 scans from 496 people are labeled as six classes in multi-label setting, including healthy, diabetic retinopathy (DR), diabetic macular edema (DME), Retinal Vein Occlusion (RVO), Hypertension (HBP) and Vitreous Hemorrhage (VH), and then split into train, validation and test set as 6:2:2. The remaining unlabeled data are only used in the pretraining step. Details of our M3OCTA and other public ones are listed in Table.1. Compared with others, M3OCTA dataset demonstrates superiorities in several aspects including the number of modalities, number of patients, image resolution, and FOV.
You can request this dataset through signing this agreement link. After downloading it, please place this dataset into the following directory structure:
Dataset
├── patients
│ ├── patient_ID
│ │ ├── Left
│ │ │ ├── surface.png
│ │ │ ├── deep.png
│ │ │ ├── Choroidal_capilla.png
│ │ │ └── Choroidal_vessel.png
│ │ ├── Right
│ │ └── ...
│ └── ...
├── pretrain.json
├── train.json
├── val.json
├── test.json
└── readme.txt
Different json files corresponding the pretraining, train, val and test stage. All of them are saved in a unifed way. For example,
{
"20221229173526": { # the patient ID
"ID": "20221229173526",
"Right": { # the scan on the right eye
"Label": [ # the annotated label
0, # DR
0, # DME
0, # RVO
0, # HBP
0 # VH
],
"Paths": { # the file paths for each image
"deep": "patients\\20221229173526\\Right\\deep.png",
"Choroidal_capilla": "patients\\20221229173526\\Right\\Choroidal_capilla.png",
"surface": "patients\\20221229173526\\Right\\surface.png"
},
"Scan_protocol": [ # the scanning range
24.0,
20.0
]
},
"Right1": { # the second scan on the right eye
"Label": [
0,
0,
0,
0,
0
],
"Paths": {
"deep": "patients\\20221229173526\\Right1\\deep.png",
"Choroidal_capilla": "patients\\20221229173526\\Right1\\Choroidal_capilla.png",
"surface": "patients\\20221229173526\\Right1\\surface.png"
},
"Scan_protocol": [
24.0,
20.0
]
}
},
....
}
We design a novel framework that leverages the pre-trained transformer encoder and fuses multi-modal information for diagnostics The proposed framework is divided into two stages: transformer encoder pretraining and multi-modal fusion. In other words, we would first pretrain the encoder and then forzen the encoder to finetune the decoder on the downstram tasks.
The pretraining set includes all the unlabeled samples (327 scans) and the whole train set (789 scans).
To pretrain the encoder, run:
CUDA_VISIBLE_DEVICES=4,5 OMP_NUM_THREADS=1 torchrun --nproc_per_node=2 run_pretraining_octa.py \
--config cfgs/pretrain/multimae-b_196_octa4_1600e.yaml --pretrain_path ./pretrain_weights/multimae_pretrain.pth \
--num_workers 10 --batch_size 80 --output_dir ./output/pretrain/multimae-b_196_octa4_1600e > logs/train_mmae_196_octa4_1600e.log 2>&1
multimae_pretrain.pth means the pretrained weights from MultiMAE. You can download it by this link and then place this file into
./pretrain_weights
directory.
The training scripts support both YAML config files and command-line arguments. To modify pre-training settings, either edit / add config files or provide additional command-line arguments.
For a list of possible arguments, see run_pretraining_octa.py
. When changing settings, make sure to modify the output_dir
and wandb_run_name
(if logging is activated) to reflect the changes.
To activate logging to Weights & Biases, either edit the config files or use the --log_wandb flag along with any other extra logging arguments.
To finetune the decoder, first, please download the weight file saved in pretraining stage. Then, place this file the directory as
/output/pretrain/multimae-b_196_octa4_1600e/checkpoint-1599.pth
, which can be set in cfgs/finetune/cls/ft_octa_100e_multimae-b.yaml
.
You can place this file in any directory that you like. Please remember modify the path in the following command to your customized directory.
Finally, run:
CUDA_VISIBLE_DEVICES=5 OMP_NUM_THREADS=1 torchrun --nproc_per_node=1 --master_port=29505 run_finetuning_cls_octa.py \
--config cfgs/finetune/cls/ft_octa_100e_multimae-b.yaml --in_domains surface-deep-Choroidal_capilla-Choroidal_vessel \
--finetune ./output/pretrain/multimae-b_196_octa4_1600e/checkpoint-1599.pth --input_size 224 --num_workers 4 --drop 0.2 \
--attn_drop_rate 0.1 --epochs 100 --type v3 --output_dir ./finetune/octa-b_196_octa4_100e_mask > logs/finetune_octa_196_octa4_100e_mask.log 2>&1
First, the trained weights in finetuning stage needed to prepared well. Please download it by this link.
Then, place this file at ./output/finetune/octa-b_196_octa4_100e_mask/
You can place this file in any directory that you like. Please remember modify the path in the following command to your customized directory.
Then, run the following commands to compute metrics:
CUDA_VISIBLE_DEVICES=2 python run_evaluation_cls_octa.py --config cfgs/finetune/cls/ft_octa_100e_multimae-b.yaml \
--in_domains surface-deep-Choroidal_capilla-Choroidal_vessel --finetune ./output/finetune/octa-b_196_octa4_100e_mask/checkpoint-best.pth \
--input_size 224 --num_workers 10 --batch_size 128 --type v3 --mlp_ratio 0.5 --test --output_dir ./output/debug
Will get the following results:
All compute: ACC 0.8498 AUC: 0.8583 Pre: 0.6019 Recall: 0.6186 AP: 0.5999 F1: 0.596
For missing modalities cases, please modify --in_domains
:
# for 3 modalities
--in_domains surface-deep-Choroidal_capilla
# Results would be: All compute: ACC 0.8349 AUC: 0.8603 Pre: 0.5651 Recall: 0.6321 AP: 0.5884 F1: 0.5874
# for 2 modalities
--in_domains surface-deep
# Results would be: All compute: ACC 0.8342 AUC: 0.8597 Pre: 0.579 Recall: 0.5444 AP: 0.5834 F1: 0.5416
# for 1 modalities
--in_domains surface
# Results would be: All compute: ACC 0.8235 AUC: 0.8309 Pre: 0.5077 Recall: 0.5271 AP: 0.5298 F1: 0.5098
Table 2 lists all results of compared methods, where the lower boundary of theoretical performance by random predictions is shown in the first row. Compared with other approaches, our network achieves state-of-the-art (SOTA) performance across all evaluation metrics in four-modal and two-modal settings.
The typo: There is a typo that the AUC of our method should be
85.83
, instead of84.77
in the original paper.
To validate the performance gain brought by the multiple modalities, we pre-train another encoder on only one modality, i.e., the retinal surface image, and then finetune the decoder on these two encoders by 1 (retinal surface), 2 (retinal surface + deep), 3 (retinal surface + deep + choroid Capillary), or all four modalities. The results are shown in Table 3.
We subjected our approach (which incorporates a 4-modal fine-tuned decoder with a 4-modal pre-trained encoder) to an evaluation across diverse modalities for validating stability when facing varying modalities shown in Fig.4(a). With the number of input modalities increases, the performance gain is also observed in almost all comparable methods, where our model shows more stability than others, demonstrating the effectiveness of learned multi-modal features in our design. Further, this stability also can be observed in all five disease types in Fig.4(b)
This repository is built using the MultiMAE.
If you find this repository helpful, please consider citing our work:
@article{wei2023leveraging,
title={Leveraging Multimodal Fusion for Enhanced Diagnosis of Multiple Retinal Diseases in Ultra-wide OCTA},
author={Wei, Hao and Shi, Peilun and Bai, Guitao and Zhang, Minqing and Li, Shuangle and Yuan, Wu},
booktitle = {IEEE International Symposium on Biomedical Imaging},
year={2024}
}