Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[R] fix image classification examples ( close #5080 ) (#5855)
Browse files Browse the repository at this point in the history
[R] fix image classification examples ( close #5080 )
  • Loading branch information
thirdwing authored Apr 15, 2017
1 parent 4948c30 commit 59d0670
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
32 changes: 30 additions & 2 deletions example/image-classification/train_mnist.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ parse_args <- function() {
help='the batch size')
parser$add_argument('--lr', type='double', default=.05,
help='the initial learning rate')
parser$add_argument('--mom', type='double', default=.9,
help='momentum for sgd')
parser$add_argument('--model-prefix', type='character',
help='the prefix of the model to load/save')
parser$add_argument('--num-round', type='integer', default=10,
Expand All @@ -113,6 +115,32 @@ if (args$network == 'mlp') {
data_shape <- c(28, 28, 1)
net <- get_lenet()
}

# train
source("train_model.R")
train_model.fit(args, net, get_iterator(data_shape))
data_loader <- get_iterator(data_shape)
data <- data_loader(args)
train <- data$train
val <- data$value

if (is.null(args$gpus)) {
devs <- mx.cpu()
} else {
devs <- lapply(unlist(strsplit(args$gpus, ",")), function(i) {
mx.gpu(as.integer(i))
})
}

mx.set.seed(0)

model <- mx.model.FeedForward.create(
X = train,
eval.data = val,
ctx = devs,
symbol = net,
num.round = args$num_round,
array.batch.size = args$batch_size,
learning.rate = args$lr,
momentum = args$mom,
eval.metric = mx.metric.accuracy,
initializer = mx.init.uniform(0.07),
batch.end.callback = mx.callback.log.train.metric(100))
4 changes: 2 additions & 2 deletions example/image-classification/train_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ train_model.fit <- function(args, network, data_loader) {
}

# data
data <- data_loader
data <- data_loader(args)
train <- data$train
val <- data$value
val <- data$value

# devices
if (is.null(args$gpus)) {
Expand Down

0 comments on commit 59d0670

Please sign in to comment.