Skip to content

Commit

Permalink
refactor sgd.jl to exploit multiple dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ZacCranko authored and vchuravy committed Jan 28, 2017
1 parent fb2a0d0 commit cac5625
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/optimizers/sgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,21 @@ function create_state(self :: SGD, index :: Int, weight :: NDArray)
end
end

function update(self :: SGD, index :: Int, weight :: NDArray, grad :: NDArray, state :: Union{Void, NDArray})
function update(self :: SGD, index :: Int, weight :: NDArray, grad :: NDArray, state :: Void)
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
grad = normalized_gradient(self.opts, self.state, weight, grad)

@inplace weight += -lr * grad
end

if isa(state, Void)
# vanilla SGD, without momentum
@inplace weight += -lr * grad
else
mom = state :: NDArray
coef = get_momentum(self.opts.momentum_scheduler, self.state)
@inplace mom .*= coef
@inplace mom .+= -lr * grad
@inplace weight .+= mom
end
# update with momentum
function update(self :: SGD, index :: Int, weight :: NDArray, grad :: NDArray, state :: NDArray)
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
grad = normalized_gradient(self.opts, self.state, weight, grad)

mom = state :: NDArray
coef = get_momentum(self.opts.momentum_scheduler, self.state)
@inplace mom .*= coef
@inplace mom .+= -lr * grad
@inplace weight .+= mom
end

0 comments on commit cac5625

Please sign in to comment.