Skip to content

Commit

Permalink
[R] fix image classification examples ( close apache#5080 ) (apache#5855
Browse files Browse the repository at this point in the history
)

[R] fix image classification examples ( close apache#5080 )
  • Loading branch information
thirdwing authored Apr 15, 2017
1 parent 37dfcda commit 19e0e8a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
32 changes: 30 additions & 2 deletions example/image-classification/train_mnist.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ parse_args <- function() {
help='the batch size')
parser$add_argument('--lr', type='double', default=.05,
help='the initial learning rate')
parser$add_argument('--mom', type='double', default=.9,
help='momentum for sgd')
parser$add_argument('--model-prefix', type='character',
help='the prefix of the model to load/save')
parser$add_argument('--num-round', type='integer', default=10,
Expand All @@ -113,6 +115,32 @@ if (args$network == 'mlp') {
data_shape <- c(28, 28, 1)
net <- get_lenet()
}

# train
source("train_model.R")
train_model.fit(args, net, get_iterator(data_shape))
data_loader <- get_iterator(data_shape)
data <- data_loader(args)
train <- data$train
val <- data$value

if (is.null(args$gpus)) {
devs <- mx.cpu()
} else {
devs <- lapply(unlist(strsplit(args$gpus, ",")), function(i) {
mx.gpu(as.integer(i))
})
}

mx.set.seed(0)

model <- mx.model.FeedForward.create(
X = train,
eval.data = val,
ctx = devs,
symbol = net,
num.round = args$num_round,
array.batch.size = args$batch_size,
learning.rate = args$lr,
momentum = args$mom,
eval.metric = mx.metric.accuracy,
initializer = mx.init.uniform(0.07),
batch.end.callback = mx.callback.log.train.metric(100))
4 changes: 2 additions & 2 deletions example/image-classification/train_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ train_model.fit <- function(args, network, data_loader) {
}

# data
data <- data_loader
data <- data_loader(args)
train <- data$train
val <- data$value
val <- data$value

# devices
if (is.null(args$gpus)) {
Expand Down

0 comments on commit 19e0e8a

Please sign in to comment.