Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add 5 min tutorial and fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
hetong007 committed Oct 18, 2015
1 parent fa782b0 commit 9c26111
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 113 deletions.
121 changes: 121 additions & 0 deletions R-package/vignettes/fiveMinutesNeuralNetwork.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
Neural Network with MXNet in Five Minutes
=============================================

This is the first tutorial for new users of the R package `mxnet`. You will learn to construct a neural network to do regression in 5 minutes.

We will show you how to do classification and regression tasks respectively. The data we use comes from the package `mlbench`.

## Classification

First of all, let us load in the data and preprocess it:

```{r}
require(mlbench)
data(Sonar, package="mlbench")
Sonar[,61] = as.numeric(Sonar[,61])-1
train.ind = c(1:50, 100:150)
train.x = data.matrix(Sonar[train.ind, 1:60])
train.y = Sonar[train.ind, 61]
test.x = data.matrix(Sonar[-train.ind, 1:60])
test.y = Sonar[-train.ind, 61]
```

The next step is to define the structure of the neural network.

```{r}
# Define the input data
data <- mx.symbol.Variable("data")
# A fully connected hidden layer
# data: input source
# name: fc1
# num_hidden: number of neurons in this hidden layer
fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=20)
# An activation function
# fc1: input source
# name: relu1
# act_type: type for the activation function
act1 <- mx.symbol.Activation(fc1, name="tanh1", act_type="tanh")
fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=2)
# Softmax function for the output layer
softmax <- mx.symbol.Softmax(fc2, name="sm")
```

According to the comments in the code, you can see the meaning of each function and its arguments. They can be easily modified according to your need.

After the network configuration, we can start the training process:

```{r}
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=device.cpu, num.round=20, array.batch.size=10,
learning.rate=0.1, momentum=0.9,
epoch.end.callback=mx.callback.log.train.metric(100))
```

Note that `mx.set.seed` is the correct function to control the random process in `mxnet`. You can see the accuracy in each round during training. It is also easy to make prediction and evaluate

```{r}
preds = predict(model, test.x)
pred.label = max.col(preds)-1
table(pred.label, test.y)
```

## Regression

Again, let us preprocess the data first.

```{r}
data(BostonHousing, package="mlbench")
train.ind = seq(1, 506, 3)
train.x = BostonHousing[train.ind, -14]
train.y = BostonHousing[train.ind, 14]
test.x = BostonHousing[-train.ind, -14]
test.y = BostonHousing[-train.ind, 14]
```

We can configure a similar network as what we have done above. The only difference is in the output activation:

```{r}
# Define the input data
data <- mx.symbol.Variable("data")
# A fully connected hidden layer
# data: input source
# name: fc1
# num_hidden: number of neurons in this hidden layer
fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=20)
# An activation function
# fc1: input source
# name: relu1
# act_type: type for the activation function
act1 <- mx.symbol.Activation(fc1, name="tanh1", act_type="tanh")
fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=2)
# Softmax function for the output layer
lro <- mx.symbol.LinearRegressionOutput(fc2, name="lro")
```

What we changed is mainly the last function, this enables the new network to optimize for squared loss. We can now train on this simple data set.

```{r}
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=device.cpu, num.round=20, array.batch.size=10,
learning.rate=0.1, momentum=0.9,
epoch.end.callback=mx.callback.log.train.metric(100))
```

It is also easy to make prediction and evaluate

```{r}
preds = predict(model, test.x)
sqrt(mean((pred.label-test.y)^2))
```

Congratulations! Now you have learnt the basic for using `mxnet`.


101 changes: 84 additions & 17 deletions R-package/vignettes/mnistCompetition.Rmd
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
---
title: "Handwritten Digits Classification Competition"
author: "Tong He"
date: "October 17, 2015"
output: html_document
---
Handwritten Digits Classification Competition
======================================================

[MNIST](http://yann.lecun.com/exdb/mnist/) is a handwritten digits image data set created by Yann LeCun. Every digit is represented by a 28x28 image. It has become a standard data set to test classifiers on simple image input. Neural network is no doubt a strong model for image classification tasks. There's a [long-term hosted competition](https://www.kaggle.com/c/digit-recognizer) on Kaggle using this data set. We will present the basic usage of `mxnet` to compete in this challenge.

Expand All @@ -14,6 +10,7 @@ First, let us download the data from [here](https://www.kaggle.com/c/digit-recog
Then we can read them in R and convert to matrices.

```{r, eval=FALSE}
require(mxnet)
train <- read.csv('data/train.csv', header=TRUE)
test <- read.csv('data/test.csv', header=TRUE)
train <- data.matrix(train)
Expand All @@ -25,7 +22,7 @@ train.y <- train[,1]

Here every image is represented as a single row in train/test. The greyscale of each image falls in the range [0, 255], we can linearly transform it into [0,1] by

```{r, eval = FALSE}
```{r, eval=FALSE}
train.x <- train.x/255
test <- test/255
```
Expand All @@ -40,14 +37,14 @@ table(train.y)

Now we have the data. The next step is to configure the structure of our network.

```{r}
```{r, eval=FALSE}
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
fc2 <- mx.symbol.FullyConnected(act1, name = "fc2", num_hidden = 64)
fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
softmax <- mx.symbol.Softmax(fc3, name = "sm")
softmax <- mx.symbol.Softmax(fc3, name="sm")
```

1. In `mxnet`, we use its own data type `symbol` to configure the network. `data <- mx.symbol.Variable("data")` use `data` to represent the input data, i.e. the input layer.
Expand All @@ -62,16 +59,16 @@ softmax <- mx.symbol.Softmax(fc3, name = "sm")

We are almost ready for the training process. Before we start the computation, let's decide what device should we use.

```{r}
```{r, eval=FALSE}
devices <- lapply(1:2, function(i) {
mx.cpu(i)
})
```

Here we assign two threads of our CPU to `mxnet`. After all these preparation, you can run the following command to train the neural network!
Here we assign two threads of our CPU to `mxnet`. After all these preparation, you can run the following command to train the neural network! Note that `mx.set.seed` is the correct function to control the random process in `mxnet`.

```{r}
set.seed(0)
```{r, eval=FALSE}
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=devices, num.round=10, array.batch.size=100,
learning.rate=0.07, momentum=0.9,
Expand All @@ -83,31 +80,101 @@ model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,

To make prediction, we can simply write

```{r}
```{r, eval=FALSE}
preds <- predict(model, test)
dim(preds)
```

It is a matrix with 28000 rows and 10 cols, containing the desired classification probabilities from the output layer. To extract the maximum label for each row, we can use the `max.col` in R:

```{r}
```{r, eval=FALSE}
pred.label <- max.col(preds) - 1
table(pred.label)
```

With a little extra effort in the csv format, we can have our submission to the competition!

```{r}
```{r, eval=FALSE}
submission <- data.frame(ImageId=1:nrow(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)
```

## LeNet

Next we are going to introduce a new network structure: [LeNet](http://yann.lecun.com/exdb/lenet/). It is proposed by Yann LeCun to recognize handwritten digits. Now we are going to demonstrate how to construct and train an LeNet in `mxnet`.

First we construct the network:

```{r, eval=FALSE}
# input
data <- mx.symbol.Variable('data')
# first conv
conv1 <- mx.symbol.Convolution(data=data, kernel=c(5,5), num_filter=20)
tanh1 <- mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 <- mx.symbol.Pooling(data=tanh1, pool_type="max",
kernel=c(2,2), stride=c(2,2))
# second conv
conv2 <- mx.symbol.Convolution(data=pool1, kernel=c(5,5), num_filter=50)
tanh2 <- mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 <- mx.symbol.Pooling(data=tanh2, pool_type="max",
kernel=c(2,2), stride=c(2,2))
# first fullc
flatten <- mx.symbol.Flatten(data=pool2)
fc1 <- mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 <- mx.symbol.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
# loss
lenet <- mx.symbol.Softmax(data=fc2)
```

Then let us reshape the matrices into arrays:

```{r, eval=FALSE}
train.array <- t(train.x)
dim(train.array) <- c(1,28,28,nrow(train.x))
train.array <- aperm(train.array, c(4,1,2,3))
test.array <- t(test)
dim(test.array) <- c(1,28,28,nrow(test))
test.array <- aperm(test.array, c(4,1,2,3))
```

Next we are going to compare the training speed on different devices, so the definition of the devices goes first:

```{r, eval=FALSE}
device.cpu <- mx.cpu()
device.gpu <- lapply(1:4, function(i) {
mx.gpu(i)
})
```

Training on CPU:

```{r, eval=FALSE}
mx.set.seed(0)
model <- mx.model.FeedForward.create(lenet, X=train.array, y=train.y,
ctx=device.cpu, num.round=5, array.batch.size=100,
learning.rate=0.05, momentum=0.9, wd=0.00001,
epoch.end.callback=mx.callback.log.train.metric(100))
```

Training on GPU:

```{r, eval=FALSE}
mx.set.seed(0)
model <- mx.model.FeedForward.create(lenet, X=train.array, y=train.y,
ctx=device.gpu, num.round=5, array.batch.size=100,
learning.rate=0.05, momentum=0.9, wd=0.00001,
epoch.end.callback=mx.callback.log.train.metric(100))
```

Finally we can submit the result to Kaggle again to see the improvement of our ranking!

```{r, eval=FALSE}
preds <- predict(model, test.array)
pred.label <- max.col(preds) - 1
submission <- data.frame(ImageId=1:nrow(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)
```

![](../web-data/mxnet/knitr/mnistCompetition-kaggle-submission.png)
7 changes: 3 additions & 4 deletions R-package/vignettes/ndarrayAndSymbolTutorial.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ Let's create `NDArray` on either GPU or CPU
```{r}
require(mxnet)
a <- mx.nd.zeros(c(2, 3)) # create a 2-by-3 matrix on cpu
b <- mx.nd.zeros(c(2, 3), mx.gpu()) # create a 2-by-3 matrix on gpu 0
c <- mx.nd.zeros(c(2, 3), mx.gpu(2)) # create a 2-by-3 matrix on gpu 0
c$dim()
b <- mx.nd.zeros(c(2, 3), mx.cpu()) # create a 2-by-3 matrix on gpu 0
c <- mx.nd.zeros(c(2, 3), mx.gpu(1)) # create a 2-by-3 matrix on gpu 0
```

We can also initialize an `NDArray` object in various ways:
Expand Down Expand Up @@ -72,7 +71,7 @@ as.array(d)
If two `NDArray`s sit on different divices, we need to explicitly move them
into the same one. For instance:

```{r}
```{r, eval=FALSE}
a <- mx.nd.ones(c(2, 3)) * 2
b <- mx.nd.ones(c(2, 3), mx.gpu()) / 8
c <- mx.nd.copyto(a, mx.gpu()) * b
Expand Down
2 changes: 1 addition & 1 deletion doc/R-package/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ ndarrayAndSymbolTutorial.Rmd:
Rscript -e \
"require(knitr);"\
"knitr::opts_knit\$$set(root.dir=\".\");"\
"knitr::opts_chunk\$$set(fig.path=\"../doc-image/mxnet/knitr/$(basename $@)-\");"\
"knitr::opts_chunk\$$set(fig.path=\"../web-data/mxnet/knitr/$(basename $@)-\");"\
"knitr::knit(\"$+\")"
2 changes: 1 addition & 1 deletion doc/R-package/classifyRealImageWithPretrainedModel.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ im <- load.image(system.file("extdata/parrots.png", package="imager"))
plot(im)
```

![plot of chunk unnamed-chunk-5](../doc-image/mxnet/knitr/classifyRealImageWithPretrainedModel-unnamed-chunk-5-1.png)
![plot of chunk unnamed-chunk-5](../web-data/mxnet/knitr/classifyRealImageWithPretrainedModel-unnamed-chunk-5-1.png)

Before feeding the image to the deep net, we need to do some preprocessing
to make the image fit the input requirement of deepnet. The preprocessing
Expand Down
Loading

0 comments on commit 9c26111

Please sign in to comment.