Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
standard update for sparse sgd_mom_update (#9189)
Browse files Browse the repository at this point in the history
* standard sparse sgd mom update

* update

* update comments

* address comments

* revise

* more general infer stype

* fix

* fix

* add comments for stype inference func

* update
  • Loading branch information
ZiyueHuang authored and eric-haibin-lin committed Jan 5, 2018
1 parent 9be50e0 commit df9f79a
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 17 deletions.
25 changes: 15 additions & 10 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,8 @@ def _get_wd(self, index):
class SGD(Optimizer):
"""The SGD optimizer with momentum and weight decay.
The optimizer updates the weight by::
rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + rescaled_grad
weight = weight - state
If the storage types of weight, state and grad are all ``row_sparse``, \
**sparse updates** are applied by::
If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row]
Expand All @@ -454,6 +448,12 @@ class SGD(Optimizer):
provides slightly different semantics than the original update, and
may lead to different empirical results.
Otherwise, **standard updates** are applied by::
rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + rescaled_grad
weight = weight - state
For details of the update algorithm see
:class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
Expand All @@ -464,16 +464,20 @@ class SGD(Optimizer):
----------
momentum : float, optional
The momentum value.
lazy_update : bool, optional
Default is True. If True, lazy updates are applied \
if the storage types of weight and grad are both ``row_sparse``.
multi_precision: bool, optional
Flag to control the internal precision of the optimizer.
``False`` results in using the same precision as the weights (default),
``True`` makes internal 32-bit copy of the weights and applies gradients \
in 32-bit precision even if actual weights used in the model have lower precision.\
Turning this on can improve convergence and accuracy when training with float16.
"""
def __init__(self, momentum=0.0, **kwargs):
def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
super(SGD, self).__init__(**kwargs)
self.momentum = momentum
self.lazy_update = lazy_update

def create_state_multi_precision(self, index, weight):
weight_master_copy = None
Expand All @@ -489,8 +493,9 @@ def create_state_multi_precision(self, index, weight):

def create_state(self, index, weight):
momentum = None
stype = weight.stype if self.lazy_update else 'default'
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
return momentum

def _update_impl(self, index, weight, grad, state, multi_precision=False):
Expand Down
112 changes: 108 additions & 4 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "./elemwise_op_common.h"
#include "mxnet_op.h"
#include "./tensor/init_op.h"
#include "./tensor/util/tensor_util-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -460,6 +461,106 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
mom.data(), req, &out_blob);
}

/*!
* \brief Storge type inference function in optimizer.
* \param n_rsp The number of inputs that should be of row_sparse storage type
* if kFComputeEx is dispatched
* \param n_rsp_dns The number of inputs that should be of row_sparse or default storage type
* if kFComputeEx is dispatched
*/
template<int n_rsp, int n_rsp_dns>
inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
CHECK_EQ(out_attrs->size(), 1U);
bool dispatched = false;

if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
// dns, ... -> dns
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + n_rsp);
const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, in_attrs->end());
if (!dispatched && common::ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) &&
(common::ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
common::ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
// rsp, ..., rsp/dns, ... -> rsp
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}

if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
}
return true;
}

template<int req>
struct SGDMomStdDnsRspDnsKernel {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
DType* mom_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType momentum, const DType lr, const DType wd, const DType rescale_grad) {
const DType rate = lr * wd;
const bool non_zero = (i == 0) ? prefix_sum[0] > 0
: prefix_sum[i] > prefix_sum[i-1];

const index_t row_i = i * row_length;
const RType grad_i = (prefix_sum[i]-1) * row_length;
for (index_t j = 0; j < row_length; j++) {
const index_t data_i = row_i + j;
const DType grad = non_zero ? grad_data[grad_i + j]
: static_cast<DType>(0);
if (clip_gradient >= 0.0f) {
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
- lr *
mshadow_op::clip::Map(rescale_grad * grad,
clip_gradient);
} else {
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
- lr * rescale_grad * grad;
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
}
}
};

template<typename xpu>
void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mom,
const OpReqType& req,
TBlob *out);

template<typename xpu>
inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const NDArray& mom,
const OpReqType& req,
NDArray *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
TBlob out_blob = out->data();
SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), req, &out_blob);
}

template<typename xpu>
inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
Expand All @@ -474,12 +575,15 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const auto weight_stype = weight.storage_type();
const auto mom_stype = mom.storage_type();
const auto out_stype = outputs[0].storage_type();
CHECK_EQ(weight_stype, mom_stype) << "Inconsistent storage type detected between mom.stype = "
<< mom_stype << " and weight.stype = " << weight_stype;
NDArray out = outputs[0];
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
out_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
} else if (weight.storage_type() == kRowSparseStorage &&
grad.storage_type() == kRowSparseStorage &&
mom.storage_type() == kDefaultStorage &&
out_stype == kRowSparseStorage) {
SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down
62 changes: 60 additions & 2 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,57 @@ DMLC_REGISTER_PARAMETER(RMSPropParam);
DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
DMLC_REGISTER_PARAMETER(FtrlParam);

template<>
void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mom,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
using namespace mshadow;
Stream<cpu>* s = ctx.get_stream<cpu>();
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mom.shape_.Size(), 0);
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
DType* grad_val = grad.data().dptr<DType>();
DType* mom_data = mom.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = weight.shape_[0];
auto row_length = weight.shape_.ProdShape(1, weight.ndim());
Tensor<cpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);

nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
// mark row flags
Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
if (grad.storage_initialized()) {
Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
prefix_sum, grad_idx);
// calculate inclusive prefix sum
for (nnvm::dim_t i = 1; i < num_rows; i++) {
prefix_sum[i] += prefix_sum[i - 1];
}
}
Kernel<SGDMomStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length,
out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
});
}


NNVM_REGISTER_OP(sgd_update)
MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
.describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer.
Expand Down Expand Up @@ -84,7 +135,10 @@ It updates the weights using::
Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
If weight and momentum are both of ``row_sparse`` storage type,
If weight and grad are both of ``row_sparse`` storage type and momentum is of ``default`` storage type,
standard update is applied.
If weight, grad and momentum are all of ``row_sparse`` storage type,
only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::
for row in gradient.indices:
Expand All @@ -97,11 +151,15 @@ only the row slices whose indices appear in grad.indices are updated (for both w
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, false, true, false>)
.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
Expand Down
66 changes: 66 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,76 @@
* \author Junyuan Xie
*/
#include "./optimizer_op-inl.h"
#include <cub/cub.cuh>

namespace mxnet {
namespace op {

template<>
void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mom,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
using namespace mshadow;
Stream<gpu>* s = ctx.get_stream<gpu>();
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mom.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
DType* grad_val = grad.data().dptr<DType>();
DType* mom_data = mom.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = weight.shape_[0];
nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());

nnvm::dim_t* prefix_sum = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
num_rows,
Stream<gpu>::GetStream(s));
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) +
temp_storage_bytes), s);
prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
// mark row flags
Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), kWriteTo, 0);
if (grad.storage_initialized()) {
Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
prefix_sum, grad_idx);
// calculate inclusive prefix sum
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
num_rows,
mshadow::Stream<gpu>::GetStream(s));
}
Kernel<SGDMomStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, row_length,
out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
});
}

NNVM_REGISTER_OP(sgd_update)
.set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
Expand Down
24 changes: 23 additions & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,29 @@ def test_sparse_sgd():
w_stype='row_sparse', g_stype='row_sparse')


def test_std_sparse_sgd():
mx.random.seed(0)
opt1 = PySGD
opt2 = mx.optimizer.SGD
shape = (3, 4, 5)
mom_options = [{'momentum': 0.9}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
for dtype in [np.float32]:
for mom_option in mom_options:
for cg_option in cg_options:
for rg_option in rg_options:
for wd_option in wd_options:
kwarg = {}
kwarg.update(mom_option)
kwarg.update(cg_option)
kwarg.update(rg_option)
kwarg.update(wd_option)
compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype,
w_stype='row_sparse', g_stype='row_sparse')


# FTML

class PyFTML(mx.optimizer.Optimizer):
Expand Down Expand Up @@ -400,7 +423,6 @@ def test_ftml():
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)



# ADAM

class PyAdam(mx.optimizer.Optimizer):
Expand Down

0 comments on commit df9f79a

Please sign in to comment.