From 256ab7c4d231b29bcdf8c138b92992efbf7e69db Mon Sep 17 00:00:00 2001 From: Khedia Date: Thu, 8 Nov 2018 14:03:05 -0800 Subject: [PATCH 1/7] fixing gradcam --- example/cnn_visualization/gradcam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/cnn_visualization/gradcam.py b/example/cnn_visualization/gradcam.py index a8708f787584..54cb65eef11b 100644 --- a/example/cnn_visualization/gradcam.py +++ b/example/cnn_visualization/gradcam.py @@ -249,8 +249,8 @@ def visualize(net, preprocessed_img, orig_img, conv_layer_name): imggrad = get_image_grad(net, preprocessed_img) conv_out, conv_out_grad = get_conv_out_grad(net, preprocessed_img, conv_layer_name=conv_layer_name) - cam = get_cam(imggrad, conv_out) - + cam = get_cam(conv_out_grad, conv_out) + cam = cv2.resize(cam, (imggrad.shape[1], imggrad.shape[2])) ggcam = get_guided_grad_cam(cam, imggrad) img_ggcam = grad_to_image(ggcam) From 50feef5a21b20b88104a5f665e0ec237be363500 Mon Sep 17 00:00:00 2001 From: Khedia Date: Thu, 8 Nov 2018 15:37:59 -0800 Subject: [PATCH 2/7] changed loading parameters code --- example/cnn_visualization/vgg.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/example/cnn_visualization/vgg.py b/example/cnn_visualization/vgg.py index b6215a334e3b..a8a0ef6c8dee 100644 --- a/example/cnn_visualization/vgg.py +++ b/example/cnn_visualization/vgg.py @@ -72,11 +72,17 @@ def get_vgg(num_layers, pretrained=False, ctx=mx.cpu(), root=os.path.join('~', '.mxnet', 'models'), **kwargs): layers, filters = vgg_spec[num_layers] net = VGG(layers, filters, **kwargs) - if pretrained: - from mxnet.gluon.model_zoo.model_store import get_model_file - batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else '' - net.load_params(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix), - root=root), ctx=ctx) + net.initialize(ctx=ctx) + + # Get the pretrained model + vgg = mx.gluon.model_zoo.vision.get_vgg(num_layers, pretrained=True, ctx=ctx) + + # Set the parameters in the new network + params = vgg.collect_params() + for key in params: + param = params[key] + net.collect_params()[net.prefix+key.replace(vgg.prefix, '')].set_data(param.data()) + return net def vgg16(**kwargs): From 9f490af21ae4b5eebae1fe8df8328e11a64464af Mon Sep 17 00:00:00 2001 From: Khedia Date: Mon, 12 Nov 2018 15:07:22 -0800 Subject: [PATCH 3/7] fixing type conversions issue with previous versions of matplotlib --- docs/tutorials/vision/cnn_visualization.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/vision/cnn_visualization.md b/docs/tutorials/vision/cnn_visualization.md index 940c261efc8f..a350fffaa36f 100644 --- a/docs/tutorials/vision/cnn_visualization.md +++ b/docs/tutorials/vision/cnn_visualization.md @@ -151,7 +151,8 @@ def show_images(pred_str, images): for i in range(num_images): fig.add_subplot(rows, cols, i+1) plt.xlabel(titles[i]) - plt.imshow(images[i], cmap='gray' if i==num_images-1 else None) + img = images[i].astype(np.uint8) + plt.imshow(img, cmap='gray' if i==num_images-1 else None) plt.show() ``` From cc6e2cd69b7a368084d06abe19b3832fb66458b6 Mon Sep 17 00:00:00 2001 From: Khedia Date: Tue, 13 Nov 2018 15:49:49 -0800 Subject: [PATCH 4/7] gradcam consolidation --- docs/conf.py | 2 +- .../tutorial_utils}/gradcam.py | 0 docs/tutorials/vision/cnn_visualization.md | 2 +- example/cnn_visualization/README.md | 17 --- example/cnn_visualization/gradcam_demo.py | 110 ------------------ example/cnn_visualization/vgg.py | 90 -------------- 6 files changed, 2 insertions(+), 219 deletions(-) rename {example/cnn_visualization => docs/tutorial_utils}/gradcam.py (100%) delete mode 100644 example/cnn_visualization/README.md delete mode 100644 example/cnn_visualization/gradcam_demo.py delete mode 100644 example/cnn_visualization/vgg.py diff --git a/docs/conf.py b/docs/conf.py index 656a1da96d69..af235219e4f1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -107,7 +107,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['3rdparty', 'build_version_doc', 'virtualenv', 'api/python/model.md', 'README.md'] +exclude_patterns = ['3rdparty', 'build_version_doc', 'virtualenv', 'api/python/model.md', 'README.md', 'tutorial_utils'] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None diff --git a/example/cnn_visualization/gradcam.py b/docs/tutorial_utils/gradcam.py similarity index 100% rename from example/cnn_visualization/gradcam.py rename to docs/tutorial_utils/gradcam.py diff --git a/docs/tutorials/vision/cnn_visualization.md b/docs/tutorials/vision/cnn_visualization.md index a350fffaa36f..2fd22d408051 100644 --- a/docs/tutorials/vision/cnn_visualization.md +++ b/docs/tutorials/vision/cnn_visualization.md @@ -22,7 +22,7 @@ from matplotlib import pyplot as plt import numpy as np gradcam_file = "gradcam.py" -base_url = "https://raw.githubusercontent.com/indhub/mxnet/cnnviz/example/cnn_visualization/{}?raw=true" +base_url = "https://github.com/apache/incubator-mxnet/tree/master/docs/tutorial_utils/{}?raw=true" mx.test_utils.download(base_url.format(gradcam_file), fname=gradcam_file) import gradcam ``` diff --git a/example/cnn_visualization/README.md b/example/cnn_visualization/README.md deleted file mode 100644 index 10b91492600e..000000000000 --- a/example/cnn_visualization/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# Visualzing CNN decisions - -This folder contains an MXNet Gluon implementation of [Grad-CAM](https://arxiv.org/abs/1610.02391) that helps visualize CNN decisions. - -A tutorial on how to use this from Jupyter notebook is available [here](https://mxnet.incubator.apache.org/tutorials/vision/cnn_visualization.html). - -You can also do the visualization from terminal: -``` -$ python gradcam_demo.py hummingbird.jpg -Predicted category : hummingbird (94) -Original Image : hummingbird_orig.jpg -Grad-CAM : hummingbird_gradcam.jpg -Guided Grad-CAM : hummingbird_guided_gradcam.jpg -Saliency Map : hummingbird_saliency.jpg -``` - -![Output of gradcam_demo.py](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/hummingbird_filenames.png) diff --git a/example/cnn_visualization/gradcam_demo.py b/example/cnn_visualization/gradcam_demo.py deleted file mode 100644 index d9ca5ddade8e..000000000000 --- a/example/cnn_visualization/gradcam_demo.py +++ /dev/null @@ -1,110 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import mxnet as mx -from mxnet import gluon - -import argparse -import os -import numpy as np -import cv2 - -import vgg -import gradcam - -# Receive image path from command line -parser = argparse.ArgumentParser(description='Grad-CAM demo') -parser.add_argument('img_path', metavar='image_path', type=str, help='path to the image file') - -args = parser.parse_args() - -# We'll use VGG-16 for visualization -network = vgg.vgg16(pretrained=True, ctx=mx.cpu()) -# We'll resize images to 224x244 as part of preprocessing -image_sz = (224, 224) - -def preprocess(data): - """Preprocess the image before running it through the network""" - data = mx.image.imresize(data, image_sz[0], image_sz[1]) - data = data.astype(np.float32) - data = data/255 - # These mean values were obtained from - # https://mxnet.incubator.apache.org/api/python/gluon/model_zoo.html - data = mx.image.color_normalize(data, - mean=mx.nd.array([0.485, 0.456, 0.406]), - std=mx.nd.array([0.229, 0.224, 0.225])) - data = mx.nd.transpose(data, (2,0,1)) # Channel first - return data - -def read_image_mxnet(path): - with open(path, 'rb') as fp: - img_bytes = fp.read() - return mx.img.imdecode(img_bytes) - -def read_image_cv(path): - return cv2.resize(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB), image_sz) - -# synset.txt contains the names of Imagenet categories -# Load the file to memory and create a helper method to query category_index -> category name -synset_url = "http://data.mxnet.io/models/imagenet/synset.txt" -synset_file_name = "synset.txt" -mx.test_utils.download(synset_url, fname=synset_file_name) - -synset = [] -with open('synset.txt', 'r') as f: - synset = [l.rstrip().split(' ', 1)[1].split(',')[0] for l in f] - -def get_class_name(cls_id): - return "%s (%d)" % (synset[cls_id], cls_id) - -def run_inference(net, data): - """Run the input image through the network and return the predicted category as integer""" - out = net(data) - return out.argmax(axis=1).asnumpy()[0].astype(int) - -def visualize(net, img_path, conv_layer_name): - """Create Grad-CAM visualizations using the network 'net' and the image at 'img_path' - conv_layer_name is the name of the top most layer of the feature extractor""" - image = read_image_mxnet(img_path) - image = preprocess(image) - image = image.expand_dims(axis=0) - - pred_str = get_class_name(run_inference(net, image)) - - orig_img = read_image_cv(img_path) - vizs = gradcam.visualize(net, image, orig_img, conv_layer_name) - return (pred_str, (orig_img, *vizs)) - -# Create Grad-CAM visualization for the user provided image -last_conv_layer_name = 'vgg0_conv2d12' -cat, vizs = visualize(network, args.img_path, last_conv_layer_name) - -print("{0:20}: {1:80}".format("Predicted category", cat)) - -# Write the visualiations into file -img_name = os.path.split(args.img_path)[1].split('.')[0] -suffixes = ['orig', 'gradcam', 'guided_gradcam', 'saliency'] -image_desc = ['Original Image', 'Grad-CAM', 'Guided Grad-CAM', 'Saliency Map'] - -for i, img in enumerate(vizs): - img = img.astype(np.float32) - if len(img.shape) == 3: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - out_file_name = "%s_%s.jpg" % (img_name, suffixes[i]) - cv2.imwrite(out_file_name, img) - print("{0:20}: {1:80}".format(image_desc[i], out_file_name)) - diff --git a/example/cnn_visualization/vgg.py b/example/cnn_visualization/vgg.py deleted file mode 100644 index a8a0ef6c8dee..000000000000 --- a/example/cnn_visualization/vgg.py +++ /dev/null @@ -1,90 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import mxnet as mx -from mxnet import gluon - -import os -from mxnet.gluon.model_zoo import model_store - -from mxnet.initializer import Xavier -from mxnet.gluon.nn import MaxPool2D, Flatten, Dense, Dropout, BatchNorm -from gradcam import Activation, Conv2D - -class VGG(mx.gluon.HybridBlock): - def __init__(self, layers, filters, classes=1000, batch_norm=False, **kwargs): - super(VGG, self).__init__(**kwargs) - assert len(layers) == len(filters) - with self.name_scope(): - self.features = self._make_features(layers, filters, batch_norm) - self.features.add(Dense(4096, activation='relu', - weight_initializer='normal', - bias_initializer='zeros')) - self.features.add(Dropout(rate=0.5)) - self.features.add(Dense(4096, activation='relu', - weight_initializer='normal', - bias_initializer='zeros')) - self.features.add(Dropout(rate=0.5)) - self.output = Dense(classes, - weight_initializer='normal', - bias_initializer='zeros') - - def _make_features(self, layers, filters, batch_norm): - featurizer = mx.gluon.nn.HybridSequential(prefix='') - for i, num in enumerate(layers): - for _ in range(num): - featurizer.add(Conv2D(filters[i], kernel_size=3, padding=1, - weight_initializer=Xavier(rnd_type='gaussian', - factor_type='out', - magnitude=2), - bias_initializer='zeros')) - if batch_norm: - featurizer.add(BatchNorm()) - featurizer.add(Activation('relu')) - featurizer.add(MaxPool2D(strides=2)) - return featurizer - - def hybrid_forward(self, F, x): - x = self.features(x) - x = self.output(x) - return x - -vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), - 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), - 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), - 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} - -def get_vgg(num_layers, pretrained=False, ctx=mx.cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): - layers, filters = vgg_spec[num_layers] - net = VGG(layers, filters, **kwargs) - net.initialize(ctx=ctx) - - # Get the pretrained model - vgg = mx.gluon.model_zoo.vision.get_vgg(num_layers, pretrained=True, ctx=ctx) - - # Set the parameters in the new network - params = vgg.collect_params() - for key in params: - param = params[key] - net.collect_params()[net.prefix+key.replace(vgg.prefix, '')].set_data(param.data()) - - return net - -def vgg16(**kwargs): - return get_vgg(16, **kwargs) - From 15f972e8d1edd16454687833179737f2fc4bbe0c Mon Sep 17 00:00:00 2001 From: Khedia Date: Wed, 14 Nov 2018 10:24:20 -0800 Subject: [PATCH 5/7] creating directory structures in utils --- docs/tutorial_utils/{ => vision/cnn_visualization}/gradcam.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/tutorial_utils/{ => vision/cnn_visualization}/gradcam.py (100%) diff --git a/docs/tutorial_utils/gradcam.py b/docs/tutorial_utils/vision/cnn_visualization/gradcam.py similarity index 100% rename from docs/tutorial_utils/gradcam.py rename to docs/tutorial_utils/vision/cnn_visualization/gradcam.py From 31d450678df95ac814d7db69992e8e757b552647 Mon Sep 17 00:00:00 2001 From: Khedia Date: Wed, 14 Nov 2018 10:27:11 -0800 Subject: [PATCH 6/7] changing location --- docs/tutorials/vision/cnn_visualization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/vision/cnn_visualization.md b/docs/tutorials/vision/cnn_visualization.md index 2fd22d408051..cea420e1385e 100644 --- a/docs/tutorials/vision/cnn_visualization.md +++ b/docs/tutorials/vision/cnn_visualization.md @@ -22,7 +22,7 @@ from matplotlib import pyplot as plt import numpy as np gradcam_file = "gradcam.py" -base_url = "https://github.com/apache/incubator-mxnet/tree/master/docs/tutorial_utils/{}?raw=true" +base_url = "https://github.com/apache/incubator-mxnet/tree/master/docs/tutorial_utils/vision/cnn_visualization/{}?raw=true" mx.test_utils.download(base_url.format(gradcam_file), fname=gradcam_file) import gradcam ``` From d100733f4fd4905f9f05e053273d2915a0910e26 Mon Sep 17 00:00:00 2001 From: Khedia Date: Thu, 15 Nov 2018 10:46:00 -0800 Subject: [PATCH 7/7] empty commit --- docs/tutorials/vision/cnn_visualization.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tutorials/vision/cnn_visualization.md b/docs/tutorials/vision/cnn_visualization.md index cea420e1385e..fd6a464e99d4 100644 --- a/docs/tutorials/vision/cnn_visualization.md +++ b/docs/tutorials/vision/cnn_visualization.md @@ -182,6 +182,7 @@ Next, we'll write a method to get an image, preprocess it, predict category and 2. **Guided Grad-CAM:** Guided Grad-CAM shows which exact pixels contributed the most to the CNN's decision. 3. **Saliency map:** Saliency map is a monochrome image showing which pixels contributed the most to the CNN's decision. Sometimes, it is easier to see the areas in the image that most influence the output in a monochrome image than in a color image. + ```python def visualize(net, img_path, conv_layer_name): orig_img = mx.img.imread(img_path)