Skip to content

Commit

Permalink
add NAG optimizer to r api (apache#14023)
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya authored and hetong007 committed Feb 1, 2019
1 parent 2a4634b commit 439377d
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 11 deletions.
105 changes: 105 additions & 0 deletions R-package/R/optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,110 @@ mx.opt.adadelta <- function(rho = 0.90,
}


#' Create a Nesterov Accelerated SGD( NAG) optimizer.
#'
#' NAG optimizer is described in Aleksandar Botev. et al (2016).
#' *NAG: A Nesterov accelerated SGD.*
#' https://arxiv.org/pdf/1607.01981.pdf
#'
#' @param learning.rate float, default=0.01
#' The initial learning rate.
#' @param momentum float, default=0
#' The momentum value
#' @param wd float, default=0.0
#' L2 regularization coefficient added to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.nag <- function(learning.rate = 0.01,
momentum = 0,
wd = 0,
rescale.grad = 1,
clip_gradient = -1,
lr_scheduler = NULL) {

lr <- learning.rate
count <- 0
num_update <- 0

nag <- new.env()
nag$lr <- learning.rate
nag$count <- 0
nag$num_update <- 0

create_exec <- function(index, weight_dim, ctx) {

weight <- mx.symbol.Variable("weight")
grad <- mx.symbol.Variable("grad")
mom <- mx.symbol.Variable("mom")
grad <- grad * rescale.grad

if (!is.null(clip_gradient)) {
if (clip_gradient >= 0) {
grad <- mx.symbol.clip(data = grad, a.min = -clip_gradient, a.max = clip_gradient)
}
}

if (momentum == 0) {

weight <- weight - lr * (grad + (wd * weight))
w <- mx.symbol.identity(weight, name = "w")
sym <- mx.symbol.Group(c(w))

} else {

mom <- momentum * mom + grad + wd * weight
grad <- momentum * mom + grad
weight <- weight - lr * grad

w <- mx.symbol.identity(weight, name = "w")
m <- mx.symbol.identity(mom, name = "m")
sym <- mx.symbol.Group(c(w, m))

}

exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, grad.req = "null")
return(exec)
}

update <- function(index, exec_w, weight, grad) {

if (!is.null(lr_scheduler)){
lr_scheduler(nag) ## changing lr
lr <- nag$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = nag, x = indexKey, inherits = FALSE)){
nag[[indexKey]] <- 0
} else {
indexValue <- nag[[indexKey]]
nag[[indexKey]] <- indexValue + 1
nag$num_update <- max(nag$num_update, nag[[indexKey]])
}
}

mx.exec.update.arg.arrays(exec_w,
arg.arrays = list(weight = weight,grad = grad),
match.name = T)
mx.exec.forward(exec_w, is.train = F)

# update state
if (!is.null(exec_w$ref.outputs$m_output)){
mx.exec.update.arg.arrays(exec_w,
arg.arrays = list(mom = exec_w$ref.outputs$m_output),
match.name = T)
}

return(exec_w$ref.outputs$w_output)
}
return(list(create_exec = create_exec, update = update))
}


#' Create an optimizer by name and parameters
#'
#' @param name The name of the optimizer
Expand All @@ -466,6 +570,7 @@ mx.opt.create <- function(name, ...) {
"adam" = mx.opt.adam(...),
"adagrad" = mx.opt.adagrad(...),
"adadelta" = mx.opt.adadelta(...),
"nag" = mx.opt.nag(...),
stop("Unknown optimizer ", name))
}

Expand Down
88 changes: 77 additions & 11 deletions R-package/tests/testthat/test_optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

context("optimizer")

if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) ==
1) {
mx.ctx.default(new = mx.gpu())
message("Using GPU for testing.")
}

test_that("sgd", {

data <- mx.symbol.Variable("data")
Expand All @@ -30,14 +36,14 @@ test_that("sgd", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("sgd", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -63,14 +69,14 @@ test_that("rmsprop", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, gamma1 = 0.95,
gamma2 = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -97,14 +103,14 @@ test_that("adam", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adam", learning.rate = 1, beta1 = 0.9, beta2 = 0.999,
epsilon = 1e-08, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -131,14 +137,14 @@ test_that("adagrad", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adagrad", learning.rate = 1, epsilon = 1e-08, wd = 0,
rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -164,22 +170,82 @@ test_that("adadelta", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adadelta", rho = 0.9, epsilon = 1e-05, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.11, 1.81), dim = c(2, 1)),
tolerance = 0.1)

})


test_that("nag_no_momentum", {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc_weight <- mx.symbol.Variable("fc_weight")
fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
name = "fc1", num_hidden = 1)
loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")

x <- mx.nd.array(array(1:6, dim = 2:3))
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))

optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.4, 2.6), dim = c(2, 1)), tolerance = 0.05)
})


test_that("nag_momentum", {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc_weight <- mx.symbol.Variable("fc_weight")
fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
name = "fc1", num_hidden = 1)
loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")

x <- mx.nd.array(array(1:6, dim = 2:3))
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))

optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0.1, wd = 0, rescale.grad = 1,
clip_gradient = 5)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.45, 2.65), dim = c(2, 1)), tolerance = 0.1)
})

0 comments on commit 439377d

Please sign in to comment.