Skip to content

Commit

Permalink
Use SeqLSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
JoostvDoorn committed Jun 25, 2016
1 parent 6e66251 commit f167feb
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 20 deletions.
44 changes: 24 additions & 20 deletions Model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ require 'dpnn' -- Adds gradParamClip method
require 'modules/GuidedReLU'
require 'modules/DeconvnetReLU'
require 'modules/GradientRescale'
require 'modules/MinDim'
require 'modules/Bottle'
--nn.FastLSTM.usenngraph = true -- Use faster FastLSTM TODO: Re-enable once nngraph #109 is resolved

local Model = classic.class('Model')
Expand Down Expand Up @@ -112,10 +114,19 @@ function Model:create()
end

-- Add network body
net:add(self:createBody())
local body = self:createBody()
-- Calculate body output size
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(net, _.append({histLen}, self.stateSpec[2]))))
net:add(nn.View(bodyOutputSize))
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(body, _.append({histLen}, self.stateSpec[2]))))
body:add(nn.View(-1, bodyOutputSize))
if not self.async and self.recurrent then
net:add(nn.MinDim(1, 4))
net:add(nn.Transpose({1, 2}))
body = nn.Bottle(body, 4, 2)
end
net:add(body)
if self.recurrent then
net:add(nn.MinDim(1, 3))
end

-- Network head
local head = nn.Sequential()
Expand All @@ -124,9 +135,10 @@ function Model:create()
-- Value approximator V^(s)
local valStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
valStream:add(lstm)
valStream:add(nn.Select(-3, -1)) -- Select the second dimension
else
valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
valStream:add(nn.ReLU(true))
Expand All @@ -136,9 +148,10 @@ function Model:create()
-- Advantage approximator A^(s, a)
local advStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
advStream:add(lstm)
advStream:add(nn.Select(-3, -1)) -- Select the second dimension
else
advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
advStream:add(nn.ReLU(true))
Expand All @@ -158,12 +171,10 @@ function Model:create()
head:add(DuelAggregator(self.m))
else
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
head:add(lstm)
if self.async then
lstm:remember('both')
end
head:add(nn.Select(-3, -1)) -- Select the second dimension
else
head:add(nn.Linear(bodyOutputSize, self.hiddenSize))
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM
Expand Down Expand Up @@ -215,20 +226,13 @@ function Model:create()
if not self.a3c then
net:add(nn.JoinTable(1, 1))
net:add(nn.View(heads, self.m))

if not self.async and self.recurrent then
local sequencer = nn.Sequencer(net)
sequencer:remember('both') -- Keep hidden state between forward calls; requires manual calls to forget
net = nn.Sequential():add(nn.SplitTable(1, #self.stateSpec[2] + 1)):add(sequencer):add(nn.SelectTable(-1))
end
end

-- GPU conversion
if self.gpu > 0 then
require 'cunn'
net:cuda()
end

-- Save reference to network
self.net = net

Expand Down
61 changes: 61 additions & 0 deletions modules/Bottle.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
local Bottle, parent = torch.class("nn.Bottle", "nn.Container")

function Bottle:__init(module, nInputDim, nOutputDim)
parent.__init(self)
self.nInputDim = nInputDim or 2
self.nOutputDim = nOutputDim or self.nInputDim
self.dimDelta = self.nInputDim-self.nOutputDim
-- Used to reshape the gradients
self.shape = torch.Tensor(self.nInputDim)
self.outShape = torch.Tensor(self.nOutputDim)
self.size = nil
-- add module to modules
self.modules[1] = module
end

local function inShape(input)
local size = input:size()
local output = torch.LongTensor(#size)
for i=1,#size do
output[i] = size[i]
end
return output
end

function Bottle:updateOutput(input)
local idx = input:dim()-self.nInputDim+1
-- see if bottle is required
if idx > 1 then
-- bottle the first dims
local size = inShape(input)
self.size = input:size()
local shape = size[{{idx,size:size(1)}}]
self.shape:copy(shape)
local batchSize = size[{{1,idx-1}}]:prod()
self.shape[{{1}}]:mul(batchSize)
-- Forward with the module's dimension
local output = self.modules[1]:updateOutput(input:view(unpack(torch.totable(self.shape))))
assert(output:dim() == self.nOutputDim, "Wrong nr of output dims on module, nr: "..tostring(not output or output:dim()))
self.outShape:copy(inShape(output))
if math.abs(self.dimDelta)>0 then
size:resize(size:size(1)-self.dimDelta)
end
size[{{idx,size:size(1)}}]:copy(self.outShape)
size[{{idx}}]:div(batchSize)
-- unbottle
self.output = output:view(unpack(torch.totable(size)))
else
self.output = self.modules[1]:updateOutput(input)
end
return self.output
end

function Bottle:updateGradInput(input, gradOutput)
if input:dim()>self.nInputDim then
self.modules[1]:updateGradInput(input:view(unpack(torch.totable(self.shape))), gradOutput:view(unpack(torch.totable(self.outShape))))
self.gradInput = self.modules[1].gradInput:view(self.size)
else
self.gradInput = self.modules[1]:updateGradInput(input)
end
return self.gradInput
end
33 changes: 33 additions & 0 deletions modules/MinDim.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
local MinDim, parent = torch.class('nn.MinDim', 'nn.Module')

local function _assertTensor(t)
assert(torch.isTensor(t), "This module only works on tensor")
end

function MinDim:__init(pos, minInputDims)
parent.__init(self)
self.pos = pos or error('the position to insert singleton dim not specified')
self:setMinInputDims(minInputDims)
end

function MinDim:setMinInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end

function MinDim:updateOutput(input)
_assertTensor(input)
self.output = input
if input:dim() < self.numInputDims then
nn.utils.addSingletonDimension(self.output, input, self.pos)
end
return self.output
end

function MinDim:updateGradInput(input, gradOutput)
_assertTensor(input)
_assertTensor(gradOutput)
assert(input:nElement() == gradOutput:nElement())
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end

0 comments on commit f167feb

Please sign in to comment.