diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index d953e9247900..c3a1f3374a94 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -26,18 +26,21 @@ import os import numpy from ..base import py_str -from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) +from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply, + multi_sum_sq, multi_lars, norm as NDnorm) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, - multi_mp_sgd_mom_update) + multi_mp_sgd_mom_update, preloaded_multi_sgd_update, + preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update, + preloaded_multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal from ..util import is_np_array __all__ = [ - 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD', + 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD', 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' ] @@ -781,6 +784,266 @@ def update(self, index, weight, grad, state): ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight, lr=lr, wd=wd, **kwargs) +@register +class LARS(Optimizer): + """the LARS optimizer from 'Large Batch Training of Convolution Networks' \ + (https://arxiv.org/abs/1708.03888) + + Behave mostly like SGD with momentum and weight decay but is scaling \ + adaptively the learning for each layer (except bias and batch norm parameters): + w_norm = L2norm(weights) + g_norm = L2norm(gradients) + if w_norm > 0 and g_norm > 0: + lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps) + else: + lr_layer = lr * lr_mult + + Parameters + ---------- + momentum : float, optional + The momentum value. + lazy_update : bool, optional + Default is True. If True, lazy updates are applied \ + if the storage types of weight and grad are both ``row_sparse``. + lars_eta : float, optional + LARS coefficient used to scale the learning rate. Default set to 0.001. + lars_epsilon : float, optional + Optional epsilon in case of very small gradients. Default set to 0. + momentum_correction : bool, optional + If True scale momentum w.r.t global learning rate change (with an lr_scheduler) \ + as indicated in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` \ + (https://arxiv.org/pdf/1706.02677.pdf) + Default set to True. + """ + def __init__(self, momentum=0.0, lazy_update=True, eta=0.001, eps=0, + momentum_correction=True, **kwargs): + super(LARS, self).__init__(**kwargs) + self.momentum = momentum + self.momentum_correction = momentum_correction + self.lazy_update = lazy_update + self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4")) + self.eta = eta + self.eps = eps + self.skip = 0 + self.last_lr = None + self.cur_lr = None + + + def _get_lrs(self, indices): + """Gets the learning rates given the indices of the weights. + + Parameters + ---------- + indices : list of int + Indices corresponding to weights. + + Returns + ------- + lrs : list of float + Learning rates for those indices. + """ + if self.cur_lr is not None: + self.last_lr = self.cur_lr + + if self.lr_scheduler is not None: + lr = self.lr_scheduler(self.num_update) + else: + lr = self.lr + + if self.cur_lr is None: + self.last_lr = lr + self.cur_lr = lr + + lrs = [lr for _ in indices] + for i, index in enumerate(indices): + if index in self.param_dict: + lrs[i] *= self.param_dict[index].lr_mult + elif index in self.lr_mult: + lrs[i] *= self.lr_mult[index] + elif index in self.idx2name: + lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0) + return lrs + + def set_wd_mult(self, args_wd_mult): + self.wd_mult = {} + for n in self.idx2name.values(): + is_weight = n.endswith('_weight') + + if not is_weight: + self.wd_mult[n] = 0.0 + + if self.sym_info: + attr, arg_names = self.sym_info + for name in arg_names: + if name in attr and '__wd_mult__' in attr[name]: + self.wd_mult[name] = float(attr[name]['__wd_mult__']) + self.wd_mult.update(args_wd_mult) + + def create_state_multi_precision(self, index, weight): + weight_master_copy = None + if self.multi_precision and weight.dtype == numpy.float16: + weight_master_copy = weight.astype(numpy.float32) + return (self.create_state(index, weight_master_copy), weight_master_copy) + if weight.dtype == numpy.float16 and not self.multi_precision: + warnings.warn("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "SGD optimizer") + return self.create_state(index, weight) + + def create_state(self, index, weight): + momentum = None + if self.momentum != 0.0: + stype = weight.stype if self.lazy_update else 'default' + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) + return momentum + + def _l2norm(self, v, rescale=False): + """L2 Norm implementation""" + v = v.astype('float32') + if rescale: + v *= self.rescale_grad + norm = NDnorm(v).asnumpy()[0] + return norm + + def _get_lars(self, i, weight, g, lr, wd): + """Returns a scaling factor for the learning rate for this layer""" + name = self.idx2name[i] if i in self.idx2name else str(i) + if name.endswith('gamma') or name.endswith('beta') or name.endswith('bias'): + return lr + + w_norm = self._l2norm(weight) + g_norm = self._l2norm(g, rescale=True) + + if w_norm > 0.0 and g_norm > 0.0: + lars = self.eta * w_norm/(g_norm + wd * w_norm + self.eps) + else: + lars = 1.0 + return lars * lr + + def _update_impl(self, indices, weights, grads, states, multi_precision=False): + aggregate = True + if not isinstance(indices, (tuple, list)): + indices = [indices] + weights = [weights] + grads = [grads] + states = [states] + for weight, grad in zip(weights, grads): + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + aggregate = (aggregate and + weight.stype == 'default' and + grad.stype == 'default') + self._update_count(indices) + lrs = self._get_lrs(indices) + wds = self._get_wds(indices) + + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = (self.momentum * (self.cur_lr / self.last_lr)) \ + if (self.momentum_correction and self.last_lr != 0) else \ + self.momentum + + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + + if aggregate: + nb_params = len(indices) + names = [self.idx2name[i] if i in self.idx2name else str(i) for i in indices] + lars_idx = [i for i in range(nb_params) if + not(names[i].endswith('gamma') or names[i].endswith('beta') or + names[i].endswith('bias'))] + nb_lars = len(lars_idx) + no_lars_idx = [i for i in range(nb_params) if + (names[i].endswith('gamma') or names[i].endswith('beta') or + names[i].endswith('bias'))] + cur_ctx = weights[0].context + full_idx = lars_idx + no_lars_idx + new_lrs = array([lrs[i] for i in full_idx], ctx=cur_ctx, dtype='float32') + new_wds = array([wds[i] for i in full_idx], ctx=cur_ctx, dtype='float32') + new_weights = [weights[i] for i in full_idx] + new_grads = [grads[i] for i in full_idx] + new_states = [states[i] for i in full_idx] + if nb_lars > 0: + w_sum_sq = multi_sum_sq(*new_weights[:nb_lars], num_arrays=nb_lars) + g_sum_sq = multi_sum_sq(*new_grads[:nb_lars], num_arrays=nb_lars) + multi_lars(new_lrs[:nb_lars], w_sum_sq, g_sum_sq, new_wds[:nb_lars], + eta=self.eta, eps=self.eps, rescale_grad=self.rescale_grad, + out=new_lrs[:nb_lars]) + # Same than usual using preloaded sgd functions + sidx = 0 + while sidx < len(indices): + eidx = sidx + len(new_weights[sidx:sidx+self.aggregate_num]) + if not multi_precision: + if self.momentum > 0: + preloaded_multi_sgd_mom_update( + *(_flatten_list(zip(new_weights[sidx:eidx], + new_grads[sidx:eidx], + new_states[sidx:eidx])) + + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), + out=new_weights[sidx:eidx], + num_weights=len(new_weights[sidx:eidx]), + **kwargs) + else: + preloaded_multi_sgd_update( + *(_flatten_list(zip(new_weights[sidx:eidx], + new_grads[sidx:eidx])) + + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), + out=new_weights[sidx:eidx], + num_weights=len(new_weights[sidx:eidx]), + **kwargs) + else: + if self.momentum > 0: + preloaded_multi_mp_sgd_mom_update( + *(_flatten_list(zip(new_weights[sidx:eidx], + new_grads[sidx:eidx], + *zip(*new_states[sidx:eidx]))) + + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), + out=new_weights[sidx:eidx], + num_weights=len(new_weights[sidx:eidx]), + **kwargs) + else: + preloaded_multi_mp_sgd_update( + *(_flatten_list(zip(new_weights[sidx:eidx], + new_grads[sidx:eidx], + list(zip(*new_states[sidx:eidx]))[1])) + + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), + out=new_weights[sidx:eidx], + num_weights=len(new_weights[sidx:eidx]), + **kwargs) + sidx += self.aggregate_num + else: + lrs = [self._get_lars(i, w, g, lr, wd) for (i, w, g, lr, wd) in + zip(indices, weights, grads, lrs, wds)] + + for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds): + if not multi_precision: + if state is not None: + sgd_mom_update(weight, grad, state, out=weight, + lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) + else: + sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, + lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_sgd_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state, multi_precision=False) + + def update_multi_precision(self, index, weight, grad, state): + if not isinstance(index, (tuple, list)): + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 + else: + use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 + self._update_impl(index, weight, grad, state, + multi_precision=use_multi_precision) + +# @register class LBSGD(Optimizer): """The Large Batch SGD optimizer with momentum and weight decay. @@ -812,7 +1075,7 @@ class LBSGD(Optimizer): warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear') warmup_epochs: unsigned, default: 5 - batch_scale: unsigned, default: 1 (same as batch size*numworkers) + batch_scale: unsigned, default: 1 (same as batch size * numworkers) updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.) begin_epoch: unsigned, default 0, starting epoch. """ diff --git a/src/operator/contrib/multi_lars-inl.h b/src/operator/contrib/multi_lars-inl.h new file mode 100644 index 000000000000..c78bd7086368 --- /dev/null +++ b/src/operator/contrib/multi_lars-inl.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file multi_lars-inl.h + * \brief vectorized lars coefficient computed from sums of squared weights and grads + * \author Clement Fuji Tsang + */ +#ifndef MXNET_OPERATOR_CONTRIB_MULTI_LARS_INL_H_ +#define MXNET_OPERATOR_CONTRIB_MULTI_LARS_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../tensor/init_op.h" +#include "../tensor/util/tensor_util-inl.h" + +namespace mxnet { +namespace op { + +struct LARSParam : public dmlc::Parameter { + float eta; + float eps; + float rescale_grad; + DMLC_DECLARE_PARAMETER(LARSParam) { + DMLC_DECLARE_FIELD(eta) + .describe("LARS eta"); + DMLC_DECLARE_FIELD(eps) + .describe("LARS eps"); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Gradient rescaling factor"); + } +}; + +struct MultiLARSKernel { + MSHADOW_XINLINE static void Map(int i, float* out_data, const float* lrs, + const float* weights_sum_sq, const float* grads_sum_sq, + const float* wds, const float eta, const float eps, + const float rescale_grad, const OpReqType req) { + float w_norm = sqrtf(weights_sum_sq[i]); + bool is_lars_valid = w_norm > 0. && grads_sum_sq[i] > 0.; + KERNEL_ASSIGN(out_data[i], req, is_lars_valid ? + lrs[i] * eta * w_norm / (sqrtf(grads_sum_sq[i]) * rescale_grad + wds[i] * w_norm + eps) : + lrs[i]); + } +}; + +template +inline void MultiLARS(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + auto param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + if (inputs[0].type_flag_ != mshadow::kFloat32) + LOG(FATAL) << "MultiLARS only support float"; + Tensor lrs = inputs[0].FlatTo2D(s); + Tensor weights_sum_sq = inputs[1].FlatTo2D(s); + Tensor grads_sum_sq = inputs[2].FlatTo2D(s); + Tensor wds = inputs[3].FlatTo2D(s); + Tensor out_data = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weights_sum_sq.shape_.Size(), out_data.dptr_, + lrs.dptr_, weights_sum_sq.dptr_, grads_sum_sq.dptr_, + wds.dptr_, param.eta, param.eps, + param.rescale_grad, req[0]); +} + +} // namespace op +} // namespace mxnet + + +#endif // MXNET_OPERATOR_CONTRIB_MULTI_LARS_INL_H_ diff --git a/src/operator/contrib/multi_lars.cc b/src/operator/contrib/multi_lars.cc new file mode 100644 index 000000000000..d4ee49358971 --- /dev/null +++ b/src/operator/contrib/multi_lars.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file multi_lars.cc + * \brief vectorized LARS coefficient computed from sums of squared weights and grads + * \author Clement Fuji Tsang + */ + +#include "./multi_lars-inl.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(LARSParam); + +NNVM_REGISTER_OP(multi_lars) +.describe(R"code(Compute the LARS coefficients of multiple weights and grads from their sums of square" +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", ElemwiseType<4, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<4, 1, false, false, false>) +.set_attr("FListInputNames", [](const nnvm::NodeAttrs& attrs) { + std::vector list_input_names = {"lrs", "weights_sum_sq", "grads_sum_sq", "wds"}; + return list_input_names; + }) +.set_attr("FCompute", MultiLARS) +.add_argument("lrs", "NDArray-or-Symbol", "Learning rates to scale by LARS coefficient") +.add_argument("weights_sum_sq", "NDArray-or-Symbol", "sum of square of weights arrays") +.add_argument("grads_sum_sq", "NDArray-or-Symbol", "sum of square of gradients arrays") +.add_argument("wds", "NDArray-or-Symbol", "weight decays") +.add_arguments(LARSParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/multi_lars.cu b/src/operator/contrib/multi_lars.cu new file mode 100644 index 000000000000..292f09092e76 --- /dev/null +++ b/src/operator/contrib/multi_lars.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file multi_lars.cu + * \brief vectorized lars coefficient computed from sums of squared weights and grads + * \author Clement Fuji Tsang + */ + +#include "./multi_lars-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(multi_lars) +.set_attr("FCompute", MultiLARS); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/multi_sum_sq-inl.h b/src/operator/contrib/multi_sum_sq-inl.h new file mode 100644 index 000000000000..876155215d1c --- /dev/null +++ b/src/operator/contrib/multi_sum_sq-inl.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file multi_l2_norm-inl.h + * \brief vectorized L2 norm over multiple arrays operators + * \author Clement Fuji Tsang, Andrei Ivanov + */ + + +#ifndef MXNET_OPERATOR_CONTRIB_MULTI_SUM_SQ_INL_H_ +#define MXNET_OPERATOR_CONTRIB_MULTI_SUM_SQ_INL_H_ + +#include +#include +#include "../operator_common.h" + +namespace mxnet { +namespace op { + +struct MultiSumSqParam : public dmlc::Parameter { + int num_arrays; + DMLC_DECLARE_PARAMETER(MultiSumSqParam) { + DMLC_DECLARE_FIELD(num_arrays) + .describe("number of input arrays."); + } +}; + +inline bool MultiSumSqShape(const NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const auto &p = dmlc::get(attrs.parsed); + out_shape->resize(1); + + SHAPE_ASSIGN_CHECK(*out_shape, 0, mxnet::TShape{p.num_arrays}); + + CHECK_EQ(in_shape->size(), p.num_arrays); + for (auto s : *in_shape) { + if (s.ndim() == 0) + return false; + } + return true; +} + +inline bool MultiSumSqType(const NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + const auto& p = dmlc::get(attrs.parsed); + CHECK_EQ(in_type->size(), p.num_arrays); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, "array_" + std::to_string(i)); + } + } + out_type->clear(); + out_type->push_back(mshadow::kFloat32); + return true; +} + +template +void MultiSumSqRun(const std::vector &inputs, int nInputs, + float *out_ptr, mshadow::Stream *s); + +template +void MultiSumSq(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto s = ctx.get_stream(); + const auto& p = dmlc::get(attrs.parsed); + float* out_ptr = outputs[0].FlatTo2D(s).dptr_; + MultiSumSqRun(inputs, p.num_arrays, out_ptr, s); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_MULTI_SUM_SQ_INL_H_ diff --git a/src/operator/contrib/multi_sum_sq.cc b/src/operator/contrib/multi_sum_sq.cc new file mode 100644 index 000000000000..cdb5423db23f --- /dev/null +++ b/src/operator/contrib/multi_sum_sq.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file multi_sum_sq.cc + * \brief vectorized sum or squared over multiple arrays operators + * \author Clement Fuji Tsang, Andrei Ivanov + */ + +#include "./multi_sum_sq-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MultiSumSqParam); + +NNVM_REGISTER_OP(multi_sum_sq) +.describe(R"code(Compute the sums of squares of multiple arrays +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + return static_cast(dmlc::get(attrs.parsed).num_arrays); + }) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSumSqShape) +.set_attr("FInferType", MultiSumSqType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const auto& param = dmlc::get(attrs.parsed); + const uint32_t num_args = param.num_arrays; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("array_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FCompute", MultiSumSq) +.add_argument("data", "NDArray-or-Symbol[]", "Arrays") +.add_arguments(MultiSumSqParam::__FIELDS__()); + +template +inline void CalcSumSq(const std::vector &inputs, int nInputs, + float *out_ptr, mshadow::Stream *s) { + int i; + size_t j; +#pragma omp parallel for private(i, j) + for (i = 0; i < nInputs; ++i) { // array index in inputs + float sum = 0; + const auto address = inputs[i].FlatTo2D(s).dptr_; + const auto jMax = inputs[i].shape_.Size(); + for (j = 0; j < jMax; ++j) + sum += address[j] * address[j]; + + out_ptr[i] = sum; + } +} + +template<> +void MultiSumSqRun(const std::vector &inputs, int nInputs, + float *out_ptr, mshadow::Stream *s) { + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, + CalcSumSq(inputs, nInputs, out_ptr, s); + ) +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/multi_sum_sq.cu b/src/operator/contrib/multi_sum_sq.cu new file mode 100644 index 000000000000..6f6fe56bfd81 --- /dev/null +++ b/src/operator/contrib/multi_sum_sq.cu @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file multi_sum_sq.cu + * \brief vectorized sums of squares norm over multiple arrays operators + * \author Clement Fuji Tsang, Andrei Ivanov + */ +#include "./multi_sum_sq-inl.h" +#include + +#define ILP 4 +#define BLOCK_LIMIT 320 +#define ARRAY_LIMIT 110 + +namespace mxnet { +namespace op { + +// Shamelessly gotten from: +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu +// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +template +struct MultiSumSqKernelParam { + DType* addresses[ARRAY_LIMIT]; + int sizes[ARRAY_LIMIT]; + unsigned char block_to_tensor[BLOCK_LIMIT]; + int block_to_chunk[BLOCK_LIMIT]; +}; + +template +__device__ __forceinline__ DType reduce_block_into_lanes(DType* x, + DType val, + int lanes = 1, + bool share_result = false) { + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + + #pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) + x[tid] = x[tid] + x[tid+i]; + __syncthreads(); + } + + DType final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid+32]; + else + final = val; + // __SYNCWARP(); + + #pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__global__ void MultiSumSqKernel(int chunk_size, + MultiSumSqKernelParam param, + float* output) { + const int tensor_loc = param.block_to_tensor[blockIdx.x]; + const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size; + const int n = param.sizes[tensor_loc] - chunk_len; + const DType* x = param.addresses[tensor_loc] + chunk_len; + const auto iMax = n <= chunk_size? n : chunk_size; + __shared__ float vals[512]; + + // Non-divergent exit condition for __syncthreads, not necessary here + float val = 0; + for (int i_start = 0; + i_start < iMax; + i_start += blockDim.x * ILP) { + int i = i_start + threadIdx.x; + // #pragma unroll + for (int ii = 0; ii < ILP && i < iMax; ++ii, i += blockDim.x) { + const auto incoming_val = static_cast(x[i]); + val += incoming_val * incoming_val; + } + } + + const float final = reduce_block_into_lanes(vals, val); + if (threadIdx.x == 0) + atomicAdd(output + tensor_loc, final); +} + +template<> +void MultiSumSqRun(const std::vector &inputs, int nInputs, + float *out_ptr, mshadow::Stream *s) { + const int chunk_size = 32768; + const int block_size = 512; + using namespace mxnet_op; + auto stream = mshadow::Stream::GetStream(s); + CUDA_CALL(cudaMemsetAsync(out_ptr, 0, nInputs * sizeof(float), stream)); + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MultiSumSqKernelParam param; + int loc_block_info = 0; // position in param.block_to_tensor and param.block_to_chunck + int loc_tensor_info = 0; // position in param.sizes and param.addresses + int output_offset = 0; // array index of the first block pointed on by param.addresses + for (int t = 0; t < nInputs; t++, loc_tensor_info++) { // array index in inputs + param.sizes[loc_tensor_info] = inputs[t].shape_.Size(); + param.addresses[loc_tensor_info] = inputs[t].FlatTo2D(s).dptr_; + const int chunks_this_tensor = (inputs[t].shape_.Size() - 1) / chunk_size; + for (int chunk = 0; chunk <= chunks_this_tensor; ++chunk) { // array chunk index + param.block_to_tensor[loc_block_info] = loc_tensor_info; + param.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const bool last_curr_chunk = chunk == chunks_this_tensor; + const bool tensors_full = last_curr_chunk && loc_tensor_info == 109; + const bool blocks_full = (loc_block_info == 320); + const bool last_chunk = last_curr_chunk && t == nInputs - 1; + if (!(tensors_full || blocks_full || last_chunk)) + continue; + + MultiSumSqKernel<<>> + (chunk_size, param, out_ptr + output_offset); + MSHADOW_CUDA_POST_KERNEL_CHECK(MultiSumSqKernel); + loc_block_info = 0; + if (last_curr_chunk) { // if you start from a new tensor + loc_tensor_info = -1; + output_offset = t + 1; + } else { // if you start from the same tensor + param.sizes[0] = param.sizes[loc_tensor_info]; + param.addresses[0] = param.addresses[loc_tensor_info]; + loc_tensor_info = 0; + output_offset = t; + } + } + } + }); +} + +NNVM_REGISTER_OP(multi_sum_sq) +.set_attr("FCompute", MultiSumSq); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/preloaded_multi_sgd-inl.h b/src/operator/contrib/preloaded_multi_sgd-inl.h new file mode 100644 index 000000000000..840ca215c520 --- /dev/null +++ b/src/operator/contrib/preloaded_multi_sgd-inl.h @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file preloaded_multi_sgd-inl.h + * \brief Multi-sgd optimizers with lrs and wds as mxnet inputs + * \author Clement Fuji Tsang + */ +#ifndef MXNET_OPERATOR_CONTRIB_PRELOADED_MULTI_SGD_INL_H_ +#define MXNET_OPERATOR_CONTRIB_PRELOADED_MULTI_SGD_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../mxnet_op.h" +#include "../tensor/init_op.h" +#include "../tensor/util/tensor_util-inl.h" + +namespace mxnet { +namespace op { + +struct PreloadedMultiSGDParam : public dmlc::Parameter { + float rescale_grad; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(PreloadedMultiSGDParam) { + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + +struct PreloadedMultiSGDMomParam : public dmlc::Parameter { + float momentum; + float rescale_grad; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(PreloadedMultiSGDMomParam) { + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + +template +inline bool PreloadedMultiSGDShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights + 2); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_shapes = *in_attrs; + auto& output_shapes = *out_attrs; + // Learning rates + CHECK_EQ(in_attrs->at(param.num_weights * input_stride).Size(), param.num_weights) + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " + << param.num_weights << ", and got " << in_attrs->at(param.num_weights * input_stride).Size(); + // Weight decays + CHECK_EQ(in_attrs->at(param.num_weights * input_stride + 1).Size(), param.num_weights) + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " + << param.num_weights << ", and got " + << in_attrs->at(param.num_weights * input_stride + 1).Size(); + // Weights and gradients + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_shapes[i]}); + for (int j = 0; j < input_stride; ++j) { + input_vec.push_back(input_shapes[i * input_stride + j]); + } + all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); + } + return all_inferred; +} + +template +inline bool MP_PreloadedMultiSGD_InferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights + 2); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_types = *in_attrs; + auto& output_types = *out_attrs; + // Weights and gradients + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_types[i]}); + for (int j = 0; j < input_stride - num_fp32_inputs; ++j) { + input_vec.push_back(input_types[i * input_stride + j]); + } + all_inferred = all_inferred && + ElemwiseType(attrs, &input_vec, &output_vec); + } + // master copies of weights + for (int i = 0; i < param.num_weights; ++i) { + for (int j = 0; j < num_fp32_inputs; ++j) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32); + } + } + TYPE_ASSIGN_CHECK(input_types, input_stride * param.num_weights, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * param.num_weights + 1, mshadow::kFloat32); + return all_inferred; +} + +template +struct PreloadedMultiSGDKernelParam { + static const int N = 60; + int count; + size_t max_size; + size_t sizes[N]; + DType * weights[N]; + DType * grads[N]; + MPDType * mom[N]; + MPDType * weights32[N]; + DType * out_data[N]; + float * lrs; + float * wds; + MPDType clip_gradient; + MPDType rescale_grad; + MPDType momentum; +}; + +template +struct PreloadedMultiSGDKernel { + template + MSHADOW_XINLINE static void Map(int i, const PreloadedMultiSGDKernelParam& param, + const OpReqType req) { + for (int index = 0; index < param.count; ++index) { + if ((size_t)i < param.sizes[index]) { + MPDType w = has_mixed_precision ? param.weights32[index][i] : + MPDType(param.weights[index][i]); + MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0); + if (param.clip_gradient >= 0.0f) { + mom = param.momentum*mom + - param.lrs[index]*param.wds[index]*w + - param.lrs[index] + *mshadow_op::clip::Map(param.rescale_grad * + static_cast(param.grads[index][i]), + param.clip_gradient); + } else { + mom = param.momentum*mom + - param.lrs[index]*param.wds[index]*w + - param.lrs[index]*param.rescale_grad*static_cast(param.grads[index][i]); + } + if (has_momentum) { + param.mom[index][i] = mom; + } + w = w + mom; + if (has_mixed_precision) { + param.weights32[index][i] = w; + } + KERNEL_ASSIGN(param.out_data[index][i], req, w); + } + } + } +}; + +template +PreloadedMultiSGDKernelParam FillPreloadedMultiSGDKernelParam( + const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, + const std::vector &outputs) { + using namespace mxnet_op; + const ParamType& p = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + PreloadedMultiSGDKernelParam param; + param.clip_gradient = p.clip_gradient; + param.rescale_grad = p.rescale_grad; + param.momentum = 0; + param.count = p.num_weights; + param.max_size = 0; + for (int i = 0; i < param.count; ++i) { + param.sizes[i] = inputs[i * input_stride].shape_.Size(); + if (param.max_size < param.sizes[i]) { + param.max_size = param.sizes[i]; + } + param.weights[i] = inputs[i * input_stride].FlatTo2D(s).dptr_; + param.grads[i] = inputs[i * input_stride + 1].FlatTo2D(s).dptr_; + // if mixed precision, then the last input in a set + // is 32-bit master copy of the weights + if (!std::is_same::value) { + param.weights32[i] = inputs[i * input_stride + input_stride - 1] + .FlatTo2D(s).dptr_; + } + param.out_data[i] = outputs[i].FlatTo2D(s).dptr_; + } + const int lrs_idx = param.count * input_stride; + const int wds_idx = param.count * input_stride + 1; + param.lrs = inputs[lrs_idx].FlatTo2D(s).dptr_; + param.wds = inputs[wds_idx].FlatTo2D(s).dptr_; + return param; +} + + +template +PreloadedMultiSGDKernelParam FillPreloadedMultiSGDMomKernelParam( + const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, + const std::vector &outputs) { + using namespace mxnet_op; + const PreloadedMultiSGDMomParam& p = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + PreloadedMultiSGDKernelParam param = + FillPreloadedMultiSGDKernelParam(attrs, ctx, inputs, outputs); + param.momentum = p.momentum; + for (int i = 0; i < param.count; ++i) { + param.mom[i] = inputs[i * input_stride + 2].FlatTo2D(s).dptr_; + } + + return param; +} + +template +class preloaded_type_identity { + public: + using type = T; +}; + +template +class preloaded_single_precision { + public: + using type = float; +}; + +template class MPTypeChooser, int input_stride> +inline void PreloadedMultiSGDUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + PreloadedMultiSGDKernelParam param = + FillPreloadedMultiSGDKernelParam(attrs, ctx, inputs, outputs); + Kernel::value>, + xpu>::Launch(s, param.max_size, param, req[0]); + }); +} + +template class MPTypeChooser, int input_stride> +inline void PreloadedMultiSGDMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + PreloadedMultiSGDKernelParam param = + FillPreloadedMultiSGDMomKernelParam(attrs, ctx, inputs, outputs); + Kernel::value>, + xpu>::Launch(s, param.max_size, param, req[0]); + }); +} + +} // namespace op +} // namespace mxnet + + +#endif // MXNET_OPERATOR_CONTRIB_PRELOADED_MULTI_SGD_INL_H_ diff --git a/src/operator/contrib/preloaded_multi_sgd.cc b/src/operator/contrib/preloaded_multi_sgd.cc new file mode 100755 index 000000000000..768d288c7c27 --- /dev/null +++ b/src/operator/contrib/preloaded_multi_sgd.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file preloaded_multi_sgd.cc + * \brief Multi-sgd optimizers with lrs and wds as mxnet inputs + * \author Clement Fuji Tsang + */ +#include "./preloaded_multi_sgd-inl.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(PreloadedMultiSGDParam); +DMLC_REGISTER_PARAMETER(PreloadedMultiSGDMomParam); + +NNVM_REGISTER_OP(preloaded_multi_sgd_update) +.describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 2 + 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", PreloadedMultiSGDShape) +.set_attr("FInferType", + MP_PreloadedMultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + ret.reserve(num_args * 2 + 2); + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + } + ret.emplace_back("lrs"); + ret.emplace_back("wds"); + return ret; + }) +.set_attr("FCompute", PreloadedMultiSGDUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients, learning rates and weight decays") +.add_arguments(PreloadedMultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(preloaded_multi_sgd_mom_update) +.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 3 + 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", PreloadedMultiSGDShape) +.set_attr("FInferType", + MP_PreloadedMultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + ret.reserve(num_args * 3 + 2); + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + } + ret.emplace_back("lrs"); + ret.emplace_back("wds"); + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const PreloadedMultiSGDMomParam& param = dmlc::get(attrs.parsed); + ret.reserve(param.num_weights); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr("FCompute", PreloadedMultiSGDMomUpdate) +.add_argument("data", "NDArray-or-Symbol[]", + "Weights, gradients, momentum, learning rates and weight decays") +.add_arguments(PreloadedMultiSGDMomParam::__FIELDS__()); + +NNVM_REGISTER_OP(preloaded_multi_mp_sgd_update) +.describe(R"code(Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 3 + 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", PreloadedMultiSGDShape) +.set_attr("FInferType", + MP_PreloadedMultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + ret.reserve(num_args * 3 + 2); + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + ret.emplace_back("lrs"); + ret.emplace_back("wds"); + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const PreloadedMultiSGDParam& param = dmlc::get(attrs.parsed); + ret.reserve(param.num_weights); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr("FCompute", PreloadedMultiSGDUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients, learning rates and weight decays") +.add_arguments(PreloadedMultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(preloaded_multi_mp_sgd_mom_update) +.describe(R"code(Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 4 + 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const PreloadedMultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", PreloadedMultiSGDShape) +.set_attr("FInferType", + MP_PreloadedMultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + ret.reserve(num_args * 4 + 2); + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + ret.emplace_back("lrs"); + ret.emplace_back("wds"); + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const PreloadedMultiSGDMomParam& param = dmlc::get(attrs.parsed); + ret.reserve(param.num_weights * 2); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) +.set_attr("FCompute", PreloadedMultiSGDMomUpdate) +.add_argument("data", "NDArray-or-Symbol[]", + "Weights, gradients, momentums, learning rates and weight decays") +.add_arguments(PreloadedMultiSGDMomParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/preloaded_multi_sgd.cu b/src/operator/contrib/preloaded_multi_sgd.cu new file mode 100644 index 000000000000..3335d632cd98 --- /dev/null +++ b/src/operator/contrib/preloaded_multi_sgd.cu @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file preloaded_multi_sgd.cu + * \brief Multi-sgd optimizers with lrs and wds as mxnet inputs + * \author Clement Fuji Tsang + */ +#include "./preloaded_multi_sgd-inl.h" +#include + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(preloaded_multi_sgd_update) +.set_attr("FCompute", PreloadedMultiSGDUpdate); +NNVM_REGISTER_OP(preloaded_multi_sgd_mom_update) +.set_attr("FCompute", PreloadedMultiSGDMomUpdate); +NNVM_REGISTER_OP(preloaded_multi_mp_sgd_update) +.set_attr("FCompute", PreloadedMultiSGDUpdate); +NNVM_REGISTER_OP(preloaded_multi_mp_sgd_mom_update) +.set_attr("FCompute", + PreloadedMultiSGDMomUpdate); + +} // namespace op +} // namespace mxnet diff --git a/tests/nightly/test_optimizer.py b/tests/nightly/test_optimizer.py new file mode 100644 index 000000000000..c4e264b79b65 --- /dev/null +++ b/tests/nightly/test_optimizer.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx + +import sys +import os +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../unittest')) +from common import setup_module, with_seed + +# This script is testing the efficiency of LARS +# We are training LeNet-5 at batch-size 8000 in 10 epochs above 98% accuracy +# Which is not doable with simple SGD + momentum (from what have been tested so far) + +def lenet5(): + """LeNet-5 Symbol""" + #pylint: disable=no-member + data = mx.sym.Variable('data') + conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20) + tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") + pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", + kernel=(2, 2), stride=(2, 2)) + # second conv + conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50) + tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") + pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", + kernel=(2, 2), stride=(2, 2)) + # first fullc + flatten = mx.sym.Flatten(data=pool2) + fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500) + tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") + # second fullc + fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) + # loss + lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') + #pylint: enable=no-member + return lenet + +@with_seed() +def test_lars(): + num_epochs = 10 + batch_size = 8000 + mnist = mx.test_utils.get_mnist() + train_iter = mx.io.NDArrayIter(mnist['train_data'], + mnist['train_label'], + batch_size, + shuffle=True) + test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + ctx = mx.gpu(0) + lenet_model = mx.mod.Module(lenet5(), context=ctx) + warmup_epochs = 1 + epoch_it = int(train_iter.num_data / batch_size) + # LARS works best with Polynomial scheduler and warmup + base_lr = 0.01 + optimizer_params={ + 'learning_rate': base_lr, + 'lr_scheduler': mx.lr_scheduler.PolyScheduler(base_lr=base_lr, + max_update=epoch_it * num_epochs, + warmup_steps=epoch_it * warmup_epochs), + 'momentum': 0.9, + 'eta': 14., + } + lenet_model.fit(train_iter, + eval_data=test_iter, + optimizer='lars', + optimizer_params=optimizer_params, + eval_metric='acc', + num_epoch=num_epochs) + + # predict accuracy for lenet + acc = mx.metric.Accuracy() + lenet_model.score(test_iter, acc) + accuracy = acc.get()[1] + assert accuracy > 0.98, "LeNet-5 training accuracy on MNIST was too low" + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index e3862abbebd1..7aac23acd549 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -268,6 +268,159 @@ def test_fft(): shape = tuple(np.random.randint(1, maxdim, size=order)) check_fft(shape) +def _make_ndarrays(input_list, ctx=mx.gpu(0)): + return [mx.nd.array(arr, dtype=arr.dtype, ctx=ctx) for arr in input_list] + +def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2): + weights_arr = [np.random.rand(*shape).astype(w_dtype) * 10. for shape in shapes] + grads_arr = [np.random.rand(*shape).astype(g_dtype) for shape in shapes] + + lrs = (np.random.rand(len(shapes)).astype('float32') + 0.1) / 100. + wds = (np.random.rand(len(shapes)).astype('float32') + 0.1) / 1000. + eta = (np.random.rand() + 0.1) + eps = (np.random.rand() + 0.1) / 10000. + + mx_w = _make_ndarrays(weights_arr, ctx=ctx) + mx_g = _make_ndarrays(grads_arr, ctx=ctx) + mx_lrs = mx.nd.array(lrs, dtype='float32', ctx=ctx) + mx_wds = mx.nd.array(wds, dtype='float32', ctx=ctx) + + w_sum_sq = mx.nd.multi_sum_sq(*mx_w, num_arrays=len(shapes)) + g_sum_sq = mx.nd.multi_sum_sq(*mx_g, num_arrays=len(shapes)) + + ref_w_sum_sq = mx.nd.array([(w.astype('float32') ** 2).sum() for w in weights_arr], + dtype='float32', ctx=ctx) + ref_g_sum_sq = mx.nd.array([(g.astype('float32') ** 2).sum() for g in grads_arr], + dtype='float32', ctx=ctx) + assert_almost_equal(ref_w_sum_sq.asnumpy(), w_sum_sq.asnumpy(), atol=tol1, rtol=tol1) + assert_almost_equal(ref_g_sum_sq.asnumpy(), g_sum_sq.asnumpy(), atol=tol1, rtol=tol1) + + rescale_grad = (np.random.rand() + 0.5) * 100. + mx_new_lrs = mx.nd.multi_lars(mx_lrs, w_sum_sq, g_sum_sq, mx_wds, eta=eta, eps=eps, + rescale_grad=rescale_grad) + ref_w_l2norm = mx.nd.sqrt(ref_w_sum_sq) + ref_g_l2norm = mx.nd.sqrt(ref_g_sum_sq * rescale_grad * rescale_grad) + ref_new_lrs = mx.nd.zeros(ref_w_l2norm.shape, dtype='float32', ctx=ctx) + for i in range(ref_w_l2norm.size): + _w = ref_w_l2norm[i] + _g = ref_g_l2norm[i] + if _w > 0.0 and _g > 0.0: + ref_new_lrs[i] = lrs[i] * eta * _w / (_g + wds[i] * _w + eps) + else: + ref_new_lrs[i] = lrs[i] + assert_almost_equal(ref_new_lrs.asnumpy(), mx_new_lrs.asnumpy(), atol=tol2, rtol=tol2) + +@with_seed() +def test_fast_lars(): + min_nparam = 50 + max_nparam = 60 + maxdim = 10000 + maxndim = 1 + + dtypes = ['float16','float32', 'float64'] + for ctx in [mx.cpu(0), mx.gpu(0)]: + for w_dtype in dtypes: + for g_dtype in dtypes: + nparam = np.random.randint(min_nparam + 1, max_nparam + 1) + shapes = [np.random.randint(1, maxdim + 1, size=maxndim) for i in range(nparam)] + lowTol = ctx == mx.cpu(0) and ('float16'in [w_dtype, g_dtype]) + tol1 = 1e-3 if lowTol else 1e-5 + tol2 = 1e-6 if lowTol else 1e-7 + check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2) + +def check_preloaded_multi_sgd(dtype, shapes, momentum, use_master_weights): + def _flatten_list(nested_list): + return [item for sublist in nested_list for item in sublist] + weights_arr = [np.random.rand(*shape).astype(dtype) * 100. for shape in shapes] + grads_arr = [np.random.rand(*shape).astype(dtype) * 100. for shape in shapes] + rescale_grad = (np.random.random() + 1.0) + mx_w = _make_ndarrays(weights_arr) + mx_g = _make_ndarrays(grads_arr) + mx_p_w = _make_ndarrays(weights_arr) + mx_p_g = _make_ndarrays(grads_arr) + lrs = list((np.random.random(size=len(shapes)).astype('float32') + 0.1) / 100.) + mx_lrs = mx.nd.array(lrs, dtype='float32', ctx=mx.gpu(0)) + wds = list((np.random.random(size=len(shapes)).astype('float32') + 0.1) / 1000.) + mx_wds = mx.nd.array(wds, dtype='float32', ctx=mx.gpu(0)) + if use_master_weights: + weights32_arr = [arr.astype('float32') for arr in weights_arr] + mx_w32 = _make_ndarrays(weights32_arr) + mx_p_w32 = _make_ndarrays(weights32_arr) + if momentum is None: + if use_master_weights: + mx.nd.multi_mp_sgd_update( + *_flatten_list(zip(mx_w, mx_g, mx_w32)), + num_weights=len(shapes), lrs=lrs, wds=wds, + rescale_grad=rescale_grad, out=mx_w) + mx.nd.preloaded_multi_mp_sgd_update( + *(_flatten_list(zip(mx_p_w, mx_p_g, mx_p_w32)) + + [mx_lrs, mx_wds]), num_weights=len(shapes), + rescale_grad=rescale_grad, out=mx_p_w) + else: + out = mx.nd.multi_sgd_update( + *_flatten_list(zip(mx_w, mx_g)), + num_weights=len(shapes), lrs=lrs, wds=wds, + rescale_grad=rescale_grad, out=mx_w) + preloaded_out = mx.nd.preloaded_multi_sgd_update( + *(_flatten_list(zip(mx_p_w, mx_p_g)) + + [mx_lrs, mx_wds]), num_weights=len(shapes), + rescale_grad=rescale_grad, out=mx_p_w) + else: + if use_master_weights: + momentums_arr = [np.random.rand(*shape).astype("float32") for shape in shapes] + mx_m = _make_ndarrays(momentums_arr) + mx_p_m = _make_ndarrays(momentums_arr) + out = mx.nd.multi_mp_sgd_mom_update( + *_flatten_list(zip(mx_w, mx_g, mx_m, mx_w32)), + num_weights=len(shapes), lrs=lrs, wds=wds, + rescale_grad=0.95, momentum=momentum, out=mx_w) + preloaded_out = mx.nd.preloaded_multi_mp_sgd_mom_update( + *(_flatten_list(zip(mx_p_w, mx_p_g, mx_p_m, mx_p_w32)) + + [mx_lrs, mx_wds]), num_weights=len(shapes), + rescale_grad=0.95, momentum=momentum, out=mx_p_w) + else: + momentums_arr = [np.random.rand(*shape).astype(dtype) for shape in shapes] + mx_m = _make_ndarrays(momentums_arr) + mx_p_m = _make_ndarrays(momentums_arr) + mx.nd.multi_sgd_mom_update( + *_flatten_list(zip(mx_w, mx_g, mx_m)), + num_weights=len(shapes), lrs=lrs, wds=wds, + rescale_grad=0.95, momentum=momentum, out=mx_w) + mx.nd.preloaded_multi_sgd_mom_update( + *(_flatten_list(zip(mx_p_w, mx_p_g, mx_p_m)) + + [mx_lrs, mx_wds]), num_weights=len(shapes), + rescale_grad=0.95, momentum=momentum, out=mx_p_w) + + def _assert_all_almost_equal(lhs_list, rhs_list, rtol, atol): + for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)): + assert_almost_equal(lhs.asnumpy(), rhs.asnumpy(), rtol=rtol, atol=atol) + if dtype == 'float16': + rtol = 1e-3 + atol = 1e-3 + else: + rtol = 1e-5 + atol = 1e-6 + _assert_all_almost_equal(mx_p_w, mx_w, rtol, atol) + if momentum is not None: + _assert_all_almost_equal(mx_p_m, mx_m, rtol, atol) + if use_master_weights: + _assert_all_almost_equal(mx_p_w32, mx_w32, 1e-5, 1e-6) + +@with_seed() +def test_preloaded_multi_sgd(): + dtypes = ['float16', 'float32'] + momentums = [None, 0.9] + min_nparam = 5 + max_nparam = 10 + maxdim = 6 + maxndim = 4 + for dtype in dtypes: + use_master_weights_list = [False,] if dtype == 'float32' else [True, False] + for use_master_weights in use_master_weights_list: + for momentum in momentums: + nparam = np.random.randint(min_nparam + 1, max_nparam + 1) + shapes = [np.random.randint(1, maxdim + 1, size=maxndim) for i in range(nparam)] + check_preloaded_multi_sgd(dtype, shapes, momentum, use_master_weights) @with_seed() def test_batchnorm_with_type():