Skip to content

Commit

Permalink
Merge pull request apache#9 from terrytangyuan/terry
Browse files Browse the repository at this point in the history
Adam Optimizer
  • Loading branch information
terrytangyuan committed Dec 22, 2015
2 parents 67abdee + ef46e26 commit 30ca7f2
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 13 deletions.
18 changes: 9 additions & 9 deletions example/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,31 +147,31 @@ model.fit(X=train_data, y=train_label)
The following factors may significant affect the performance:
- Use a fast backend. A fast BLAS library, e.g. openblas, altas,
1. Use a fast backend. A fast BLAS library, e.g. openblas, altas,
and mkl, is necessary if only using CPU. While for Nvidia GPUs, we strongly
recommend to use CUDNN.
- Three important things for the input data:
- data format. If you are using the `rec` format, then everything should be
2. Three important things for the input data:
1. data format. If you are using the `rec` format, then everything should be
fine.
- decoding. In default MXNet uses 4 CPU threads for decoding the images, which
2. decoding. In default MXNet uses 4 CPU threads for decoding the images, which
are often able to decode over 1k images per second. You
may increase the number of threads if either you are using a low-end CPU or
you GPUs are very powerful.
- place to store the data. Any local or distributed filesystem (HDFS, Amazon
3. place to store the data. Any local or distributed filesystem (HDFS, Amazon
S3) should be fine. There may be a problem if multiple machines read the
data from the network shared filesystem (NFS) at the same time.
- Use a large batch size. We often choose the largest one which can fit into
3. Use a large batch size. We often choose the largest one which can fit into
the GPU memory. But a too large value may slow down the convergence. For
example, the safe batch size for CIFAR 10 is around 200, while for ImageNet
1K, the batch size can go beyond 1K.
- Choose the proper `kvstore` if using more than one GPU. (See
4. Choose the proper `kvstore` if using more than one GPU. (See
[doc/developer-guide/multi_node.md](../../doc/developer-guide/multi_node.md)
for more information)
- For a single machine, often the default `local` is good enough. But you may want
1. For a single machine, often the default `local` is good enough. But you may want
to use `local_allreduce_device` for models with size >> 100MB such as AlexNet
and VGG. But also note that `local_allreduce_device` takes more GPU memory than
others.
- For multiple machines, we recommend to try `dist_sync` first. But if the
2. For multiple machines, we recommend to try `dist_sync` first. But if the
model size is quite large or you use a large number of machines, you may want to use `dist_async`.
## Results
Expand Down
13 changes: 13 additions & 0 deletions example/rnn/get_ptb_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/env bash

RNN_DIR=$(cd `dirname $0`; pwd)
DATA_DIR="${RNN_DIR}/data/"

if [[ ! -d "${DATA_DIR}" ]]; then
echo "${DATA_DIR} doesn't exist, will create one";
mkdir -p ${DATA_DIR}
fi

wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
4 changes: 2 additions & 2 deletions example/rnn/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def calc_nll(seq_label_probs, X, begin):
def train_lstm(model, X_train_batch, X_val_batch,
num_round, update_period,
optimizer='rmsprop', half_life=2,max_grad_norm = 5.0, **kwargs):
print("Training swith train.shape=%s" % str(X_train_batch.shape))
print("Training swith val.shape=%s" % str(X_val_batch.shape))
print("Training with train.shape=%s" % str(X_train_batch.shape))
print("Training with val.shape=%s" % str(X_val_batch.shape))
m = model
seq_len = len(m.seq_data)
batch_size = m.seq_data[0].shape[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import ml.dmlc.mxnet.{NDArray, Optimizer, LRScheduler}
import ml.dmlc.mxnet.NDArrayConversions._

/**
* Adam optimizer as described in [King2014]
*
* [King2014] Diederik Kingma, Jimmy Ba,
* Adam: A Method for Stochastic Optimization,
* http://arxiv.org/abs/1412.6980
*
* @author Yuan Tang
*
* @param learningRate Float, Step size.
* @param beta1 Float, Exponential decay rate for the first moment estimates.
* @param beta2 Float, Exponential decay rate for the second moment estimates.
* @param epsilon Float
* @param decayFactor Float
* @param wd Float, L2 regularization coefficient add to all the weights
* @param rescaleGrad Float, rescaling factor of gradient.
* @param clipGradient Float, clip gradient in range [-clip_gradient, clip_gradient]
* @param lrScheduler The learning rate scheduler
*/
class Adam(var learningRate: Float = 0.002f, val beta1: Float = 0.9f, val beta2: Float = 0.999f,
val epsilon: Float = 0.00000001f, val decayFactor: Float = 1-0.00000001f, val wd: Float = 0.0f,
rescaleGrad: Float = 1f, val clipGradient: Float = 0f,
val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) {

protected var time: Int = 0
protected var timeFirstIndex: Int = 0
/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
* @param weight weight ndarray
* @param grad grad ndarray
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
val lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
}) * lrScale.getOrElse(index, 1f)

var (mean, variance) = state

if (timeFirstIndex == 0) {
timeFirstIndex = index
time = 0
} else if (timeFirstIndex == index) {
time += 1
}

val t1: Int = time + 1
learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1))/(1.0 - math.pow(beta1, t1))) toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1) toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
}

val meanT = beta1t * mean.asInstanceOf[NDArray] + (1.0 - beta1t) * resdGrad toScalar
val varianceT = beta2 * variance.asInstanceOf[NDArray] + (1.0f - beta2) * resdGrad * resdGrad toScalar

var step = learningRate * meanT / (math.sqrt(varianceT) + epsilon)

if (wd > 0.0f) {
step += (lr * wd * weight).toScalar
}

weight += -step.toFloat
mean = meanT
variance = varianceT
}

// Create additional optimizer state: mean, variance
override def createState(index: Int, weight: NDArray): AnyRef = {
timeFirstIndex = 0
(NDArray.zeros(weight.shape, weight.context), // mean
NDArray.zeros(weight.shape, weight.context)) // variance
}
}
2 changes: 1 addition & 1 deletion src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ void GraphExecutor::InitGraph(const Symbol &symbol,
}
std::sort(head_nodes.begin(), head_nodes.end());
head_nodes.resize(std::unique(head_nodes.begin(), head_nodes.end()) - head_nodes.begin());
std::vector<uint32_t> fwd_nodes = graph_.PostDFSOrder(head_nodes, {});
std::vector<uint32_t> fwd_nodes = graph_.PostDFSOrder(head_nodes, std::unordered_set<uint32_t>());
num_forward_nodes_ = fwd_nodes.size();

std::unordered_set<uint32_t> fwd_set(fwd_nodes.begin(), fwd_nodes.end());
Expand Down
3 changes: 2 additions & 1 deletion src/symbol/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ class StaticGraph {
* \return a post DFS visit order of nodes that can reach heads.
*/
std::vector<uint32_t> PostDFSOrder(const std::vector<uint32_t>& head_nodes,
const std::unordered_set<uint32_t>& banned = {}) const;
const std::unordered_set<uint32_t>& banned
= std::unordered_set<uint32_t>()) const;
/*!
* \brief infer the node shapes in the computation graph.
*
Expand Down

0 comments on commit 30ca7f2

Please sign in to comment.