-
Notifications
You must be signed in to change notification settings - Fork 19
/
trans_mp.m
50 lines (40 loc) · 1.29 KB
/
trans_mp.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
classdef trans_mp < trans_basic
%TRANS_MP Non-overlapping Max Pooling
% Use Jonathan Masci's implementation (MaxPooling.cpp and .m)
properties
scale; % scale (window size) for subsampling
idx; % index (linear) for the local maximum elements
end
methods
function obj = trans_mp(scale_)
obj.scale = scale_;
end
function [obj, data_o] = ff(obj, data_i, data_o)
%
[data_o.a, obj.idx] = MaxPooling(data_i.a, [obj.scale,obj.scale]);
end % ff
function data_i = deriv_input(obj, data_i, data_o)
% data_i.d: [Hi,Wi,Mi,N]
% data_o.d: [Ho,Wo,Mo,N], where [Hi,Wi] = s*[Ho,Wo]
% infer the size
N = data_o.N;
szs = obj.szs_in; szs(end) = N;
% initialize
data_i.d = zeros( szs );
% up-pooling
data_i.d( obj.idx ) = data_o.d(:);
end % deriv_input
function obj = init_param(obj, szs_in_)
% szs_in_: [a,b,c]. input map size
% Set:
% obj.szs_out: [Hout,Wout,Mout]
% obj.szs_in: [Hin,Win,Min]
% set input map size
obj.szs_in = szs_in_;
% deduce the output map size
tmp = [szs_in_(1), szs_in_(2)];
tmp = tmp ./ obj.scale;
obj.szs_out = [tmp(1), tmp(2), szs_in_(3),szs_in_(4)];
end % init_param
end % methods
end