From 6b0cac1f662eb7a7d2da5dcd7aa6ad01fe3c5d69 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Mon, 14 Aug 2017 15:41:59 -0700 Subject: [PATCH] sparse Adam optimizer (#164) * add sparse adam * register gpu op * add comments * cr comments --- src/operator/optimizer_op-inl.h | 141 ++++++++++++++++++++++++ src/operator/optimizer_op.cc | 1 + src/operator/optimizer_op.cu | 3 +- src/operator/tensor/init_op.h | 2 +- tests/python/unittest/test_optimizer.py | 40 ++++--- 5 files changed, 168 insertions(+), 19 deletions(-) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index a253bfe25c8f..3911510e1bfd 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -700,6 +700,147 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, }); } +/*! + * Note: this kernel performs sparse adam update. For each row-slice in row_sparse + * gradient, it finds the corresponding elements in weight, mean and var and performs + * the update. + * The kernel assumes dense weight/mean/var, and row_sparse gradient + */ +template +struct AdamDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, + DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2, + const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) { + using nnvm::dim_t; + using namespace mshadow_op; + const dim_t row_offset = grad_idx[i] * row_length; + for (dim_t j = 0; j < row_length; j++) { + // index in data/mean/var + const dim_t data_i = row_offset + j; + // index in grad + const dim_t grad_i = i * row_length + j; + const DType grad_rescaled = grad_data[grad_i] * rescale_grad + weight_data[data_i] * wd; + if (clip_gradient >= 0.0f) { + mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * + clip::Map(grad_rescaled, clip_gradient); + var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map( + clip::Map(grad_rescaled, clip_gradient)); + } else { + mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled; + var_data[data_i] = beta2 * var_data[data_i] + + (1.f - beta2) * grad_rescaled * grad_rescaled; + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] / + square_root::Map(var_data[data_i]) + epsilon); + } + } +}; + + +template +inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mean, + const TBlob& var, + const OpReqType& req, + TBlob *out) { + using namespace mxnet_op; + using namespace rowsparse; + Stream* s = ctx.get_stream(); + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update"; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mean.shape_.Size(), 0); + CHECK_GT(var.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + const DType* weight_data = weight.dptr(); + const IType* grad_idx = grad.aux_data(kIdx).dptr(); + const DType* grad_val = grad.data().dptr(); + DType* mean_data = mean.dptr(); + DType* var_data = var.dptr(); + DType* out_data = out->dptr(); + nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0]; + const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, row_length, + out_data, mean_data, var_data, weight_data, grad_idx, grad_val, + static_cast(param.clip_gradient), static_cast(param.beta1), + static_cast(param.beta2), static_cast(param.lr), + static_cast(param.wd), static_cast(param.epsilon), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +inline void AdamUpdateRspRspRspImpl(const AdamParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mean, + const NDArray& var, + const OpReqType& req, + NDArray *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights"); + Stream* s = ctx.get_stream(); + // fill mean and variance with zero values in order to reuse the sgd mom dns impl + if (!mean.storage_initialized()) { + NDArray mean_zeros = mean; + FillDnsZerosRspImpl(s, &mean_zeros); + } + if (!var.storage_initialized()) { + NDArray var_zeros = var; + FillDnsZerosRspImpl(s, &var_zeros); + } + TBlob out_blob = out->data(); + // reuse dns rsp implementation when storage_shape == shape + AdamUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, mean.data(), + var.data(), req, &out_blob); +} + + +template +inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const AdamParam& param = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + const auto weight_stype = inputs[0].storage_type(); + const auto grad_stype = inputs[1].storage_type(); + const auto mean_stype = inputs[2].storage_type(); + const auto var_stype = inputs[3].storage_type(); + + const auto out_stype = outputs[0].storage_type(); + CHECK_EQ(mean_stype, weight_stype) << "Inconsistent storage type detected between " + << " mean.stype = " << mean_stype << " and weight.stype = " << weight_stype; + CHECK_EQ(var_stype, weight_stype) << "Inconsistent storage type detected between " + << " var.stype = " << var_stype << " and weight.stype = " << weight_stype; + if (weight_stype == kRowSparseStorage && mean_stype == kRowSparseStorage && + var_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && + out_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + AdamUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], inputs[2], + inputs[3], req[0], &out); + } else { + LOG(FATAL) << "Unexpected storage types: weight.stype = " << weight_stype + << ", var.stype = " << var_stype << ", mean.stype = " << mean_stype + << ", grad.stype = " << grad_stype; + } +} + // This RMSProp code follows the version in // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) // by Alex Graves, 2013. diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 9bcd1bcb33f6..9b2b088c5095 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -160,6 +160,7 @@ It updates the weights using:: return std::vector{2, 3}; }) .set_attr("FCompute", AdamUpdate) +.set_attr("FComputeEx", AdamUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index f23218ac5fac..fe45f4be8c66 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -42,7 +42,8 @@ NNVM_REGISTER_OP(mp_sgd_mom_update) .set_attr("FCompute", MP_SGDMomUpdate); NNVM_REGISTER_OP(adam_update) -.set_attr("FCompute", AdamUpdate); +.set_attr("FCompute", AdamUpdate) +.set_attr("FComputeEx", AdamUpdateEx); NNVM_REGISTER_OP(rmsprop_update) .set_attr("FCompute", RMSPropUpdate); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 1ac933ddaef5..0cd81d77133c 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -162,7 +162,7 @@ inline void FillDnsZerosRspImpl(mshadow::Stream *s, NDArray *dst) { auto idx = dst->aux_data(kIdx).FlatTo1D(s); auto val = dst->data(); Kernel::Launch(s, val.Size(), val.dptr()); - ASSIGN_DISPATCH(idx, kWriteTo, range(0, num_rows, 1, 1)) + ASSIGN_DISPATCH(idx, kWriteTo, range(0, num_rows, 1, 1)); }); }); } diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 7b9f4a1e6f43..055f6464f0ef 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -312,12 +312,13 @@ def test_sparse_sgd(): class PyAdam(mx.optimizer.Optimizer): """python reference implemenation of adam""" def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - decay_factor=(1 - 1e-8), **kwargs): + decay_factor=(1 - 1e-8), sparse_update=False, **kwargs): super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.decay_factor = decay_factor + self.sparse_update = sparse_update def create_state(self, index, weight): """Create additional optimizer state: mean, variance @@ -355,21 +356,28 @@ def update(self, index, weight, grad, state): mean, variance = state wd = self._get_wd(index) - grad = grad * self.rescale_grad + wd * weight - if self.clip_gradient is not None: - mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad) - - mean *= self.beta1 - mean += grad * (1. - self.beta1) - - variance *= self.beta2 - variance += (1 - self.beta2) * mx.nd.square(grad, out=grad) - + num_rows = weight.shape[0] coef1 = 1. - self.beta1**t coef2 = 1. - self.beta2**t lr *= math.sqrt(coef2)/coef1 - - weight -= lr*mean/(mx.nd.sqrt(variance) + self.epsilon) + for row in range(num_rows): + # check row slices of all zeros + all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy())) + # skip zeros during sparse update + if all_zeros and self.sparse_update: + continue + grad[row] = grad[row] * self.rescale_grad + wd * weight[row] + # clip gradients + if self.clip_gradient is not None: + mx.nd.clip(grad[row], -self.clip_gradient, self.clip_gradient, out=grad[row]) + # update mean + mean[row] *= self.beta1 + mean[row] += grad[row] * (1. - self.beta1) + # update variance + variance[row] *= self.beta2 + variance[row] += (1 - self.beta2) * mx.nd.square(grad[row], out=grad[row]) + # update weight + weight[row] -= lr*mean[row]/(mx.nd.sqrt(variance[row]) + self.epsilon) def test_adam(): @@ -386,10 +394,8 @@ def test_adam(): {'rescale_grad': 0.8, 'wd': 0.05}] for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32) - # test operator fallback on cpu - if (default_context() == mx.cpu()): - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, - np.float32, g_stype='row_sparse') + compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape, + np.float32, w_stype='row_sparse', g_stype='row_sparse') # RMSProp class PyRMSProp(mx.optimizer.Optimizer):