Skip to content

Commit

Permalink
Use SeqLSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
JoostvDoorn committed Sep 11, 2016
1 parent 5f1afc8 commit 00df464
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 21 deletions.
44 changes: 23 additions & 21 deletions Model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require 'dpnn' -- Adds gradParamClip method
require 'modules/GuidedReLU'
require 'modules/DeconvnetReLU'
require 'modules/GradientRescale'
--nn.FastLSTM.usenngraph = true -- Use faster FastLSTM TODO: Re-enable once nngraph #109 is resolved
require 'modules/MinDim'

local Model = classic.class('Model')

Expand Down Expand Up @@ -112,10 +112,19 @@ function Model:create()
end

-- Add network body
net:add(self:createBody())
local body = self:createBody()
-- Calculate body output size
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(net, _.append({histLen}, self.stateSpec[2]))))
net:add(nn.View(bodyOutputSize))
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(body, _.append({histLen}, self.stateSpec[2]))))
body:add(nn.View(-1, bodyOutputSize))
if not self.async and self.recurrent then
net:add(nn.MinDim(1, 4))
net:add(nn.Transpose({1, 2}))
body = nn.Bottle(body, 4, 2)
end
net:add(body)
if self.recurrent then
net:add(nn.MinDim(1, 3))
end

-- Network head
local head = nn.Sequential()
Expand All @@ -124,9 +133,10 @@ function Model:create()
-- Value approximator V^(s)
local valStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
valStream:add(lstm)
valStream:add(nn.Select(-3, -1)) -- Select last timestep
else
valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
valStream:add(nn.ReLU(true))
Expand All @@ -136,9 +146,10 @@ function Model:create()
-- Advantage approximator A^(s, a)
local advStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
advStream:add(lstm)
advStream:add(nn.Select(-3, -1)) -- Select last timestep
else
advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
advStream:add(nn.ReLU(true))
Expand All @@ -158,12 +169,10 @@ function Model:create()
head:add(DuelAggregator(self.m))
else
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
head:add(lstm)
if self.async then
lstm:remember('both')
end
head:add(nn.Select(-3, -1)) -- Select last timestep
else
head:add(nn.Linear(bodyOutputSize, self.hiddenSize))
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM
Expand Down Expand Up @@ -215,20 +224,13 @@ function Model:create()
if not self.a3c then
net:add(nn.JoinTable(1, 1))
net:add(nn.View(heads, self.m))

if not self.async and self.recurrent then
local sequencer = nn.Sequencer(net)
sequencer:remember('both') -- Keep hidden state between forward calls; requires manual calls to forget
net = nn.Sequential():add(nn.SplitTable(1, #self.stateSpec[2] + 1)):add(sequencer):add(nn.SelectTable(-1))
end
end

-- GPU conversion
if self.gpu > 0 then
require 'cunn'
net:cuda()
end

-- Save reference to network
self.net = net

Expand Down
33 changes: 33 additions & 0 deletions modules/MinDim.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
local MinDim, parent = torch.class('nn.MinDim', 'nn.Module')

local function _assertTensor(t)
assert(torch.isTensor(t), "This module only works on tensor")
end

function MinDim:__init(pos, minInputDims)
parent.__init(self)
self.pos = pos or error('the position to insert singleton dim not specified')
self:setMinInputDims(minInputDims)
end

function MinDim:setMinInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end

function MinDim:updateOutput(input)
_assertTensor(input)
self.output = input
if input:dim() < self.numInputDims then
nn.utils.addSingletonDimension(self.output, input, self.pos)
end
return self.output
end

function MinDim:updateGradInput(input, gradOutput)
_assertTensor(input)
_assertTensor(gradOutput)
assert(input:nElement() == gradOutput:nElement())
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end

0 comments on commit 00df464

Please sign in to comment.