forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DiagHessian.lua
129 lines (114 loc) · 4.86 KB
/
DiagHessian.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
-- Module
function nn.Module.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or diagHessianOutput
return self.diagHessianInput
end
function nn.Module.accDiagHessianParameters(self, input, diagHessianOutput, scale)
end
function nn.Module.initDiagHessianParameters(self)
end
-- Criterion
function nn.Criterion.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or self.output.new()
return self.diagHessianInput
end
-- MSECriterion
function nn.MSECriterion.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or input.new()
self.diagHessianInput:resizeAs(input):fill(1)
return self.diagHessianInput
end
-- Linear
function nn.Linear.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or self.output.new()
self.weightSq = self.weightSq or self.output.new():resizeAs(self.weight)
self.weightSq:copy(self.weight):cmul(self.weightSq)
if input:dim() == 1 then
self.diagHessianInput:resizeAs(input)
self.diagHessianInput:addmv(0, 1, self.weightSq:t(), diagHessianOutput)
elseif input:dim() == 2 then
self.diagHessianInput:resizeAs(input)
self.diagHessianInput:addmm(0, 1, diagHessianOutput, self.weightSq)
end
return self.diagHessianInput
end
function nn.Linear.initDiagHessianParameters(self)
self.diagHessianWeight = self.diagHessianWeight or self.output.new():resizeAs(self.weight)
self.diagHessianBias = self.diagHessianBias or self.output.new():resizeAs(self.bias)
end
function nn.Linear.accDiagHessianParameters(self, input, diagHessianOutput, scale)
scale = scale or 1
self.inputSq = self.inputSq or self.output.new()
self.inputSq:resizeAs(input):copy(input):cmul(self.inputSq)
if input:dim() == 1 then
self.diagHessianWeight:addr(scale, diagHessianOutput, self.inputSq)
self.diagHessianBias:add(scale, diagHessianOutput)
elseif input:dim() == 2 then
local nframe = input:size(1)
local nunit = self.bias:size(1)
self.diagHessianWeight:addmm(scale, diagHessianOutput:t(), self.inputSq)
self.diagHessianBias:addmv(scale, diagHessianOutput:t(), self.output.new(nframe):fill(1))
end
end
-- Tanh
function nn.Tanh.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or self.output.new()
self.derivativeSq = self.derivativeSq or self.output.new()
self.derivativeSq:resizeAs(self.output):copy(self.output):cmul(self.output):mul(-1):add(1)
self.derivativeSq:cmul(self.derivativeSq)
self.diagHessianInput:resizeAs(input):copy(diagHessianOutput):cmul(self.derivativeSq)
return self.diagHessianInput
end
-- Sequential
function nn.Sequential.backwardDiagHessian(self, input, diagHessianOutput)
local currentDiagHessianOutput = diagHessianOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
currentDiagHessianOutput = currentModule:backwardDiagHessian(previousModule.output, currentDiagHessianOutput)
currentModule = previousModule
end
currentDiagHessianOutput = currentModule:backwardDiagHessian(input, currentDiagHessianOutput)
self.diagHessianInput = currentDiagHessianOutput
return currentDiagHessianOutput
end
function nn.Sequential.initDiagHessianParameters(self)
for i=1,#self.modules do
self.modules[i]:initDiagHessianParameters()
end
end
function nn.Sequential.accDiagHessianParameters(self, input, diagHessianOutput, scale)
scale = scale or 1
local currentDiagHessianOutput = diagHessianOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
currentModule:accDiagHessianParameters(previousModule.output, currentDiagHessianOutput, scale)
currentDiagHessianOutput = currentModule.diagHessianInput
currentModule = previousModule
end
currentModule:accDiagHessianParameters(input, currentDiagHessianOutput, scale)
end
-- ConcatTable
function nn.ConcatTable.backwardDiagHessian(self, input, diagHessianOutput)
for i,module in ipairs(self.modules) do
local currentDiagHessianInput = module:backward(input, diagHessianOutput[i])
if i == 1 then
self.diagHessianInput:resizeAs(currentDiagHessianInput):copy(currentDiagHessianInput)
else
self.diagHessianInput:add(currentDiagHessianInput)
end
end
return self.diagHessianInput
end
function nn.ConcatTable.initDiagHessianParameters(self)
for i=1,#self.modules do
self.modules[i]:initDiagHessianParameters()
end
end
function nn.ConcatTable.accDiagHessianParameters(self, input, diagHessianOutput, scale)
scale = scale or 1
for i,module in ipairs(self.modules) do
module:accDiagHessianParameters(input, diagHessianOutput[i], scale)
end
end