Skip to content

Commit

Permalink
sparse Adam optimizer (apache#164)
Browse files Browse the repository at this point in the history
*  add sparse adam

* register gpu op

* add comments

* cr comments
  • Loading branch information
eric-haibin-lin committed Aug 14, 2017
1 parent 889a09e commit 6b0cac1
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 19 deletions.
141 changes: 141 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,147 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
});
}

/*!
* Note: this kernel performs sparse adam update. For each row-slice in row_sparse
* gradient, it finds the corresponding elements in weight, mean and var and performs
* the update.
* The kernel assumes dense weight/mean/var, and row_sparse gradient
*/
template<int req>
struct AdamDnsRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
using nnvm::dim_t;
using namespace mshadow_op;
const dim_t row_offset = grad_idx[i] * row_length;
for (dim_t j = 0; j < row_length; j++) {
// index in data/mean/var
const dim_t data_i = row_offset + j;
// index in grad
const dim_t grad_i = i * row_length + j;
const DType grad_rescaled = grad_data[grad_i] * rescale_grad + weight_data[data_i] * wd;
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * grad_rescaled * grad_rescaled;
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
square_root::Map(var_data[data_i]) + epsilon);
}
}
};


template<typename xpu>
inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mean,
const TBlob& var,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
if (!grad.storage_initialized() || req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mean.shape_.Size(), 0);
CHECK_GT(var.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const DType* weight_data = weight.dptr<DType>();
const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
DType* mean_data = mean.dptr<DType>();
DType* var_data = var.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
Kernel<AdamDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
});
});
});
}

template<typename xpu>
inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const NDArray& mean,
const NDArray& var,
const OpReqType& req,
NDArray *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mean and variance with zero values in order to reuse the sgd mom dns impl
if (!mean.storage_initialized()) {
NDArray mean_zeros = mean;
FillDnsZerosRspImpl(s, &mean_zeros);
}
if (!var.storage_initialized()) {
NDArray var_zeros = var;
FillDnsZerosRspImpl(s, &var_zeros);
}
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
AdamUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
var.data(), req, &out_blob);
}


template<typename xpu>
inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const auto weight_stype = inputs[0].storage_type();
const auto grad_stype = inputs[1].storage_type();
const auto mean_stype = inputs[2].storage_type();
const auto var_stype = inputs[3].storage_type();

const auto out_stype = outputs[0].storage_type();
CHECK_EQ(mean_stype, weight_stype) << "Inconsistent storage type detected between "
<< " mean.stype = " << mean_stype << " and weight.stype = " << weight_stype;
CHECK_EQ(var_stype, weight_stype) << "Inconsistent storage type detected between "
<< " var.stype = " << var_stype << " and weight.stype = " << weight_stype;
if (weight_stype == kRowSparseStorage && mean_stype == kRowSparseStorage &&
var_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
out_stype == kRowSparseStorage) {
NDArray out = outputs[0];
AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
inputs[3], req[0], &out);
} else {
LOG(FATAL) << "Unexpected storage types: weight.stype = " << weight_stype
<< ", var.stype = " << var_stype << ", mean.stype = " << mean_stype
<< ", grad.stype = " << grad_stype;
}
}

// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
// by Alex Graves, 2013.
Expand Down
1 change: 1 addition & 0 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ It updates the weights using::
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", AdamUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
Expand Down
3 changes: 2 additions & 1 deletion src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ NNVM_REGISTER_OP(mp_sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MP_SGDMomUpdate<gpu>);

NNVM_REGISTER_OP(adam_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>);
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdamUpdateEx<gpu>);

NNVM_REGISTER_OP(rmsprop_update)
.set_attr<FCompute>("FCompute<gpu>", RMSPropUpdate<gpu>);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = dst->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1));
});
});
}
Expand Down
40 changes: 23 additions & 17 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,13 @@ def test_sparse_sgd():
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
decay_factor=(1 - 1e-8), **kwargs):
decay_factor=(1 - 1e-8), sparse_update=False, **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay_factor = decay_factor
self.sparse_update = sparse_update

def create_state(self, index, weight):
"""Create additional optimizer state: mean, variance
Expand Down Expand Up @@ -355,21 +356,28 @@ def update(self, index, weight, grad, state):
mean, variance = state

wd = self._get_wd(index)
grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad)

mean *= self.beta1
mean += grad * (1. - self.beta1)

variance *= self.beta2
variance += (1 - self.beta2) * mx.nd.square(grad, out=grad)

num_rows = weight.shape[0]
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
lr *= math.sqrt(coef2)/coef1

weight -= lr*mean/(mx.nd.sqrt(variance) + self.epsilon)
for row in range(num_rows):
# check row slices of all zeros
all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy()))
# skip zeros during sparse update
if all_zeros and self.sparse_update:
continue
grad[row] = grad[row] * self.rescale_grad + wd * weight[row]
# clip gradients
if self.clip_gradient is not None:
mx.nd.clip(grad[row], -self.clip_gradient, self.clip_gradient, out=grad[row])
# update mean
mean[row] *= self.beta1
mean[row] += grad[row] * (1. - self.beta1)
# update variance
variance[row] *= self.beta2
variance[row] += (1 - self.beta2) * mx.nd.square(grad[row], out=grad[row])
# update weight
weight[row] -= lr*mean[row]/(mx.nd.sqrt(variance[row]) + self.epsilon)


def test_adam():
Expand All @@ -386,10 +394,8 @@ def test_adam():
{'rescale_grad': 0.8, 'wd': 0.05}]
for kwarg in kwargs:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)
# test operator fallback on cpu
if (default_context() == mx.cpu()):
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
np.float32, g_stype='row_sparse')
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
np.float32, w_stype='row_sparse', g_stype='row_sparse')

# RMSProp
class PyRMSProp(mx.optimizer.Optimizer):
Expand Down

0 comments on commit 6b0cac1

Please sign in to comment.