Skip to content

Commit

Permalink
first commit of source code for the class activation mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoubolei committed Apr 11, 2016
0 parents commit 7e76a6e
Show file tree
Hide file tree
Showing 15 changed files with 181 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# CAM
Binary file added categories1000.mat
Binary file not shown.
Binary file added data_img1.mat
Binary file not shown.
Binary file added data_img2.mat
Binary file not shown.
Binary file added data_net.mat
Binary file not shown.
61 changes: 61 additions & 0 deletions demo.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
% Sample code to generate class activation map from 10 crops of activations
% Bolei Zhou, March 15, 2016
% for the online prediction, make sure you have complied matcaffe

clear
addpath('caffe-for-cudnn-v2.5.48/matlab');

imgID = 2; % 1 or 2
img = imread(['img' num2str(imgID) '.jpg']);
img = imresize(img, [256 256]);
online = 0; % whether extract features online or load pre-extracted features

load('categories1000.mat');
if online == 1
% load the CAM model and extract features

net_weights = ['googlenet_imagenet/CAMmodels/imagenet_googleletCAM_train_iter_120000.caffemodel'];
net_model = 'googlenet_imagenet/deploy_googlenetCAM.prototxt'];
net = caffe.Net(net_model, net_weights, 'test');

weights_LR = net.params('CAM_fc',1).get_data();% get the softmax layer of the network

scores = net.forward({prepare_image(img)});% extract conv features online
activation_lastconv = net.blobs('CAM_conv').get_data();
scores = scores{1};
else
% use the extracted features and softmax parameters cached before hand
load('data_net.mat'); % it contains the softmax weights and the category names of the network
load(['data_img' num2str(imgID) '.mat']); %it contains the pre-extracted conv features
end




%% Class Activation Mapping

topNum = 5; % generate heatmap for top X prediction results
scoresMean = mean(scores,2);
[value_category, IDX_category] = sort(scoresMean,'descend');
[curCAMmapAll] = returnCAMmap(activation_lastconv, weights_LR(:,IDX_category(1:topNum)));

curResult = im2double(img);
curPrediction = '';

for j=1:topNum
curCAMmap_crops = squeeze(curCAMmapAll(:,:,j,:));
curCAMmapLarge_crops = imresize(curCAMmap_crops,[256 256]);
curCAMLarge = mergeTenCrop(curCAMmapLarge_crops);
curHeatMap = imresize(im2double(curCAMLarge),[256 256]);
curHeatMap = im2double(curHeatMap);
curHeatMap = map2jpg(curHeatMap,[], 'jet');
curHeatMap = im2double(img)*0.2+curHeatMap*0.7;
curResult = [curResult ones(size(curHeatMap,1),8,3) curHeatMap];
curPrediction = [curPrediction ' --top' num2str(j) ':' categories{IDX_category(j)}];

end
figure,imshow(curResult);title(curPrediction)

if online==1
caffe.reset_all();
end
Binary file added ilsvrc_2012_mean.mat
Binary file not shown.
Binary file added img1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions map2jpg.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function [img] = map2jpg(imgmap, range, colorMap)
imgmap = double(imgmap);
if(~exist('range', 'var') || isempty(range)), range = [min(imgmap(:)) max(imgmap(:))]; end

heatmap_gray = mat2gray(imgmap, range);
heatmap_x = gray2ind(heatmap_gray, 256);
heatmap_x(isnan(imgmap)) = 0;

if(~exist('colorMap', 'var'))
img = ind2rgb(heatmap_x, jet(256));
else
img = ind2rgb(heatmap_x, eval([colorMap '(256)']));
end

42 changes: 42 additions & 0 deletions mergeTenCrop.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
function alignImgMean = mergeTenCrop( CAMmap_crops)
% align the ten crops of CAMmaps back to one image (take a look at caffe
% matlab wrapper about how ten crops are generated)
cropImgSet = zeros([256 256 3 10]);
cropImgSet(:,:,1,:) = CAMmap_crops;
cropImgSet(:,:,2,:) = CAMmap_crops;
cropImgSet(:,:,3,:) = CAMmap_crops;


squareSize = 256;
cropSize = size(cropImgSet,1);
indices = [0 squareSize-cropSize] + 1;

alignImgSet = zeros(256,256,size(cropImgSet,3),'single');


curr = 1;
for i = indices
for j = indices

curCrop1 = permute(cropImgSet(:,:,:,curr),[2 1 3 4]);
curCrop2 = permute(cropImgSet(end:-1:1,:,:,curr+5),[2 1 3 4]);


alignImgSet(i:i+cropSize-1, j:j+cropSize-1,:,curr) = curCrop1;
alignImgSet(i:i+cropSize-1, j:j+cropSize-1,:, curr+5) = curCrop2;

curr = curr + 1;

end
end
center = floor(indices(2) / 2)+1;
curCrop1 = permute(cropImgSet(:,:,:,5),[2 1 3 4]);
curCrop2 = permute(cropImgSet(end:-1:1,:,:,10),[2 1 3 4]);
alignImgSet(center:center+cropSize-1, center:center+cropSize-1,:,5) = curCrop1;
alignImgSet(center:center+cropSize-1, center:center+cropSize-1,:, 10) = curCrop2;
alignImgMean = squeeze(sum(sum(abs(alignImgSet),3),4));

end



32 changes: 32 additions & 0 deletions prepare_image.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
function crops_data = prepare_image(im)
% ------------------------------------------------------------------------
% caffe/matlab/+caffe/imagenet/ilsvrc_2012_mean.mat contains mean_data that
% is already in W x H x C with BGR channels
d = load('ilsvrc_2012_mean.mat');
mean_data = d.mean_data;
IMAGE_DIM = 256;
CROPPED_DIM = 224; % 224 for googLeNet , 227 for VGG and AlexNet

% Convert an image returned by Matlab's imread to im_data in caffe's data
% format: W x H x C with BGR channels
im_data = im(:, :, [3, 2, 1]); % permute channels from RGB to BGR
im_data = permute(im_data, [2, 1, 3]); % flip width and height
im_data = single(im_data); % convert from uint8 to single
im_data = imresize(im_data, [IMAGE_DIM IMAGE_DIM], 'bilinear'); % resize im_data
im_data = im_data - mean_data; % subtract mean_data (already in W x H x C, BGR)

% oversample (4 corners, center, and their x-axis flips)
crops_data = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single');
indices = [0 IMAGE_DIM-CROPPED_DIM] + 1;
n = 1;
for i = indices
for j = indices
crops_data(:, :, :, n) = im_data(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :);
crops_data(:, :, :, n+5) = crops_data(end:-1:1, :, :, n);
n = n + 1;
end
end
center = floor(indices(2) / 2) + 1;
crops_data(:,:,:,5) = ...
im_data(center:center+CROPPED_DIM-1,center:center+CROPPED_DIM-1,:);
crops_data(:,:,:,10) = crops_data(end:-1:1, :, :, 5);
3 changes: 3 additions & 0 deletions readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Here we provide a sample code to generate class activation maps from the pre-extracted CNN activations for the given image. You need to integrate the feature extraction wrapper inside to make it work for your own CAM network and images.

Bolei Zhou
3 changes: 3 additions & 0 deletions readme.txt~
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Here we provide a sample code to generate class activation maps from the pre-extracted convolution activations for the given image.

Bolei Zhou
25 changes: 25 additions & 0 deletions returnCAMmap.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function [curColumnMap] = returnCAMmap( featureObjectSwitchSpatial, weights_LR)
%RETURNCOLUMNMAP Summary of this function goes here
% Detailed explanation goes here

if size(featureObjectSwitchSpatial,4) ==1

featureObjectSwitchSpatial_vectorized = reshape(featureObjectSwitchSpatial,[size(featureObjectSwitchSpatial,1)*size(featureObjectSwitchSpatial,2) size(featureObjectSwitchSpatial,3)]);
detectionMap = featureObjectSwitchSpatial_vectorized*weights_LR;
curColumnMap = reshape(detectionMap,[size(featureObjectSwitchSpatial,1),size(featureObjectSwitchSpatial,2), size(weights_LR,2)]);
else
columnSet = zeros(size(featureObjectSwitchSpatial,1),size(featureObjectSwitchSpatial,2),size(weights_LR,2),size(featureObjectSwitchSpatial,4));
for i=1:size(featureObjectSwitchSpatial,4)
curFeatureObjectSwitchSpatial = squeeze(featureObjectSwitchSpatial(:,:,:,i));
featureObjectSwitchSpatial_vectorized = reshape(curFeatureObjectSwitchSpatial,[size(curFeatureObjectSwitchSpatial,1)*size(curFeatureObjectSwitchSpatial,2) size(curFeatureObjectSwitchSpatial,3)]);
detectionMap = featureObjectSwitchSpatial_vectorized*weights_LR;
curColumnMap = reshape(detectionMap,[size(featureObjectSwitchSpatial,1),size(featureObjectSwitchSpatial,2), size(weights_LR,2)]);
columnSet(:,:,:,i) = curColumnMap;
end
curColumnMap = columnSet;
end



end

0 comments on commit 7e76a6e

Please sign in to comment.