Skip to content

Commit

Permalink
Merge pull request #663 from NVIDIA/dev/bug-662
Browse files Browse the repository at this point in the history
Fix Torch LR policy
  • Loading branch information
gheinrich committed Mar 30, 2016
2 parents 6b18cd0 + a614e48 commit b4dd9c7
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tools/torch/LRPolicy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,17 @@ end
function LRPolicy:GetLearningRate(iter)

local rate=0
local progress = 100 * (iter / self.max_iter) -- expressed in percent units

if self.policy == "fixed" then
rate = self.baselr
elseif self.policy == "step" then
local current_step = math.floor(iter/self.step_size)
rate = self.baselr * math.pow(self.gamma, current_step)
elseif self.policy == "exp" then
rate = self.baselr * math.pow(self.gamma, iter)
rate = self.baselr * math.pow(self.gamma, progress)
elseif self.policy == "inv" then
rate = self.baselr * math.pow(1 + self.gamma * iter, - self.power)
rate = self.baselr * math.pow(1 + self.gamma * progress, - self.power)
elseif self.policy == "multistep" then
if (self.current_step <= self.stepvalue_size and iter >= self.step_values[self.current_step]) then
self.current_step = self.current_step + 1
Expand All @@ -69,7 +70,7 @@ function LRPolicy:GetLearningRate(iter)
elseif self.policy == "poly" then
rate = self.baselr * math.pow(1.0 - (iter / self.max_iter), self.power)
elseif self.policy == "sigmoid" then
rate = self.baselr * (1.0 / (1.0 + exp(-self.gamma * (iter - self.step_size))));
rate = self.baselr * (1.0 / (1.0 + math.exp(self.gamma * (progress - 100*self.step_size/self.max_iter))));
else
--have to include additional comments
print("Unknown learning rate policy: " .. self.policy)
Expand Down

0 comments on commit b4dd9c7

Please sign in to comment.