From c6add37a001b47824fc6288e68cf523c1366d67a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 4 Mar 2020 19:33:43 +0800 Subject: [PATCH] [DLMED] add 3D medical image classification example --- examples/densenet_classification_3d.py | 136 +++++++++++++++++++++++++ monai/data/nifti_reader.py | 19 ++-- 2 files changed, 148 insertions(+), 7 deletions(-) create mode 100644 examples/densenet_classification_3d.py diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d.py new file mode 100644 index 0000000000..07ac3ffe04 --- /dev/null +++ b/examples/densenet_classification_3d.py @@ -0,0 +1,136 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms + +from monai.data.nifti_reader import NiftiDataset +from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.handlers.stats_handler import StatsHandler +from ignite.metrics import Accuracy +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() + +# FIXME: temp test dataset, Wenqi will replace later +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) + +# Define transforms for image and segmentation +imtrans = transforms.Compose([ + Rescale(), + AddChannel(), + UniformRandomPatch((96, 96, 96)), + ToTensor() +]) + +# Define nifti dataset, dataloader. +ds = NiftiDataset(image_files=images, labels=labels, transform=imtrans) +loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +im, label = monai.utils.misc.first(loader) +print(type(im), im.shape, label) + +lr = 1e-5 + +# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +loss = torch.nn.CrossEntropyLoss() +opt = torch.optim.Adam(net.parameters(), lr) + +# Create trainer +device = torch.device("cuda:0") +trainer = create_supervised_trainer(net, opt, loss, device, False) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) +train_stats_handler = StatsHandler() +train_stats_handler.attach(trainer) + +@trainer.on(Events.EPOCH_COMPLETED) +def log_training_loss(engine): + engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output) + +# Set parameters for validation +validation_every_n_epochs = 1 +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = NiftiDataset(image_files=images[-5:], labels=labels[-5:], transform=imtrans) +val_loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + +# create a training data loader +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +train_ds = NiftiDataset(image_files=images[:15], labels=labels[:15], transform=imtrans) +train_loader = DataLoader(train_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index c803411b3b..012b46e4e4 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -66,7 +66,7 @@ class NiftiDataset(Dataset): for the image and segmentation arrays separately. """ - def __init__(self, image_files, seg_files, as_closest_canonical=False, + def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, seg_transform=None, image_only=True, dtype=None): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied @@ -74,7 +74,8 @@ def __init__(self, image_files, seg_files, as_closest_canonical=False, Args: image_files (list of str): list of image filenames - seg_files (list of str): list of segmentation filenames + seg_files (list of str): if in segmentation task, list of segmentation filenames + labels (list or array): if in classification task, list of classification labels as_closest_canonical (bool): if True, load the image as closest to canonical orientation transform (Callable, optional): transform to apply to image arrays seg_transform (Callable, optional): transform to apply to segmentation arrays @@ -82,11 +83,12 @@ def __init__(self, image_files, seg_files, as_closest_canonical=False, dtype (np.dtype, optional): if not None convert the loaded image to this data type """ - if len(image_files) != len(seg_files): + if seg_files is not None and len(image_files) != len(seg_files): raise ValueError('Must have same number of image and segmentation files') self.image_files = image_files self.seg_files = seg_files + self.labels = labels self.as_closest_canonical = as_closest_canonical self.transform = transform self.seg_transform = seg_transform @@ -104,7 +106,10 @@ def __getitem__(self, index): else: img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype) - seg = load_nifti(self.seg_files[index]) + if self.seg_files is not None: + target = load_nifti(self.seg_files[index]) + elif self.labels is not None: + target = self.labels[index] # https://github.com/pytorch/vision/issues/9#issuecomment-304224800 seed = np.random.randint(2147483647) @@ -116,12 +121,12 @@ def __getitem__(self, index): if self.seg_transform is not None: np.random.seed(seed) # ensure randomized transforms roll the same values for segmentations as images - seg = self.seg_transform(seg) + target = self.seg_transform(target) seg_seed = np.random.randint(2147483647) assert(random_sync_test == seg_seed) if self.image_only or meta_data is None: - return img, seg + return img, target compatible_meta = {} for meta_key in meta_data: @@ -130,4 +135,4 @@ def __getitem__(self, index): and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: continue compatible_meta[meta_key] = meta_datum - return img, seg, compatible_meta + return img, target, compatible_meta