Skip to content

Commit

Permalink
Use SeqLSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
JoostvDoorn committed Jun 26, 2016
1 parent 6e66251 commit 0efbe9b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 21 deletions.
4 changes: 4 additions & 0 deletions Agent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ function Agent:training()
if self.recurrent then
self.policyNet:forget()
self.targetNet:forget()
self.policyNet:storeState()
end
end

Expand Down Expand Up @@ -198,8 +199,10 @@ function Agent:observe(reward, rawObservation, terminal)
self.saliencyMap:zero()
end
else
self.policyNet:restoreState()
-- Retrieve estimates from all heads
local QHeads = self.policyNet:forward(state)
self.policyNet:storeState()

-- Sample from current episode head (indexes on first dimension with no batch)
local Qs = QHeads:select(1, self.head)
Expand Down Expand Up @@ -261,6 +264,7 @@ function Agent:observe(reward, rawObservation, terminal)
elseif self.recurrent then
-- Forget last sequence
self.policyNet:forget()
self.policyNet:storeState()
end
end

Expand Down
82 changes: 61 additions & 21 deletions Model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require 'dpnn' -- Adds gradParamClip method
require 'modules/GuidedReLU'
require 'modules/DeconvnetReLU'
require 'modules/GradientRescale'
--nn.FastLSTM.usenngraph = true -- Use faster FastLSTM TODO: Re-enable once nngraph #109 is resolved
require 'modules/MinDim'

local Model = classic.class('Model')

Expand Down Expand Up @@ -43,6 +43,31 @@ function Model:_init(opt)
end
end

-- Used to store the state of the RNNs
local function storeState(self)
for i=1,#self.rnn do
local c0 = self.rnn[i].c0
local h0 = self.rnn[i].h0
self.rnnMem[i].c0:resizeAs(c0):copy(c0)
self.rnnMem[i].h0:resizeAs(h0):copy(h0)
end
end

-- Used to restore the state of the RNNs
local function restoreState(self)
for i=1,#self.rnn do
local c0 = self.rnnMem[i].c0
local h0 = self.rnnMem[i].h0
if c0:dim() >= 2 then
self.rnn[i].cell:resize(1, c0:size(1), c0:size(2)):copy(c0)
self.rnn[i]._output:resize(1, h0:size(1), h0:size(2)):copy(h0)
else
self.rnn[i].cell:set()
self.rnn[i]._output:set()
end
end
end

-- Processes a single frame for DQN input; must not return same memory to prevent side-effects
function Model:preprocess(observation)
local frame = observation:type(self.tensorType) -- Convert from CudaTensor if necessary
Expand Down Expand Up @@ -104,18 +129,34 @@ end
function Model:create()
-- Number of input frames for recurrent networks is always 1
local histLen = self.recurrent and 1 or self.histLen

local rnn = {} -- Stores references to LSTM modules
local rnnMem = {}

-- Network starting with convolutional layers/model body
local net = nn.Sequential()
if self.recurrent then
net.rnn = rnn
net.rnnMem = rnnMem
net.storeState = storeState
net.restoreState = restoreState
net:add(nn.Copy(nil, nil, true)) -- Needed when splitting batch x seq x input over seq for DRQN; better than nn.Contiguous
end

-- Add network body
net:add(self:createBody())
local body = self:createBody()
-- Calculate body output size
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(net, _.append({histLen}, self.stateSpec[2]))))
net:add(nn.View(bodyOutputSize))
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(body, _.append({histLen}, self.stateSpec[2]))))
body:add(nn.View(-1, bodyOutputSize))
if not self.async and self.recurrent then
net:add(nn.MinDim(1, 4))
net:add(nn.Transpose({1, 2}))
body = nn.Bottle(body, 4, 2)
end
net:add(body)
if self.recurrent then
net:add(nn.MinDim(1, 3))
end

-- Network head
local head = nn.Sequential()
Expand All @@ -124,9 +165,12 @@ function Model:create()
-- Value approximator V^(s)
local valStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
rnn[1] = lstm
rnnMem[1] = {c0=torch.Tensor():type(self.tensorType), h0=torch.Tensor():type(self.tensorType)}
valStream:add(lstm)
valStream:add(nn.Select(-3, -1)) -- Select last timestep
else
valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
valStream:add(nn.ReLU(true))
Expand All @@ -136,9 +180,12 @@ function Model:create()
-- Advantage approximator A^(s, a)
local advStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
rnn[2] = lstm
rnnMem[2] = {c0=torch.Tensor():type(self.tensorType), h0=torch.Tensor():type(self.tensorType)}
advStream:add(lstm)
advStream:add(nn.Select(-3, -1)) -- Select last timestep
else
advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
advStream:add(nn.ReLU(true))
Expand All @@ -158,12 +205,12 @@ function Model:create()
head:add(DuelAggregator(self.m))
else
if self.recurrent then
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize)
lstm:remember('both')
rnn[1] = lstm
rnnMem[1] = {c0=torch.Tensor():type(self.tensorType), h0=torch.Tensor():type(self.tensorType)}
head:add(lstm)
if self.async then
lstm:remember('both')
end
head:add(nn.Select(-3, -1)) -- Select last timestep
else
head:add(nn.Linear(bodyOutputSize, self.hiddenSize))
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM
Expand Down Expand Up @@ -215,20 +262,13 @@ function Model:create()
if not self.a3c then
net:add(nn.JoinTable(1, 1))
net:add(nn.View(heads, self.m))

if not self.async and self.recurrent then
local sequencer = nn.Sequencer(net)
sequencer:remember('both') -- Keep hidden state between forward calls; requires manual calls to forget
net = nn.Sequential():add(nn.SplitTable(1, #self.stateSpec[2] + 1)):add(sequencer):add(nn.SelectTable(-1))
end
end

-- GPU conversion
if self.gpu > 0 then
require 'cunn'
net:cuda()
end

-- Save reference to network
self.net = net

Expand Down
33 changes: 33 additions & 0 deletions modules/MinDim.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
local MinDim, parent = torch.class('nn.MinDim', 'nn.Module')

local function _assertTensor(t)
assert(torch.isTensor(t), "This module only works on tensor")
end

function MinDim:__init(pos, minInputDims)
parent.__init(self)
self.pos = pos or error('the position to insert singleton dim not specified')
self:setMinInputDims(minInputDims)
end

function MinDim:setMinInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end

function MinDim:updateOutput(input)
_assertTensor(input)
self.output = input
if input:dim() < self.numInputDims then
nn.utils.addSingletonDimension(self.output, input, self.pos)
end
return self.output
end

function MinDim:updateGradInput(input, gradOutput)
_assertTensor(input)
_assertTensor(gradOutput)
assert(input:nElement() == gradOutput:nElement())
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end

0 comments on commit 0efbe9b

Please sign in to comment.