diff --git a/Agent.lua b/Agent.lua index 9490a6d..df5a829 100644 --- a/Agent.lua +++ b/Agent.lua @@ -193,6 +193,11 @@ function Agent:observe(reward, rawObservation, terminal) -- Choose action by ε-greedy exploration (even with bootstraps) aIndex = torch.random(1, self.m) + -- Forward state anyway if recurrent + if self.recurrent then + self.policyNet:forward(state) + end + -- Reset saliency if action not chosen by network if self.saliency then self.saliencyMap:zero() diff --git a/async/QAgent.lua b/async/QAgent.lua index 0870131..ce60afb 100644 --- a/async/QAgent.lua +++ b/async/QAgent.lua @@ -20,6 +20,7 @@ function QAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sha self.dTheta:zero() self.doubleQ = opt.doubleQ + self.recurrent = opt.recurrent self.epsilonStart = opt.epsilonStart self.epsilon = self.epsilonStart @@ -60,6 +61,10 @@ function QAgent:eGreedy(state, net) end if torch.uniform() < self.epsilon then + -- Forward state anyway if recurrent + if self.recurrent then + net:forward(state) + end return torch.random(1,self.m) end