-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathCHeadNetMulti.lua
executable file
·108 lines (73 loc) · 3.21 KB
/
CHeadNetMulti.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
--[[
Input:
This module outputs two matrices nb_matches x 2 tensor.
In first column of this tensor are energies of dprog matches for _refPos_
In second column of this tensor are energies of dprog matches for _refNeg_
Note that we contrast refPos with refNeg as well as refPos with negPos
--]]
local headNetMulti, parent = torch.class('nn.headNetMulti', 'nn.Module')
function headNetMulti:__init(arr_idx, headNet )
parent.__init(self)
-- NB! head net should be on GPU for optimal speed
self.headNet = headNet:clone('weight','bias', 'gradWeight','gradBias')
self.update = true;
self.arr_idx = arr_idx:cuda():cudaLong()
end
function headNetMulti:updateOutput(input)
local arr1, arr2 = unpack(input)
local dim = arr1:size(1);
local nb_pairs = self.arr_idx:size(1)
local row = self.arr_idx:select(2,1)
local col = self.arr_idx:select(2,2)
self.output = torch.CudaTensor(dim, dim):zero()
local in1 = arr1:index(1, row)
local in2 = arr2:index(1, col)
self.headNet:forward{ in1, in2 }
local output_vec = self.output:view(dim*dim)
local idx = col + (row-1)*dim;
output_vec:indexAdd(1, idx:long(), self.headNet.output)
-- if we store states for all pairs we will run out of memory
self.headNet:clearState()
return self.output
end
function headNetMulti:updateGradInput(input, gradOutput)
local arr1, arr2 = unpack(input)
local dim = gradOutput:size(1)
-- after we receive output gradients (which are sparse!) we know what pairs are useful
local useful_arr_idx = gradOutput:double():nonzero():cuda() -- for server compatibility
if( useful_arr_idx:numel() > 0 ) then
do
local flat_idx = useful_arr_idx:select(2,2) + (useful_arr_idx:select(2,1) - 1)*dim
local gradOutput_vec = gradOutput:view(dim*dim)
self.useful_gradOutput = gradOutput_vec:index(1, flat_idx:long()):cuda() -- for server compatibility
end
-- selector nets select usefull elements for cost computation
local selectorNet1 = nn.Index(1):cuda()
local selectorNet2 = nn.Index(1):cuda()
selectorNet1:forward{arr1, useful_arr_idx:select(2,1)}
selectorNet2:forward{arr2, useful_arr_idx:select(2,2)}
self.selectorOutput = {selectorNet1.output, selectorNet2.output}
-- we channel selected elements to head net
self.headNet:forward(self.selectorOutput)
-- compute input gradient of the head net
self.headNet:updateGradInput(self.selectorOutput, self.useful_gradOutput)
-- compute input gradient of the selector nets
selectorNet1:backward({arr1, useful_arr_idx:select(2,1)}, self.headNet.gradInput[1])
selectorNet2:backward({arr2, useful_arr_idx:select(2,2)}, self.headNet.gradInput[2])
self.gradInput = {selectorNet1.gradInput[1], selectorNet2.gradInput[1]}
self.update = true;
else
self.update = false;
self.gradInput = {arr1:clone():zero(), arr2:clone():zero()}
end
return self.gradInput
end
function headNetMulti:accGradParameters(input, gradOutput)
if self.update then
self.headNet:accGradParameters(self.selectorOutput, self.useful_gradOutput)
end
end
function headNetMulti:parameters()
weights, gradWeights = self.headNet:parameters()
return weights, gradWeights
end