Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SeqLSTM #46

Merged
merged 1 commit into from
Sep 25, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions Model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require 'dpnn' -- Adds gradParamClip method
require 'modules/GuidedReLU'
require 'modules/DeconvnetReLU'
require 'modules/GradientRescale'
require 'modules/MinDim'

local Model = classic.class('Model')

Expand Down Expand Up @@ -86,22 +87,38 @@ function Model:create()
-- Add network body
log.info('Setting up ' .. self.modelBody)
local Body = require(self.modelBody)
net:add(Body(self):createBody())
local body = Body(self):createBody()

-- Calculate body output size
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(net, _.append({histLen}, self.stateSpec[2]))))
net:add(nn.View(bodyOutputSize))
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(body, _.append({histLen}, self.stateSpec[2]))))
if not self.async and self.recurrent then
body:add(nn.View(-1, bodyOutputSize))
net:add(nn.MinDim(1, 4))
net:add(nn.Transpose({1, 2}))
body = nn.Bottle(body, 4, 2)
net:add(body)
net:add(nn.MinDim(1, 3))
else
body:add(nn.View(bodyOutputSize))
net:add(body)
end

-- Network head
local head = nn.Sequential()
local heads = math.max(self.bootstraps, 1)
if self.duel then
-- Value approximator V^(s)
local valStream = nn.Sequential()
if self.recurrent then
if self.recurrent and self.async then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
lstm:remember('both')
valStream:add(lstm)
elseif self.recurrent then
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
valStream:add(lstm)
valStream:add(nn.Select(-3, -1)) -- Select last timestep
else
valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
valStream:add(nn.ReLU(true))
Expand All @@ -110,10 +127,16 @@ function Model:create()

-- Advantage approximator A^(s, a)
local advStream = nn.Sequential()
if self.recurrent then
if self.recurrent and self.async then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
lstm:remember('both')
advStream:add(lstm)
elseif self.recurrent then
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
advStream:add(lstm)
advStream:add(nn.Select(-3, -1)) -- Select last timestep
else
advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
advStream:add(nn.ReLU(true))
Expand All @@ -132,13 +155,16 @@ function Model:create()
-- Add dueling streams aggregator module
head:add(DuelAggregator(self.m))
else
if self.recurrent then
if self.recurrent and self.async then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
lstm:remember('both')
head:add(lstm)
if self.async then
lstm:remember('both')
end
elseif self.recurrent then
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
head:add(lstm)
head:add(nn.Select(-3, -1)) -- Select last timestep
else
head:add(nn.Linear(bodyOutputSize, self.hiddenSize))
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM
Expand Down Expand Up @@ -190,14 +216,7 @@ function Model:create()
if not self.a3c then
net:add(nn.JoinTable(1, 1))
net:add(nn.View(heads, self.m))

if not self.async and self.recurrent then
local sequencer = nn.Sequencer(net)
sequencer:remember('both') -- Keep hidden state between forward calls; requires manual calls to forget
net = nn.Sequential():add(nn.SplitTable(1, #self.stateSpec[2] + 1)):add(sequencer):add(nn.SelectTable(-1))
end
end

-- GPU conversion
if self.gpu > 0 then
require 'cunn'
Expand All @@ -214,7 +233,7 @@ function Model:create()
--]]
end
end

-- Save reference to network
self.net = net

Expand Down
33 changes: 33 additions & 0 deletions modules/MinDim.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
local MinDim, parent = torch.class('nn.MinDim', 'nn.Module')

local function _assertTensor(t)
assert(torch.isTensor(t), "This module only works on tensor")
end

function MinDim:__init(pos, minInputDims)
parent.__init(self)
self.pos = pos or error('the position to insert singleton dim not specified')
self:setMinInputDims(minInputDims)
end

function MinDim:setMinInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end

function MinDim:updateOutput(input)
_assertTensor(input)
self.output = input
if input:dim() < self.numInputDims then
nn.utils.addSingletonDimension(self.output, input, self.pos)
end
return self.output
end

function MinDim:updateGradInput(input, gradOutput)
_assertTensor(input)
_assertTensor(gradOutput)
assert(input:nElement() == gradOutput:nElement())
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end