forked from zhoubolei/CAM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first commit of source code for the class activation mapping
- Loading branch information
0 parents
commit 7e76a6e
Showing
15 changed files
with
181 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# CAM |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
% Sample code to generate class activation map from 10 crops of activations | ||
% Bolei Zhou, March 15, 2016 | ||
% for the online prediction, make sure you have complied matcaffe | ||
|
||
clear | ||
addpath('caffe-for-cudnn-v2.5.48/matlab'); | ||
|
||
imgID = 2; % 1 or 2 | ||
img = imread(['img' num2str(imgID) '.jpg']); | ||
img = imresize(img, [256 256]); | ||
online = 0; % whether extract features online or load pre-extracted features | ||
|
||
load('categories1000.mat'); | ||
if online == 1 | ||
% load the CAM model and extract features | ||
|
||
net_weights = ['googlenet_imagenet/CAMmodels/imagenet_googleletCAM_train_iter_120000.caffemodel']; | ||
net_model = 'googlenet_imagenet/deploy_googlenetCAM.prototxt']; | ||
net = caffe.Net(net_model, net_weights, 'test'); | ||
|
||
weights_LR = net.params('CAM_fc',1).get_data();% get the softmax layer of the network | ||
|
||
scores = net.forward({prepare_image(img)});% extract conv features online | ||
activation_lastconv = net.blobs('CAM_conv').get_data(); | ||
scores = scores{1}; | ||
else | ||
% use the extracted features and softmax parameters cached before hand | ||
load('data_net.mat'); % it contains the softmax weights and the category names of the network | ||
load(['data_img' num2str(imgID) '.mat']); %it contains the pre-extracted conv features | ||
end | ||
|
||
|
||
|
||
|
||
%% Class Activation Mapping | ||
|
||
topNum = 5; % generate heatmap for top X prediction results | ||
scoresMean = mean(scores,2); | ||
[value_category, IDX_category] = sort(scoresMean,'descend'); | ||
[curCAMmapAll] = returnCAMmap(activation_lastconv, weights_LR(:,IDX_category(1:topNum))); | ||
|
||
curResult = im2double(img); | ||
curPrediction = ''; | ||
|
||
for j=1:topNum | ||
curCAMmap_crops = squeeze(curCAMmapAll(:,:,j,:)); | ||
curCAMmapLarge_crops = imresize(curCAMmap_crops,[256 256]); | ||
curCAMLarge = mergeTenCrop(curCAMmapLarge_crops); | ||
curHeatMap = imresize(im2double(curCAMLarge),[256 256]); | ||
curHeatMap = im2double(curHeatMap); | ||
curHeatMap = map2jpg(curHeatMap,[], 'jet'); | ||
curHeatMap = im2double(img)*0.2+curHeatMap*0.7; | ||
curResult = [curResult ones(size(curHeatMap,1),8,3) curHeatMap]; | ||
curPrediction = [curPrediction ' --top' num2str(j) ':' categories{IDX_category(j)}]; | ||
|
||
end | ||
figure,imshow(curResult);title(curPrediction) | ||
|
||
if online==1 | ||
caffe.reset_all(); | ||
end |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
function [img] = map2jpg(imgmap, range, colorMap) | ||
imgmap = double(imgmap); | ||
if(~exist('range', 'var') || isempty(range)), range = [min(imgmap(:)) max(imgmap(:))]; end | ||
|
||
heatmap_gray = mat2gray(imgmap, range); | ||
heatmap_x = gray2ind(heatmap_gray, 256); | ||
heatmap_x(isnan(imgmap)) = 0; | ||
|
||
if(~exist('colorMap', 'var')) | ||
img = ind2rgb(heatmap_x, jet(256)); | ||
else | ||
img = ind2rgb(heatmap_x, eval([colorMap '(256)'])); | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
function alignImgMean = mergeTenCrop( CAMmap_crops) | ||
% align the ten crops of CAMmaps back to one image (take a look at caffe | ||
% matlab wrapper about how ten crops are generated) | ||
cropImgSet = zeros([256 256 3 10]); | ||
cropImgSet(:,:,1,:) = CAMmap_crops; | ||
cropImgSet(:,:,2,:) = CAMmap_crops; | ||
cropImgSet(:,:,3,:) = CAMmap_crops; | ||
|
||
|
||
squareSize = 256; | ||
cropSize = size(cropImgSet,1); | ||
indices = [0 squareSize-cropSize] + 1; | ||
|
||
alignImgSet = zeros(256,256,size(cropImgSet,3),'single'); | ||
|
||
|
||
curr = 1; | ||
for i = indices | ||
for j = indices | ||
|
||
curCrop1 = permute(cropImgSet(:,:,:,curr),[2 1 3 4]); | ||
curCrop2 = permute(cropImgSet(end:-1:1,:,:,curr+5),[2 1 3 4]); | ||
|
||
|
||
alignImgSet(i:i+cropSize-1, j:j+cropSize-1,:,curr) = curCrop1; | ||
alignImgSet(i:i+cropSize-1, j:j+cropSize-1,:, curr+5) = curCrop2; | ||
|
||
curr = curr + 1; | ||
|
||
end | ||
end | ||
center = floor(indices(2) / 2)+1; | ||
curCrop1 = permute(cropImgSet(:,:,:,5),[2 1 3 4]); | ||
curCrop2 = permute(cropImgSet(end:-1:1,:,:,10),[2 1 3 4]); | ||
alignImgSet(center:center+cropSize-1, center:center+cropSize-1,:,5) = curCrop1; | ||
alignImgSet(center:center+cropSize-1, center:center+cropSize-1,:, 10) = curCrop2; | ||
alignImgMean = squeeze(sum(sum(abs(alignImgSet),3),4)); | ||
|
||
end | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
function crops_data = prepare_image(im) | ||
% ------------------------------------------------------------------------ | ||
% caffe/matlab/+caffe/imagenet/ilsvrc_2012_mean.mat contains mean_data that | ||
% is already in W x H x C with BGR channels | ||
d = load('ilsvrc_2012_mean.mat'); | ||
mean_data = d.mean_data; | ||
IMAGE_DIM = 256; | ||
CROPPED_DIM = 224; % 224 for googLeNet , 227 for VGG and AlexNet | ||
|
||
% Convert an image returned by Matlab's imread to im_data in caffe's data | ||
% format: W x H x C with BGR channels | ||
im_data = im(:, :, [3, 2, 1]); % permute channels from RGB to BGR | ||
im_data = permute(im_data, [2, 1, 3]); % flip width and height | ||
im_data = single(im_data); % convert from uint8 to single | ||
im_data = imresize(im_data, [IMAGE_DIM IMAGE_DIM], 'bilinear'); % resize im_data | ||
im_data = im_data - mean_data; % subtract mean_data (already in W x H x C, BGR) | ||
|
||
% oversample (4 corners, center, and their x-axis flips) | ||
crops_data = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single'); | ||
indices = [0 IMAGE_DIM-CROPPED_DIM] + 1; | ||
n = 1; | ||
for i = indices | ||
for j = indices | ||
crops_data(:, :, :, n) = im_data(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :); | ||
crops_data(:, :, :, n+5) = crops_data(end:-1:1, :, :, n); | ||
n = n + 1; | ||
end | ||
end | ||
center = floor(indices(2) / 2) + 1; | ||
crops_data(:,:,:,5) = ... | ||
im_data(center:center+CROPPED_DIM-1,center:center+CROPPED_DIM-1,:); | ||
crops_data(:,:,:,10) = crops_data(end:-1:1, :, :, 5); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Here we provide a sample code to generate class activation maps from the pre-extracted CNN activations for the given image. You need to integrate the feature extraction wrapper inside to make it work for your own CAM network and images. | ||
|
||
Bolei Zhou |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Here we provide a sample code to generate class activation maps from the pre-extracted convolution activations for the given image. | ||
|
||
Bolei Zhou |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function [curColumnMap] = returnCAMmap( featureObjectSwitchSpatial, weights_LR) | ||
%RETURNCOLUMNMAP Summary of this function goes here | ||
% Detailed explanation goes here | ||
|
||
if size(featureObjectSwitchSpatial,4) ==1 | ||
|
||
featureObjectSwitchSpatial_vectorized = reshape(featureObjectSwitchSpatial,[size(featureObjectSwitchSpatial,1)*size(featureObjectSwitchSpatial,2) size(featureObjectSwitchSpatial,3)]); | ||
detectionMap = featureObjectSwitchSpatial_vectorized*weights_LR; | ||
curColumnMap = reshape(detectionMap,[size(featureObjectSwitchSpatial,1),size(featureObjectSwitchSpatial,2), size(weights_LR,2)]); | ||
else | ||
columnSet = zeros(size(featureObjectSwitchSpatial,1),size(featureObjectSwitchSpatial,2),size(weights_LR,2),size(featureObjectSwitchSpatial,4)); | ||
for i=1:size(featureObjectSwitchSpatial,4) | ||
curFeatureObjectSwitchSpatial = squeeze(featureObjectSwitchSpatial(:,:,:,i)); | ||
featureObjectSwitchSpatial_vectorized = reshape(curFeatureObjectSwitchSpatial,[size(curFeatureObjectSwitchSpatial,1)*size(curFeatureObjectSwitchSpatial,2) size(curFeatureObjectSwitchSpatial,3)]); | ||
detectionMap = featureObjectSwitchSpatial_vectorized*weights_LR; | ||
curColumnMap = reshape(detectionMap,[size(featureObjectSwitchSpatial,1),size(featureObjectSwitchSpatial,2), size(weights_LR,2)]); | ||
columnSet(:,:,:,i) = curColumnMap; | ||
end | ||
curColumnMap = columnSet; | ||
end | ||
|
||
|
||
|
||
end | ||
|