diff --git a/benchmark/python/sparse/updater.py b/benchmark/python/sparse/updater.py new file mode 100644 index 000000000000..72f2bfd04a27 --- /dev/null +++ b/benchmark/python/sparse/updater.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +import mxnet as mx +from mxnet.ndarray.sparse import adam_update +import numpy as np +import argparse + +mx.random.seed(0) +np.random.seed(0) + +parser = argparse.ArgumentParser(description='Benchmark adam updater') +parser.add_argument('--dim-in', type=int, default=240000, help='weight.shape[0]') +parser.add_argument('--dim-out', type=int, default=512, help='weight.shape[1]') +parser.add_argument('--nnr', type=int, default=5000, help='grad.indices.shape[0]') +parser.add_argument('--repeat', type=int, default=1000, help='num repeat') +parser.add_argument('--dense-grad', action='store_true', + help='if set to true, both gradient and weight are dense.') +parser.add_argument('--dense-state', action='store_true', + help='if set to true, states are dense, indicating standard update') +parser.add_argument('--cpu', action='store_true') + + +args = parser.parse_args() +dim_in = args.dim_in +dim_out = args.dim_out +nnr = args.nnr +ctx = mx.cpu() if args.cpu else mx.gpu() + +ones = mx.nd.ones((dim_in, dim_out), ctx=ctx) + +if not args.dense_grad: + weight = ones.tostype('row_sparse') + indices = np.arange(dim_in) + np.random.shuffle(indices) + indices = np.unique(indices[:nnr]) + indices = mx.nd.array(indices, ctx=ctx) + grad = mx.nd.sparse.retain(weight, indices) +else: + weight = ones.copy() + grad = ones.copy() + +if args.dense_state: + mean = ones.copy() +else: + mean = ones.tostype('row_sparse') + +var = mean.copy() + +# warmup +for i in range(10): + adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9, + beta2=0.99, rescale_grad=0.5, epsilon=1e-8) +weight.wait_to_read() + +# measure speed +a = time.time() +for i in range(args.repeat): + adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9, + beta2=0.99, rescale_grad=0.5, epsilon=1e-8) +weight.wait_to_read() +b = time.time() +print(b - a) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 55d215602eef..104f20a61eeb 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -749,6 +749,9 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, }); } +template +struct AdamDnsRspDnsKernel; + /*! * Note: this kernel performs sparse adam update. For each row-slice in row_sparse * gradient, it finds the corresponding elements in weight, mean and var and performs @@ -756,7 +759,7 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, * The kernel assumes dense weight/mean/var, and row_sparse gradient */ template -struct AdamDnsRspDnsKernel { +struct AdamDnsRspDnsKernel { template MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, @@ -788,6 +791,33 @@ struct AdamDnsRspDnsKernel { }; +template +struct AdamDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, + DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2, + const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) { + using nnvm::dim_t; + using namespace mshadow_op; + const dim_t row_id = i / row_length; + const dim_t col_id = i % row_length; + const dim_t row_offset = grad_idx[row_id] * row_length; + // index in data/mean/var + const dim_t data_i = row_offset + col_id; + // index in grad + DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[data_i] * wd; + if (clip_gradient >= 0.0f) { + grad_rescaled = clip::Map(grad_rescaled, clip_gradient); + } + mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled; + var_data[data_i] = beta2 * var_data[data_i] + + (1.f - beta2) * grad_rescaled * grad_rescaled; + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] / + (square_root::Map(var_data[data_i]) + epsilon)); + } +}; + template inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param, const OpContext& ctx, @@ -817,8 +847,12 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param, DType* out_data = out->dptr(); nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0]; const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); - Kernel, xpu>::Launch(s, num_rows, row_length, - out_data, mean_data, var_data, weight_data, grad_idx, grad_val, + size_t num_threads = num_rows; + if (std::is_same::value) { + num_threads = num_rows * row_length; + } + Kernel, xpu>::Launch(s, num_threads, + row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val, static_cast(param.clip_gradient), static_cast(param.beta1), static_cast(param.beta2), static_cast(param.lr), static_cast(param.wd), static_cast(param.epsilon), @@ -858,42 +892,8 @@ inline void AdamUpdateRspRspRspImpl(const AdamParam& param, var.data(), req, &out_blob); } -template -struct AdamStdDnsRspDnsKernel { - template - MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, - DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, - const DType* grad_data, const RType* prefix_sum, const DType clip_gradient, - const DType beta1, const DType beta2, const DType lr, const DType wd, - const DType epsilon, const DType rescale_grad) { - using namespace mshadow_op; - const bool non_zero = (i == 0) ? prefix_sum[0] > 0 - : prefix_sum[i] > prefix_sum[i-1]; - - const index_t row_i = i * row_length; - const RType grad_i = (prefix_sum[i]-1) * row_length; - for (index_t j = 0; j < row_length; j++) { - const index_t data_i = row_i + j; - const DType grad_rescaled = non_zero ? static_cast( - grad_data[grad_i + j] * rescale_grad + - weight_data[data_i] * wd) - : static_cast(weight_data[data_i] * wd); - if (clip_gradient >= 0.0f) { - mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * - clip::Map(grad_rescaled, clip_gradient); - var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map( - clip::Map(grad_rescaled, clip_gradient)); - } else { - mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled; - var_data[data_i] = beta2 * var_data[data_i] + - (1.f - beta2) * square::Map(grad_rescaled); - } - KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] / - (square_root::Map(var_data[data_i]) + epsilon)); - } - } -}; - +template +struct AdamStdDnsRspDnsKernel; template void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 741092ad7844..f7ccbbb739d6 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -149,6 +149,43 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } +template +struct AdamStdDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, + DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const RType* prefix_sum, const DType clip_gradient, + const DType beta1, const DType beta2, const DType lr, const DType wd, + const DType epsilon, const DType rescale_grad) { + using namespace mshadow_op; + const bool non_zero = (i == 0) ? prefix_sum[0] > 0 + : prefix_sum[i] > prefix_sum[i-1]; + + const index_t row_i = i * row_length; + const RType grad_i = (prefix_sum[i]-1) * row_length; + for (index_t j = 0; j < row_length; j++) { + const index_t data_i = row_i + j; + const DType grad_rescaled = non_zero ? static_cast( + grad_data[grad_i + j] * rescale_grad + + weight_data[data_i] * wd) + : static_cast(weight_data[data_i] * wd); + if (clip_gradient >= 0.0f) { + mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * + clip::Map(grad_rescaled, clip_gradient); + var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map( + clip::Map(grad_rescaled, clip_gradient)); + } else { + mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled; + var_data[data_i] = beta2 * var_data[data_i] + + (1.f - beta2) * square::Map(grad_rescaled); + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] / + (square_root::Map(var_data[data_i]) + epsilon)); + } + } +}; + + template<> void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, const OpContext& ctx, @@ -194,7 +231,7 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, } } - Kernel, cpu>::Launch(s, num_rows, row_length, + Kernel, cpu>::Launch(s, num_rows, row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum, static_cast(param.clip_gradient), static_cast(param.beta1), static_cast(param.beta2), static_cast(param.lr), diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index c49af68a5f68..18ee66a729c2 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -94,6 +94,35 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } +template +struct AdamStdDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, + DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const RType* prefix_sum, const DType clip_gradient, + const DType beta1, const DType beta2, const DType lr, const DType wd, + const DType epsilon, const DType rescale_grad) { + using namespace mshadow_op; + using nnvm::dim_t; + const dim_t row_id = i / row_length; + const dim_t col_id = i % row_length; + const bool non_zero = (row_id == 0) ? prefix_sum[0] > 0 + : prefix_sum[row_id] > prefix_sum[row_id - 1]; + const RType grad_offset = (prefix_sum[row_id] - 1) * row_length + col_id; + DType grad_rescaled = non_zero ? static_cast(grad_data[grad_offset] * rescale_grad + + weight_data[i] * wd) + : static_cast(weight_data[i] * wd); + if (clip_gradient >= 0.0f) { + grad_rescaled = clip::Map(grad_rescaled, clip_gradient); + } + mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled; + var_data[i] = beta2 * var_data[i] + + (1.f - beta2) * square::Map(grad_rescaled); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] / + (square_root::Map(var_data[i]) + epsilon)); + } +}; + template<> void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, const OpContext& ctx, @@ -122,8 +151,8 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, DType* mean_data = mean.dptr(); DType* var_data = var.dptr(); DType* out_data = out->dptr(); - nnvm::dim_t num_rows = weight.shape_[0]; - nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim()); + const nnvm::dim_t num_rows = weight.shape_[0]; + const nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim()); nnvm::dim_t* prefix_sum = NULL; void* d_temp_storage = NULL; size_t temp_storage_bytes = 0; @@ -152,8 +181,8 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, Stream::GetStream(s)); } - Kernel, gpu>::Launch(s, num_rows, row_length, - out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum, + Kernel, gpu>::Launch(s, weight.shape_.Size(), + row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum, static_cast(param.clip_gradient), static_cast(param.beta1), static_cast(param.beta2), static_cast(param.lr), static_cast(param.wd), static_cast(param.epsilon), diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index f71e2c81e27e..bbd7845f66f3 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -543,13 +543,13 @@ def test_ftml(): class PyAdam(mx.optimizer.Optimizer): """python reference implemenation of adam""" def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - decay_factor=(1 - 1e-8), sparse_update=False, **kwargs): + decay_factor=(1 - 1e-8), lazy_update=False, **kwargs): super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.decay_factor = decay_factor - self.sparse_update = sparse_update + self.lazy_update = lazy_update def create_state(self, index, weight): """Create additional optimizer state: mean, variance @@ -595,7 +595,7 @@ def update(self, index, weight, grad, state): # check row slices of all zeros all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy())) # skip zeros during sparse update - if all_zeros and self.sparse_update: + if all_zeros and self.lazy_update: continue grad[row] = grad[row] * self.rescale_grad + wd * weight[row] # clip gradients @@ -638,7 +638,7 @@ def test_adam(): compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, rtol=1e-4, atol=2e-5) # atol 2e-5 needed to pass with seed 781809840 - compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape, + compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse', rtol=1e-4, atol=2e-5) compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, @@ -883,12 +883,12 @@ class PyFtrl(mx.optimizer.Optimizer): \\eta_{t,i} = \\frac{learningrate}{\\beta+\\sqrt{\\sum_{s=1}^tg_{s,i}^t}} """ - def __init__(self, lamda1=0.01, learning_rate=0.1, beta=1, sparse_update=False, **kwargs): + def __init__(self, lamda1=0.01, learning_rate=0.1, beta=1, lazy_update=False, **kwargs): super(PyFtrl, self).__init__(**kwargs) self.lamda1 = lamda1 self.beta = beta self.lr = learning_rate - self.sparse_update = sparse_update + self.lazy_update = lazy_update def create_state(self, index, weight): return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # dn @@ -903,7 +903,7 @@ def update(self, index, weight, grad, state): dn, n = state for row in range(num_rows): all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy())) - if all_zeros and self.sparse_update: + if all_zeros and self.lazy_update: continue grad[row] = grad[row] * self.rescale_grad if self.clip_gradient is not None: @@ -933,7 +933,7 @@ def test_ftrl(): {'clip_gradient': 0.5, 'wd': 0.07, 'lamda1': 1.0}] for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32) - compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape, + compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape, np.float32, w_stype='row_sparse', g_stype='row_sparse') @with_seed(1234)