Skip to content

Commit

Permalink
[R] custom iter in model training; MF demo in R (apache#6673)
Browse files Browse the repository at this point in the history
  • Loading branch information
thirdwing committed Jul 11, 2017
1 parent 4ede26c commit 3865356
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 29 deletions.
96 changes: 67 additions & 29 deletions R-package/R/model.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
# slice the shape on the highest dimension
mx.model.slice.shape <- function(shape, nsplit) {
ndim <- length(shape)
batchsize <- shape[[ndim]]
step <- as.integer((batchsize + nsplit - 1) / nsplit)
lapply(0:(nsplit - 1), function(k) {
begin = min(k * step, batchsize)
end = min((k + 1) * step, batchsize)
s <- shape
s[[ndim]] = end - begin
return(list(begin=begin, end=end, shape=s))
})
if (is.numeric(shape)) {
ndim <- length(shape)
batchsize <- shape[[ndim]]
step <- as.integer((batchsize + nsplit - 1) / nsplit)
lapply(0:(nsplit - 1), function(k) {
begin = min(k * step, batchsize)
end = min((k + 1) * step, batchsize)
s <- shape
s[[ndim]] = end - begin
return(list(begin=begin, end=end, shape=s))
})
} else if (is.list(shape)) {
shape.names = names(shape)
ndim <- length(shape[[1]])
batchsize <- shape[[1]][[ndim]]
step <- as.integer((batchsize + nsplit - 1) / nsplit)
lapply(0:(nsplit - 1), function(k) {
begin = min(k * step, batchsize)
end = min((k + 1) * step, batchsize)
s <- lapply(shape, function(s) {
s[[ndim]] = end - begin
return(s)
})
return(list(begin=begin, end=end, shape=s))
})
}
}

# get the argument name of data and label
Expand Down Expand Up @@ -102,12 +118,13 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
# create the executors
sliceinfo <- mx.model.slice.shape(input.shape, ndevice)
sliceinfo2 <- mx.model.slice.shape(output.shape, ndevice)

arg_names <- arguments(symbol)
label_name <- arg_names[endsWith(arg_names, "label")]
train.execs <- lapply(1:ndevice, function(i) {
arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write",
data = sliceinfo[[i]]$shape)
arg_lst[[label_name]] = sliceinfo2[[i]]$shape
arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write")
arg_lst <- append(arg_lst, sliceinfo[[i]]$shape)
arg_lst <- append(arg_lst, sliceinfo2[[i]]$shape)
arg_lst[["fixed.param"]] = fixed.param
do.call(mx.simple.bind, arg_lst)
})
Expand Down Expand Up @@ -135,7 +152,9 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
}
# Get the input names
input.names <- mx.model.check.arguments(symbol)
# input.names <- mx.model.check.arguments(symbol)
arg_names <- arguments(symbol)
label_name <- arg_names[endsWith(arg_names, "label")]

for (iteration in begin.round:end.round) {
nbatch <- 0
Expand All @@ -147,14 +166,13 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
dlist <- train.data$value()
slices <- lapply(1:ndevice, function(i) {
s <- sliceinfo[[i]]
ret <- list(data=mx.nd.slice(dlist$data, s$begin, s$end),
label=mx.nd.slice(dlist$label, s$begin, s$end))
ret <- sapply(names(dlist), function(n) {mx.nd.slice(dlist[[n]], s$begin, s$end)})
return(ret)
})
# copy data to executor
for (i in 1:ndevice) {
s <- slices[[i]]
names(s) <- input.names
names(s)[endsWith(names(s), "label")] = label_name
mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
}
for (texec in train.execs) {
Expand Down Expand Up @@ -218,13 +236,12 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
dlist <- eval.data$value()
slices <- lapply(1:ndevice, function(i) {
s <- sliceinfo[[i]]
ret <- list(data=mx.nd.slice(dlist$data, s$begin, s$end),
label=mx.nd.slice(dlist$label, s$begin, s$end))
ret <- sapply(names(dlist), function(n) {mx.nd.slice(dlist[[n]], s$begin, s$end)})
return(ret)
})
for (i in 1:ndevice) {
s <- slices[[i]]
names(s) <- input.names
names(s)[endsWith(names(s), "label")] = label_name
mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
}
for (texec in train.execs) {
Expand Down Expand Up @@ -265,10 +282,10 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
# Initialize parameters
mx.model.init.params <- function(symbol, input.shape, output.shape, initializer, ctx) {
if (!is.MXSymbol(symbol)) stop("symbol need to be MXSymbol")
arg_names <- arguments(symbol)
label_name <- arg_names[endsWith(arg_names, "label")]
arg_lst <- list(symbol = symbol, data=input.shape)
arg_lst[[label_name]] = output.shape

arg_lst <- list(symbol = symbol)
arg_lst <- append(arg_lst, input.shape)
arg_lst <- append(arg_lst, output.shape)

slist <- do.call(mx.symbol.infer.shape, arg_lst)
if (is.null(slist)) stop("Not enough information to get shapes")
Expand Down Expand Up @@ -393,6 +410,10 @@ mx.model.select.layout.predict <- function(X, model) {
#' Model parameter, list of name to NDArray of net's weights.
#' @param aux.params list, optional
#' Model parameter, list of name to NDArray of net's auxiliary states.
#' @param input.names optional
#' The names of the input symbols.
#' @param output.names optional
#' The names of the output symbols.
#' @return model A trained mxnet model.
#'
#' @export
Expand All @@ -405,7 +426,9 @@ function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
epoch.end.callback=NULL, batch.end.callback=NULL,
array.batch.size=128, array.layout="auto",
kvstore = "local", verbose = TRUE,
arg.params = NULL, aux.params = NULL, fixed.param = NULL,
arg.params = NULL, aux.params = NULL,
input.names=NULL, output.names = NULL,
fixed.param = NULL,
...) {
if (is.array(X) || is.matrix(X)) {
if (array.layout == "auto") {
Expand All @@ -420,8 +443,18 @@ function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
X$reset()
if (!X$iter.next()) stop("Empty input")
}
input.shape <- dim((X$value())$data)
output.shape <- dim((X$value())$label)
if (is.null(input.names)) {
input.names <- "data"
}
input.shape <- sapply(input.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
if (is.null(output.names)) {
arg_names <- arguments(symbol)
output.names <- arg_names[endsWith(arg_names, "label")]
output.shape <- list()
output.shape[[output.names]] <- dim((X$value())$label)
} else {
output.shape <- sapply(output.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
}
params <- mx.model.init.params(symbol, input.shape, output.shape, initializer, mx.cpu())
if (!is.null(arg.params)) params$arg.params <- arg.params
if (!is.null(aux.params)) params$aux.params <- aux.params
Expand All @@ -431,8 +464,13 @@ function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
}
if (!is.list(ctx)) stop("ctx must be mx.context or list of mx.context")
if (is.character(optimizer)) {
ndim <- length(input.shape)
batchsize = input.shape[[ndim]]
if (is.numeric(input.shape)) {
ndim <- length(input.shape)
batchsize = input.shape[[ndim]]
} else {
ndim <- length(input.shape[[1]])
batchsize = input.shape[[1]][[ndim]]
}
optimizer <- mx.opt.create(optimizer, rescale.grad=(1/batchsize), ...)
}
if (!is.null(eval.data) && !is.list(eval.data) && !is.mx.dataiter(eval.data)) {
Expand Down
11 changes: 11 additions & 0 deletions R-package/tests/testthat/get_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,14 @@ GetCifar10 <- function() {
unzip('data/cifar10.zip', exdir = 'data/')
}
}

GetMovieLens <- function() {
if (!dir.exists("data")) {
dir.create("data/")
}
if (!file.exists('data/ml-100k/u.data')) {
download.file('http://files.grouplens.org/datasets/movielens/ml-100k.zip',
destfile = 'data/ml-100k.zip')
unzip('data/ml-100k.zip', exdir = 'data/')
}
}
75 changes: 75 additions & 0 deletions R-package/tests/testthat/test_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,78 @@ test_that("Classification", {
eval.metric = mx.metric.accuracy)
})

test_that("Matrix Factorization", {
GetMovieLens()
DF <- read.table("./data/ml-100k/u.data", header = F, sep = "\t")
names(DF) <- c("user", "item", "score", "time")
max_user <- max(DF$user)
max_item <- max(DF$item)
DF_mat_x <- data.matrix(t(DF[, 1:2]))
DF_y <- DF[, 3]
k <- 64
user <- mx.symbol.Variable("user")
item <- mx.symbol.Variable("item")
score <- mx.symbol.Variable("label")
user1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(user), input_dim = max_user,
output_dim = k, name = "user1")
item1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(item), input_dim = max_item,
output_dim = k, name = "item1"
)
pred <- user1 * item1
pred1 <- mx.symbol.sum_axis(pred, axis = 1, name = "pred1")
pred2 <- mx.symbol.Flatten(pred1, name = "pred2")
pred3 <- mx.symbol.LinearRegressionOutput(data = pred2, label = score, name = "pred3")
devices = lapply(1:2, function(i) {
mx.cpu(i)
})
mx.set.seed(123)

CustomIter <- setRefClass( "CustomIter", fields = c("iter1", "iter2"),
contains = "Rcpp_MXArrayDataIter",
methods = list(
initialize = function(iter1, iter2) {
.self$iter1 <- iter1
.self$iter2 <- iter2
.self
},
value = function() {
user <- .self$iter1$value()$data
item <- .self$iter2$value()$data
label <- .self$iter1$value()$label
list(user = user,
item = item,
label = label)
},
iter.next = function() {
.self$iter1$iter.next()
.self$iter2$iter.next()
},
reset = function() {
.self$iter1$reset()
.self$iter2$reset()
},
num.pad = function() {
.self$iter1$num.pad()
},
finalize = function() {
.self$iter1$finalize()
.self$iter2$finalize()
}
)
)

user_iter = mx.io.arrayiter(data = DF[, 1], label = DF[, 3], batch.size = k)

item_iter = mx.io.arrayiter(data = DF[, 2], label = DF[, 3], batch.size = k)

train_iter <- CustomIter$new(user_iter, item_iter)

model <- mx.model.FeedForward.create(pred3, X = train_iter, ctx = devices,
num.round = 10, initializer = mx.init.uniform(0.07),
learning.rate = 0.07,
eval.metric = mx.metric.rmse,
momentum = 0.9,
epoch.end.callback = mx.callback.log.train.metric(1),
input.names = c("user", "item"),
output.names = "label")
})
67 changes: 67 additions & 0 deletions example/recommenders/demo-MF.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
library(mxnet)
DF <- read.table("./ml-100k/u.data", header = F, sep = "\t")
names(DF) <- c("user", "item", "score", "time")
max_user <- max(DF$user)
max_item <- max(DF$item)
DF_mat_x <- data.matrix(t(DF[, 1:2]))
DF_y <- DF[, 3]
k <- 64
user <- mx.symbol.Variable("user")
item <- mx.symbol.Variable("item")
score <- mx.symbol.Variable("label")
user1 <-mx.symbol.Embedding(data = mx.symbol.BlockGrad(user), input_dim = max_user,
output_dim = k, name = "user1")
item1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(item), input_dim = max_item,
output_dim = k, name = "item1")
pred <- user1 * item1
pred1 <- mx.symbol.sum_axis(pred, axis = 1, name = "pred1")
pred2 <- mx.symbol.Flatten(pred1, name = "pred2")
pred3 <- mx.symbol.LinearRegressionOutput(data = pred2, label = score, name = "pred3")
devices <- mx.cpu()
mx.set.seed(123)

CustomIter <- setRefClass("CustomIter", fields = c("iter1", "iter2"),
contains = "Rcpp_MXArrayDataIter",
methods = list(
initialize = function(iter1, iter2) {
.self$iter1 <- iter1
.self$iter2 <- iter2
.self
},
value = function() {
user <- .self$iter1$value()$data
item <- .self$iter2$value()$data
label <- .self$iter1$value()$label
list(user = user,
item = item,
label = label)
},
iter.next = function() {
.self$iter1$iter.next()
.self$iter2$iter.next()
},
reset = function() {
.self$iter1$reset()
.self$iter2$reset()
},
num.pad = function() {
.self$iter1$num.pad()
},
finalize = function() {
.self$iter1$finalize()
.self$iter2$finalize()
}
)
)

user_iter = mx.io.arrayiter(data = DF[, 1], label = DF[, 3], batch.size = k)

item_iter = mx.io.arrayiter(data = DF[, 2], label = DF[, 3], batch.size = k)

train_iter <- CustomIter$new(user_iter, item_iter)

model <- mx.model.FeedForward.create(pred3, X = train_iter, ctx = devices,
num.round = 10, initializer = mx.init.uniform(0.07),
learning.rate = 0.07, eval.metric = mx.metric.rmse,
momentum = 0.9, epoch.end.callback = mx.callback.log.train.metric(1),
input.names = c("user", "item"), output.names = "label")

0 comments on commit 3865356

Please sign in to comment.