Skip to content

Tensorflow implementation of Gated Shape CNN for Semantic Segmentation

License

Notifications You must be signed in to change notification settings

ben-davidson-6/Gated-SCNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

version badge version badge

Gated-Shape CNN for Semantic Segmentation in Tensorflow 2.0

Implementation of Gated-Shape CNN for Semantic Segmentation (ICCV 2019) for semantic segmentation. Started as a way for me to learn tensorflow==2.2.0.

Performance on CityScapes

Model weights are here

Implementation mean road sidewalk building wall fence pole traffic light traffic sign vegetation terrain sky person rider car truck bus train motorcycle bicycle
Paper 80.8 98.3 86.3 93.3 55.8 64 70.8 75.9 83.1 93 65.1 95.2 85.3 67.9 96 80.8 91.2 83.3 69.6 80.4
This repo 77.7 97.8 83.2 92.4 55.8 59.8 64.4 67.6 77.5 92.4 63 94.9 81.9 63 95 80.8 86 78.3 65.2 77.1
stuttgart_02

Install

Project uses semantic versioning Maj.Min.Fix, see the badge for most recent version.

pip install git+https://github.com/ben-davidson-6/[email protected]

Note that this will not work with tensorflow < 2.2.0

Training on your own data

Just give me the network!

import gated_shape_cnn.model
number_classes = 10
# will create a tf.keras.model.Model
model = gated_shape_cnn.model.GSCNN(n_classes=number_classes)
output = model(some_input)
logits, shape_head = output[..., :-1], output[..., -1:]

Using the full training loop

You have two options

Inheriting from gated_shape_cnn.training.Dataset

To use the existing dataset class you need to have all of your images, labels, and edge boundaries prepared ahead of time, and in the following format.

  1. First make sure you have the data in the right format

    • Images are 3 channels [h, w, 3] you should be able to use anything that works inside of tf.io.decode_image. Note this does not include tiffs
    • Your segmentation should be flat, so of the shape [h, w, 1] with the last channel containing the class id and masks should be .png's
    • Edge segmentations should follow the same format as segmentations. If you do not have edge segmentations you can create them with gated_shape_cnn.training.utils.flat_label_to_edge_label
  2. Implement a method in your class which inherits from gated_shape_cnn.training.Dataset.

    • get_paths should return 3 lists which contain the paths to images, paths to semantic segs, and paths to edges such that the ith data in each list relates to the ith image
    import gated_shape_cnn.datasets.cityscapes
    import gated_shape_cnn.datasets.cityscapes.raw_dataset
    from gated_shape_cnn.training import Dataset
    
    class CityScapes(Dataset):
    
        def __init__(
                self,
                batch_size,
                network_input_h,
                network_input_w,
                max_crop_downsample,
                colour_aug_factor,
                debug,
                data_dir):
            super(CityScapes, self).__init__(
               gated_shape_cnn.datasets.cityscapes.N_CLASSES,
                batch_size,
                network_input_h,
                network_input_w,
                max_crop_downsample,
                colour_aug_factor,
                debug)
            self.raw_data = gated_shape_cnn.datasets.cityscapes.raw_dataset.CityScapesRaw(data_dir)
    
        def get_paths(self, train):
            """
            :param train: 
            :return image_paths, label_paths, edge_paths:
                image_path[0] -> path to image 0 
                label_paths[0] -> path to semantic seg of image 0
                edge_paths[0] -> path to edge seg of label 0
            """
            split = gated_shape_cnn.datasets.cityscapes.TRAIN if train else gated_shape_cnn.datasets.cityscapes.VAL
            paths = self.raw_data.dataset_paths(split)
            image_paths, label_paths, edge_paths = zip(*paths)
            return list(image_paths), list(label_paths), list(edge_paths)
  3. train your model using the gated_shape_cnn.training.train_model

    from gated_shape_cnn.training import train_model
    
    train_model(
        n_classes=instance_of_subclassed.n_classes,
        train_data=instance_of_subclassed.build_training_dataset(),
        val_data=instance_of_subclassed.build_validation_dataset(),
        optimiser=optimiser,
        epochs=300,
        log_dir='./logs',
        model_dir='./logs/model',
        accum_iterations=4,
        loss_weights=(1., 20., 1., 1.))

Building your own dataset

You do not have to go through the Dataset class to use the training loop, all you need to provide are two tf.data.Dataset's which when iterated over are of the form

for im, label, edge_label in dataset:
    # im         [b, h, w, 3]       tf.float32 note this is not normalised, as the xception preproccesing is part of the model
    # label      [b, h, w, classes] tf.float32
    # edge_label [b, h, w, 2]       tf.float32
    pass

You can then feed these into train_model for train_data and val_data.

Inference

If you want to convert your model to saved model format

from gated_shape_cnn.model import export_model, GSCNNInfer

# build a saved model
export_model(
    classes=num_classes, 
    ckpt_path='/path/to/weights', 
    out_dir='/dir/to/save/model/',)

# Helper to use the saved model
# can resize image if dont want to, or cant run inference on full size
model = GSCNNInfer('/dir/to/save/model/', resize=None)
seg, shape_head = model(path_or_imageio_image)

Differences to paper

  • Use Xception instead of WideResnet
  • Only replace the final downsampling layers with atrous convolution (usually you replace both)
  • Use generalised dice loss instead of cross entropy for the edge segmentation

For the results presented above, in comparison to the paper I:

  • Accumulate gradients over iterations as I do not have 8 GPUS! Instead of batch size of 2 synchronised on 8 GPUS I have a batch of size 4 accumulating 4 passes
  • train on a smaller resolution 700x700 versus 800x800

License

Copyright (C) 2019 NVIDIA Corporation. Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler
All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).

Permission to use, copy, modify, and distribute this software and its documentation
for any non-commercial purpose is hereby granted without fee, provided that the above
copyright notice appear in all copies and that both that copyright notice and this
permission notice appear in supporting documentation, and that the name of the author
not be used in advertising or publicity pertaining to distribution of the software
without specific, written prior permission.

THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
~                                                                             

Todo

  • Write tests
    • easy unit tests
    • unit testing training loop
  • add ci/cd so look like I know what I am doing
  • build version using keras.fit

About

Tensorflow implementation of Gated Shape CNN for Semantic Segmentation

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages