Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
…to fix-depthwise-conv
  • Loading branch information
Ubuntu committed Feb 1, 2019
2 parents 00ccd84 + f95e794 commit ffd2952
Show file tree
Hide file tree
Showing 15 changed files with 932 additions and 127 deletions.
105 changes: 105 additions & 0 deletions R-package/R/optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,110 @@ mx.opt.adadelta <- function(rho = 0.90,
}


#' Create a Nesterov Accelerated SGD( NAG) optimizer.
#'
#' NAG optimizer is described in Aleksandar Botev. et al (2016).
#' *NAG: A Nesterov accelerated SGD.*
#' https://arxiv.org/pdf/1607.01981.pdf
#'
#' @param learning.rate float, default=0.01
#' The initial learning rate.
#' @param momentum float, default=0
#' The momentum value
#' @param wd float, default=0.0
#' L2 regularization coefficient added to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.nag <- function(learning.rate = 0.01,
momentum = 0,
wd = 0,
rescale.grad = 1,
clip_gradient = -1,
lr_scheduler = NULL) {

lr <- learning.rate
count <- 0
num_update <- 0

nag <- new.env()
nag$lr <- learning.rate
nag$count <- 0
nag$num_update <- 0

create_exec <- function(index, weight_dim, ctx) {

weight <- mx.symbol.Variable("weight")
grad <- mx.symbol.Variable("grad")
mom <- mx.symbol.Variable("mom")
grad <- grad * rescale.grad

if (!is.null(clip_gradient)) {
if (clip_gradient >= 0) {
grad <- mx.symbol.clip(data = grad, a.min = -clip_gradient, a.max = clip_gradient)
}
}

if (momentum == 0) {

weight <- weight - lr * (grad + (wd * weight))
w <- mx.symbol.identity(weight, name = "w")
sym <- mx.symbol.Group(c(w))

} else {

mom <- momentum * mom + grad + wd * weight
grad <- momentum * mom + grad
weight <- weight - lr * grad

w <- mx.symbol.identity(weight, name = "w")
m <- mx.symbol.identity(mom, name = "m")
sym <- mx.symbol.Group(c(w, m))

}

exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, grad.req = "null")
return(exec)
}

update <- function(index, exec_w, weight, grad) {

if (!is.null(lr_scheduler)){
lr_scheduler(nag) ## changing lr
lr <- nag$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = nag, x = indexKey, inherits = FALSE)){
nag[[indexKey]] <- 0
} else {
indexValue <- nag[[indexKey]]
nag[[indexKey]] <- indexValue + 1
nag$num_update <- max(nag$num_update, nag[[indexKey]])
}
}

mx.exec.update.arg.arrays(exec_w,
arg.arrays = list(weight = weight,grad = grad),
match.name = T)
mx.exec.forward(exec_w, is.train = F)

# update state
if (!is.null(exec_w$ref.outputs$m_output)){
mx.exec.update.arg.arrays(exec_w,
arg.arrays = list(mom = exec_w$ref.outputs$m_output),
match.name = T)
}

return(exec_w$ref.outputs$w_output)
}
return(list(create_exec = create_exec, update = update))
}


#' Create an optimizer by name and parameters
#'
#' @param name The name of the optimizer
Expand All @@ -466,6 +570,7 @@ mx.opt.create <- function(name, ...) {
"adam" = mx.opt.adam(...),
"adagrad" = mx.opt.adagrad(...),
"adadelta" = mx.opt.adadelta(...),
"nag" = mx.opt.nag(...),
stop("Unknown optimizer ", name))
}

Expand Down
88 changes: 77 additions & 11 deletions R-package/tests/testthat/test_optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

context("optimizer")

if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) ==
1) {
mx.ctx.default(new = mx.gpu())
message("Using GPU for testing.")
}

test_that("sgd", {

data <- mx.symbol.Variable("data")
Expand All @@ -30,14 +36,14 @@ test_that("sgd", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("sgd", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -63,14 +69,14 @@ test_that("rmsprop", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, gamma1 = 0.95,
gamma2 = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -97,14 +103,14 @@ test_that("adam", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adam", learning.rate = 1, beta1 = 0.9, beta2 = 0.999,
epsilon = 1e-08, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -131,14 +137,14 @@ test_that("adagrad", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adagrad", learning.rate = 1, epsilon = 1e-08, wd = 0,
rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)
Expand All @@ -164,22 +170,82 @@ test_that("adadelta", {
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x,
exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("adadelta", rho = 0.9, epsilon = 1e-05, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.11, 1.81), dim = c(2, 1)),
tolerance = 0.1)

})


test_that("nag_no_momentum", {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc_weight <- mx.symbol.Variable("fc_weight")
fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
name = "fc1", num_hidden = 1)
loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")

x <- mx.nd.array(array(1:6, dim = 2:3))
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))

optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.4, 2.6), dim = c(2, 1)), tolerance = 0.05)
})


test_that("nag_momentum", {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc_weight <- mx.symbol.Variable("fc_weight")
fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
name = "fc1", num_hidden = 1)
loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")

x <- mx.nd.array(array(1:6, dim = 2:3))
y <- mx.nd.array(c(5, 11, 16))
w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))

exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))

optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0.1, wd = 0, rescale.grad = 1,
clip_gradient = 5)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

mx.exec.forward(exec, is.train = T)
mx.exec.backward(exec)

arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)

expect_equal(as.array(arg.blocks[[2]]), array(c(1.45, 2.65), dim = c(2, 1)), tolerance = 0.1)
})
34 changes: 13 additions & 21 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def forward(self, x):
return image.center_crop(x, *self._args)[0]


class Resize(Block):
"""Resize an image to the given size.
class Resize(HybridBlock):
"""Resize an image or a batch of image NDArray to the given size.
Should be applied before `mxnet.gluon.data.vision.transforms.ToTensor`.
Parameters
Expand All @@ -276,44 +276,36 @@ class Resize(Block):
interpolation : int
Interpolation method for resizing. By default uses bilinear
interpolation. See OpenCV's resize function for available choices.
Note that the Resize on gpu use contrib.bilinearResize2D operator
which only support bilinear interpolation(1). The result would be slightly
different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D
use algorithm which aligns corner.
Inputs:
- **data**: input tensor with (Hi x Wi x C) shape.
- **data**: input tensor with (H x W x C) or (N x H x W x C) shape.
Outputs:
- **out**: output tensor with (H x W x C) shape.
- **out**: output tensor with (H x W x C) or (N x H x W x C) shape.
Examples
--------
>>> transformer = vision.transforms.Resize(size=(1000, 500))
>>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
>>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 3x500x1000x3 @cpu(0)>
"""
def __init__(self, size, keep_ratio=False, interpolation=1):
super(Resize, self).__init__()
self._keep = keep_ratio
self._size = size
self._interpolation = interpolation

def forward(self, x):
if isinstance(self._size, numeric_types):
if not self._keep:
wsize = self._size
hsize = self._size
else:
h, w, _ = x.shape
if h > w:
wsize = self._size
hsize = int(h * wsize / w)
else:
hsize = self._size
wsize = int(w * hsize / h)
else:
wsize, hsize = self._size
return image.imresize(x, wsize, hsize, self._interpolation)

def hybrid_forward(self, F, x):
return F.image.resize(x, self._size, self._keep, self._interpolation)

class RandomFlipLeftRight(HybridBlock):
"""Randomly flip the input image left to right with a probability
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
dataName: String = "data", labelName: String = "label") {
this(IO.initDataDesc(data, allowEmpty = false, dataName,
if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
IO.initDataDesc(label, allowEmpty = true, labelName,
if (label == null || label.isEmpty) MX_REAL_TYPE else label(0).dtype, Layout.UNDEFINED),
dataBatchSize, shuffle, lastBatchHandle)
}

Expand Down Expand Up @@ -175,7 +176,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
private def _padData(ndArray: NDArray): NDArray = {
val padNum = cursor + dataBatchSize - numData
val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
val newArray = NDArray.zeros(shape)
// The new NDArray has to be created such that it inherits dtype from the passed in array
val newArray = NDArray.zeros(shape, dtype = ndArray.dtype)
NDArrayCollector.auto().withScope {
val batch = ndArray.slice(cursor, numData)
val padding = ndArray.slice(0, padNum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val shape0 = Shape(Array(1000, 2, 2))
val data = IndexedSeq(NDArray.ones(shape0), NDArray.zeros(shape0))
val shape1 = Shape(Array(1000, 1))
val label = IndexedSeq(NDArray.ones(shape1))
val label = IndexedSeq(NDArray.ones(shape1, dtype = DType.Int32))
val batchData0 = NDArray.ones(Shape(Array(128, 2, 2)))
val batchData1 = NDArray.zeros(Shape(Array(128, 2, 2)))
val batchLabel = NDArray.ones(Shape(Array(128, 1)))
Expand All @@ -254,6 +254,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(tBatch.data(0).toArray === batchData0.toArray)
assert(tBatch.data(1).toArray === batchData1.toArray)
assert(tBatch.label(0).toArray === batchLabel.toArray)
assert(tBatch.label(0).dtype == DType.Int32)
}

assert(batchCount === nBatch0)
Expand Down
Loading

0 comments on commit ffd2952

Please sign in to comment.