-
Notifications
You must be signed in to change notification settings - Fork 7
/
Bootstrap.lua
98 lines (76 loc) · 2.44 KB
/
Bootstrap.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
--[[
Deep Exploration via Bootstrapped DQN
Ian Osband, Charles Blundell, Alexander Pritzel, Benjamin Van Roy
Implemented by Yannis M. Assael (www.yannisassael.com), 2016
Usage: nn.Bootstrap(nn.Linear(size_in, size_out), 10, 0.08)
]]--
local Bootstrap, parent = torch.class('nn.Bootstrap', 'nn.Module')
function Bootstrap:__init(mod, k, param_init)
parent.__init(self)
self.k = k
self.active = {}
self.param_init = param_init
self.mod = mod:clearState()
self.mods = {}
self.mods_container = nn.Container()
for k=1,self.k do
if self.param_init then
-- By default nn.Linear multiplies with math.sqrt(3)
self.mods[k] = self.mod:clone():reset(self.param_init / math.sqrt(3))
else
self.mods[k] = self.mod:clone():reset()
end
self.mods_container:add(self.mods[k])
end
end
function Bootstrap:clearState()
self.active = {}
self.mods_container:clearState()
return parent.clearState(self)
end
function Bootstrap:parameters(...)
return self.mods_container:parameters(...)
end
function Bootstrap:type(type, tensorCache)
return parent.type(self, type, tensorCache)
end
function Bootstrap:updateOutput(input)
-- resize output
if input:dim() == 1 then
self.output:resize(self.mod.weight:size(1))
elseif input:dim() == 2 then
local nframe = input:size(1)
self.output:resize(nframe, self.mod.weight:size(1))
end
self.output:zero()
-- reset active heads
self.active = {}
-- pick a random k
local k = torch.random(self.k)
-- select active heads
for i=1,k do
self.active[i] = torch.random(self.k)
self.output:add(self.mods[self.active[i]]:updateOutput(input))
end
self.output:div(#self.active)
return self.output
end
function Bootstrap:updateGradInput(input, gradOutput)
-- rescale gradients
gradOutput:div(#self.active)
-- resize gradinput
self.gradInput:resizeAs(input):zero()
-- accumulate gradinputs
for i=1,#self.active do
self.gradInput:add(self.mods[self.active[i]]:updateGradInput(input, gradOutput))
end
return self.gradInput
end
function Bootstrap:accGradParameters(input, gradOutput, scale)
-- rescale gradients
gradOutput:div(#self.active)
-- accumulate grad parameters
for i=1,#self.active do
self.mods[self.active[i]]:accGradParameters(input, gradOutput, scale)
end
end