Skip to content

Commit

Permalink
[R] captcha example (apache#6443)
Browse files Browse the repository at this point in the history
  • Loading branch information
thirdwing authored May 26, 2017
1 parent a5c2030 commit 9828a46
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 7 deletions.
32 changes: 25 additions & 7 deletions R-package/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ mx.model.create.kvstore <- function(kvstore, arg.params, ndevice, verbose=TRUE)
}

# Internal function to do multiple device training.
mx.model.train <- function(symbol, ctx, input.shape,
mx.model.train <- function(symbol, ctx, input.shape, output.shape,
arg.params, aux.params,
begin.round, end.round, optimizer,
train.data, eval.data,
Expand All @@ -104,8 +104,14 @@ mx.model.train <- function(symbol, ctx, input.shape,
if(verbose) message(paste0("Start training with ", ndevice, " devices"))
# create the executors
sliceinfo <- mx.model.slice.shape(input.shape, ndevice)
sliceinfo2 <- mx.model.slice.shape(output.shape, ndevice)
arg_names <- arguments(symbol)
label_name <- arg_names[endsWith(arg_names, "label")]
train.execs <- lapply(1:ndevice, function(i) {
mx.simple.bind(symbol, ctx=ctx[[i]], data=sliceinfo[[i]]$shape, grad.req="write")
arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write",
data=sliceinfo[[i]]$shape)
arg_lst[[label_name]] = sliceinfo2[[i]]$shape
do.call(mx.simple.bind, arg_lst)
})
# set the parameters into executors
for (texec in train.execs) {
Expand Down Expand Up @@ -259,9 +265,14 @@ mx.model.train <- function(symbol, ctx, input.shape,
}

# Initialize parameters
mx.model.init.params <- function(symbol, input.shape, initializer, ctx) {
mx.model.init.params <- function(symbol, input.shape, output.shape, initializer, ctx) {
if (!is.MXSymbol(symbol)) stop("symbol need to be MXSymbol")
slist <- mx.symbol.infer.shape(symbol, data=input.shape)
arg_names <- arguments(symbol)
label_name <- arg_names[endsWith(arg_names, "label")]
arg_lst <- list(symbol = symbol, data=input.shape)
arg_lst[[label_name]] = output.shape

slist <- do.call(mx.symbol.infer.shape, arg_lst)
if (is.null(slist)) stop("Not enough information to get shapes")
arg.params <- mx.init.create(initializer, slist$arg.shapes, ctx, skip.unknown=TRUE)
aux.params <- mx.init.create(initializer, slist$aux.shapes, ctx, skip.unknown=FALSE)
Expand Down Expand Up @@ -413,7 +424,8 @@ function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
if (!X$iter.next()) stop("Empty input")
}
input.shape <- dim((X$value())$data)
params <- mx.model.init.params(symbol, input.shape, initializer, mx.cpu())
output.shape <- dim((X$value())$label)
params <- mx.model.init.params(symbol, input.shape, output.shape, initializer, mx.cpu())
if (!is.null(arg.params)) params$arg.params <- arg.params
if (!is.null(aux.params)) params$aux.params <- aux.params
if (is.null(ctx)) ctx <- mx.ctx.default()
Expand Down Expand Up @@ -444,7 +456,7 @@ function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
eval.data <- mx.model.init.iter(eval.data$data, eval.data$label, batch.size=array.batch.size, is.train = TRUE)
}
kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), verbose=verbose)
model <- mx.model.train(symbol, ctx, input.shape,
model <- mx.model.train(symbol, ctx, input.shape, output.shape,
params$arg.params, params$aux.params,
begin.round, num.round, optimizer=optimizer,
train.data=X, eval.data=eval.data,
Expand Down Expand Up @@ -484,7 +496,13 @@ predict.MXFeedForwardModel <- function(model, X, ctx=NULL, array.batch.size=128,
X$reset()
if (!X$iter.next()) stop("Cannot predict on empty iterator")
dlist = X$value()
pexec <- mx.simple.bind(model$symbol, ctx=ctx, data=dim(dlist$data), grad.req="null")
arg_names <- arguments(model$symbol)
label_name <- arg_names[endsWith(arg_names, "label")]
arg_lst <- list(symbol = model$symbol, ctx = ctx, data = dim(dlist$data), grad.req="null")
arg_lst[[label_name]] <- dim(dlist$label)


pexec <- do.call(mx.simple.bind, arg_lst)
mx.exec.update.arg.arrays(pexec, model$arg.params, match.name=TRUE)
mx.exec.update.aux.arrays(pexec, model$aux.params, match.name=TRUE)
packer <- mx.nd.arraypacker()
Expand Down
5 changes: 5 additions & 0 deletions example/captcha/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
This is the R version of [captcha recognition](http://blog.xlvector.net/2016-05/mxnet-ocr-cnn/) example by xlvector and it can be used as an example of multi-label training. For a captcha below, we consider it as an image with 4 labels and train a CNN over the data set.

![](captcha_example.png)

You can download the images and `.rec` files from [here](https://drive.google.com/open?id=0B_52ppM3wSXBdHctQmhUdmlTbDQ). Since each image has 4 labels, please remember to use `label_width=4` when generating the `.rec` files.
Binary file added example/captcha/captcha_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 68 additions & 0 deletions example/captcha/mxnet_captcha.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
library(mxnet)

data <- mx.symbol.Variable('data')
label <- mx.symbol.Variable('label')
conv1 <- mx.symbol.Convolution(data = data, kernel = c(5, 5), num_filter = 32)
pool1 <- mx.symbol.Pooling(data = conv1, pool_type = "max", kernel = c(2, 2), stride = c(1, 1))
relu1 <- mx.symbol.Activation(data = pool1, act_type = "relu")

conv2 <- mx.symbol.Convolution(data = relu1, kernel = c(5, 5), num_filter = 32)
pool2 <- mx.symbol.Pooling(data = conv2, pool_type = "avg", kernel = c(2, 2), stride = c(1, 1))
relu2 <- mx.symbol.Activation(data = pool2, act_type = "relu")

flatten <- mx.symbol.Flatten(data = relu2)
fc1 <- mx.symbol.FullyConnected(data = flatten, num_hidden = 120)
fc21 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc22 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc23 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc24 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc2 <- mx.symbol.Concat(c(fc21, fc22, fc23, fc24), dim = 0, num.args = 4)
label <- mx.symbol.transpose(data = label)
label <- mx.symbol.Reshape(data = label, target_shape = c(0))
captcha_net <- mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")

mx.metric.acc2 <- mx.metric.custom("accuracy", function(label, pred) {
ypred <- max.col(t(pred)) - 1
ypred <- matrix(ypred, nrow = nrow(label), ncol = ncol(label), byrow = TRUE)
return(sum(colSums(label == ypred) == 4) / ncol(label))
})

data.shape <- c(80, 30, 3)

batch_size <- 40

train <- mx.io.ImageRecordIter(
path.imgrec = "train.rec",
path.imglist = "train.lst",
batch.size = batch_size,
label.width = 4,
data.shape = data.shape,
mean.img = "mean.bin"
)

val <- mx.io.ImageRecordIter(
path.imgrec = "test.rec",
path.imglist = "test.lst",
batch.size = batch_size,
label.width = 4,
data.shape = data.shape,
mean.img = "mean.bin"
)

mx.set.seed(42)

model <- mx.model.FeedForward.create(
X = train,
eval.data = val,
ctx = mx.gpu(),
symbol = captcha_net,
eval.metric = mx.metric.acc2,
num.round = 10,
learning.rate = 0.0001,
momentum = 0.9,
wd = 0.00001,
batch.end.callback = mx.callback.log.train.metric(50),
initializer = mx.init.Xavier(factor_type = "in", magnitude = 2.34),
optimizer = "sgd",
clip_gradient = 10
)

0 comments on commit 9828a46

Please sign in to comment.