Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[R] Add NAG optimizer #14023

Merged
merged 1 commit into from
Feb 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions R-package/R/optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,110 @@ mx.opt.adadelta <- function(rho = 0.90,
}


#' Create a Nesterov Accelerated SGD( NAG) optimizer.
#'
#' NAG optimizer is described in Aleksandar Botev. et al (2016).
#' *NAG: A Nesterov accelerated SGD.*
#' https://arxiv.org/pdf/1607.01981.pdf
#'
#' @param learning.rate float, default=0.01
#' The initial learning rate.
#' @param momentum float, default=0
#' The momentum value
#' @param wd float, default=0.0
#' L2 regularization coefficient added to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.nag <- function(learning.rate = 0.01,
momentum = 0,
wd = 0,
rescale.grad = 1,
clip_gradient = -1,
lr_scheduler = NULL) {

lr <- learning.rate
count <- 0
num_update <- 0

nag <- new.env()
nag$lr <- learning.rate
nag$count <- 0
nag$num_update <- 0

create_exec <- function(index, weight_dim, ctx) {

weight <- mx.symbol.Variable("weight")
grad <- mx.symbol.Variable("grad")
mom <- mx.symbol.Variable("mom")
grad <- grad * rescale.grad

if (!is.null(clip_gradient)) {
if (clip_gradient >= 0) {
grad <- mx.symbol.clip(data = grad, a.min = -clip_gradient, a.max = clip_gradient)
}
}

if (momentum == 0) {

weight <- weight - lr * (grad + (wd * weight))
w <- mx.symbol.identity(weight, name = "w")
sym <- mx.symbol.Group(c(w))

} else {

mom <- momentum * mom + grad + wd * weight
grad <- momentum * mom + grad
weight <- weight - lr * grad

w <- mx.symbol.identity(weight, name = "w")
m <- mx.symbol.identity(mom, name = "m")
sym <- mx.symbol.Group(c(w, m))

}

exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, grad.req = "null")
return(exec)
}

update <- function(index, exec_w, weight, grad) {

if (!is.null(lr_scheduler)){
lr_scheduler(nag) ## changing lr
lr <- nag$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = nag, x = indexKey, inherits = FALSE)){
nag[[indexKey]] <- 0
} else {
indexValue <- nag[[indexKey]]
nag[[indexKey]] <- indexValue + 1
nag$num_update <- max(nag$num_update, nag[[indexKey]])
}
}

mx.exec.update.arg.arrays(exec_w,
arg.arrays = list(weight = weight,grad = grad),
match.name = T)
mx.exec.forward(exec_w, is.train = F)

# update state
if (!is.null(exec_w$ref.outputs$m_output)){
mx.exec.update.arg.arrays(exec_w,
arg.arrays = list(mom = exec_w$ref.outputs$m_output),
match.name = T)
}

return(exec_w$ref.outputs$w_output)
}
return(list(create_exec = create_exec, update = update))
}


#' Create an optimizer by name and parameters
#'
#' @param name The name of the optimizer
Expand All @@ -466,6 +570,7 @@ mx.opt.create <- function(name, ...) {
"adam" = mx.opt.adam(...),
"adagrad" = mx.opt.adagrad(...),
"adadelta" = mx.opt.adadelta(...),
"nag" = mx.opt.nag(...),
stop("Unknown optimizer ", name))
}

Expand Down
88 changes: 77 additions & 11 deletions R-package/tests/testthat/test_optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

context("optimizer")

if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) ==
1) {
mx.ctx.default(new = mx.gpu())
message("Using GPU for testing.")
}

test_that("sgd", {

data <- mx.symbol.Variable("data")
Expand All @@ -30,14 +36,14 @@ test_that("sgd", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("sgd", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -63,14 +69,14 @@ test_that("rmsprop", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, gamma1 = 0.95,
gamma2 = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -97,14 +103,14 @@ test_that("adam", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adam", learning.rate = 1, beta1 = 0.9, beta2 = 0.999,
epsilon = 1e-08, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -131,14 +137,14 @@ test_that("adagrad", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adagrad", learning.rate = 1, epsilon = 1e-08, wd = 0,
rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -164,22 +170,82 @@ test_that("adadelta", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adadelta", rho = 0.9, epsilon = 1e-05, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.11, 1.81), dim = c(2, 1)),
tolerance = 0.1)

})


test_that("nag_no_momentum", {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc_weight <- mx.symbol.Variable("fc_weight")
fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
name = "fc1", num_hidden = 1)
loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")

x <- mx.nd.array(array(1:6, dim = 2:3))
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))

optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.4, 2.6), dim = c(2, 1)), tolerance = 0.05)
})


test_that("nag_momentum", {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc_weight <- mx.symbol.Variable("fc_weight")
fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
name = "fc1", num_hidden = 1)
loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")

x <- mx.nd.array(array(1:6, dim = 2:3))
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))

optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0.1, wd = 0, rescale.grad = 1,
clip_gradient = 5)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.45, 2.65), dim = c(2, 1)), tolerance = 0.1)
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
})