Skip to content

Commit

Permalink
Use validation memory for validation
Browse files Browse the repository at this point in the history
Fixes #16
  • Loading branch information
Kaixhin committed May 7, 2016
1 parent bc71583 commit a9dc1f2
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions Agent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ function Agent:observe(reward, rawObservation, terminal)
end

-- Learns from experience
function Agent:learn(x, indices, ISWeights)
function Agent:learn(x, indices, ISWeights, isValidation)
-- Copy x to parameters θ if necessary
if x ~= self.theta then
self.theta:copy(x)
Expand All @@ -260,7 +260,8 @@ function Agent:learn(x, indices, ISWeights)
self.dTheta:zero()

-- Retrieve experience tuples
local states, actions, rewards, transitions, terminals = self.memory:retrieve(indices) -- Terminal status is for transition (can't act in terminal state)
local memory = isValidation and self.valMemory or self.memory
local states, actions, rewards, transitions, terminals = memory:retrieve(indices) -- Terminal status is for transition (can't act in terminal state)
local N = actions:size(1)

-- Perform argmax action selection
Expand Down Expand Up @@ -337,6 +338,12 @@ function Agent:learn(x, indices, ISWeights)
-- Squared loss
loss = torch.mean(self.tdErr:clone():pow(2):mul(0.5)) -- Average over heads
end

-- Exit if being used for validation metrics
if isValidation then
return
end

-- Send TD-errors δ to be used as priorities
self.memory:updatePriorities(indices, torch.mean(self.tdErr, 2)) -- Use average error over heads

Expand Down

0 comments on commit a9dc1f2

Please sign in to comment.