Skip to content

Commit

Permalink
Merge pull request #46 from JoostvDoorn/rnn
Browse files Browse the repository at this point in the history
Use SeqLSTM
  • Loading branch information
Kaixhin authored Sep 25, 2016
2 parents 5204469 + aef9ec5 commit a271098
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 18 deletions.
55 changes: 37 additions & 18 deletions Model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require 'dpnn' -- Adds gradParamClip method
require 'modules/GuidedReLU'
require 'modules/DeconvnetReLU'
require 'modules/GradientRescale'
require 'modules/MinDim'

local Model = classic.class('Model')

Expand Down Expand Up @@ -86,22 +87,38 @@ function Model:create()
-- Add network body
log.info('Setting up ' .. self.modelBody)
local Body = require(self.modelBody)
net:add(Body(self):createBody())
local body = Body(self):createBody()

-- Calculate body output size
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(net, _.append({histLen}, self.stateSpec[2]))))
net:add(nn.View(bodyOutputSize))
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(body, _.append({histLen}, self.stateSpec[2]))))
if not self.async and self.recurrent then
body:add(nn.View(-1, bodyOutputSize))
net:add(nn.MinDim(1, 4))
net:add(nn.Transpose({1, 2}))
body = nn.Bottle(body, 4, 2)
net:add(body)
net:add(nn.MinDim(1, 3))
else
body:add(nn.View(bodyOutputSize))
net:add(body)
end

-- Network head
local head = nn.Sequential()
local heads = math.max(self.bootstraps, 1)
if self.duel then
-- Value approximator V^(s)
local valStream = nn.Sequential()
if self.recurrent then
if self.recurrent and self.async then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
lstm:remember('both')
valStream:add(lstm)
elseif self.recurrent then
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
valStream:add(lstm)
valStream:add(nn.Select(-3, -1)) -- Select last timestep
else
valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
valStream:add(nn.ReLU(true))
Expand All @@ -110,10 +127,16 @@ function Model:create()

-- Advantage approximator A^(s, a)
local advStream = nn.Sequential()
if self.recurrent then
if self.recurrent and self.async then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
lstm:remember('both')
advStream:add(lstm)
elseif self.recurrent then
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
advStream:add(lstm)
advStream:add(nn.Select(-3, -1)) -- Select last timestep
else
advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
advStream:add(nn.ReLU(true))
Expand All @@ -132,13 +155,16 @@ function Model:create()
-- Add dueling streams aggregator module
head:add(DuelAggregator(self.m))
else
if self.recurrent then
if self.recurrent and self.async then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
lstm:remember('both')
head:add(lstm)
if self.async then
lstm:remember('both')
end
elseif self.recurrent then
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
head:add(lstm)
head:add(nn.Select(-3, -1)) -- Select last timestep
else
head:add(nn.Linear(bodyOutputSize, self.hiddenSize))
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM
Expand Down Expand Up @@ -190,14 +216,7 @@ function Model:create()
if not self.a3c then
net:add(nn.JoinTable(1, 1))
net:add(nn.View(heads, self.m))

if not self.async and self.recurrent then
local sequencer = nn.Sequencer(net)
sequencer:remember('both') -- Keep hidden state between forward calls; requires manual calls to forget
net = nn.Sequential():add(nn.SplitTable(1, #self.stateSpec[2] + 1)):add(sequencer):add(nn.SelectTable(-1))
end
end

-- GPU conversion
if self.gpu > 0 then
require 'cunn'
Expand All @@ -214,7 +233,7 @@ function Model:create()
--]]
end
end

-- Save reference to network
self.net = net

Expand Down
33 changes: 33 additions & 0 deletions modules/MinDim.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
local MinDim, parent = torch.class('nn.MinDim', 'nn.Module')

local function _assertTensor(t)
assert(torch.isTensor(t), "This module only works on tensor")
end

function MinDim:__init(pos, minInputDims)
parent.__init(self)
self.pos = pos or error('the position to insert singleton dim not specified')
self:setMinInputDims(minInputDims)
end

function MinDim:setMinInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end

function MinDim:updateOutput(input)
_assertTensor(input)
self.output = input
if input:dim() < self.numInputDims then
nn.utils.addSingletonDimension(self.output, input, self.pos)
end
return self.output
end

function MinDim:updateGradInput(input, gradOutput)
_assertTensor(input)
_assertTensor(gradOutput)
assert(input:nElement() == gradOutput:nElement())
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end

0 comments on commit a271098

Please sign in to comment.