diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 601bc682db38..0297a1dbe575 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -21,6 +21,7 @@ from __future__ import absolute_import import math import numpy as np +import mxnet as mx from ..context import current_context from ..random import uniform from ..base import _as_list @@ -32,6 +33,9 @@ __all__ = ["rand_zipfian", "foreach", "while_loop", "cond", "isinf", "isfinite", "isnan"] +def _flatten_list(nested_list): + return [item for sublist in nested_list for item in sublist] + # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -514,7 +518,7 @@ def isfinite(data): [0. 0. 0. 1.] """ - is_data_not_nan = data == data # pylint: disable=comparison-with-itself + is_data_not_nan = data == data # pylint: disable=comparison-with-itself is_data_not_infinite = data.abs() != np.inf return ndarray.logical_and(is_data_not_infinite, is_data_not_nan) @@ -542,14 +546,17 @@ def isnan(data): [1. 0.] """ - return data != data # pylint: disable=comparison-with-itself + return data != data # pylint: disable=comparison-with-itself -def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, - epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): +def _get_rescale_grad(rescale_grad, ctx=mx.cpu()): if not isinstance(rescale_grad, ndarray.NDArray): - rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) + return ndarray.full(shape=(1,), val=rescale_grad, ctx=ctx) else: - rescale_grad = rescale_grad.as_in_context(weight.context) + return rescale_grad.as_in_context(ctx) + +def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, + epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context) return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var, rescale_grad=rescale_grad, lr=lr, eta=eta, beta1=beta1, beta2=beta2, epsilon=epsilon, @@ -559,13 +566,42 @@ def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): - if not isinstance(rescale_grad, ndarray.NDArray): - rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) - else: - rescale_grad = rescale_grad.as_in_context(weight.context) + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context) return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var, weight32=weight32, rescale_grad=rescale_grad, lr=lr, eta=eta, beta1=beta1, beta2=beta2, epsilon=epsilon, wd=wd, clip_gradient=clip_gradient, out=out, name=name, **kwargs) + +def multi_adamw_update(weights, grads, mean, var, rescale_grad, lrs, wds, etas, + out=None, name=None, size=0, **kwargs): + if not size: + size = len(weights) + + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context) + temp_list = _flatten_list(zip(weights, grads, mean, var)) + [rescale_grad] + return ndarray._internal._multi_adamw_update(*temp_list, + out=out, + num_weights=size, + lrs=lrs, + wds=wds, + etas=etas, + name=name, + **kwargs) + +def multi_mp_adamw_update(weights, grads, mean, var, weights32, rescale_grad, lrs, wds, etas, + out=None, name=None, size=0, **kwargs): + if not size: + size = len(weights) + + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context) + temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + [rescale_grad] + return ndarray._internal._multi_mp_adamw_update(*temp_list, + out=out, + num_weights=size, + lrs=lrs, + wds=wds, + etas=etas, + name=name, + **kwargs) diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 6ae9e46b7def..fd139de3390f 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -18,26 +18,17 @@ */ /*! - * Copyright (c) 2016 by Contributors - * \file optimizer_op-inl.h + * Copyright (c) 2018 by Contributors + * \file adamw-inl.h * \brief Optimizer operators - * \author Haibin Lin + * \author Haibin Lin, Moises Hernandez, Andrei Ivanov */ #ifndef MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ #define MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ -#include #include -#include -#include -#include -#include -#include #include -#include -#include "../operator_common.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" -#include "../mxnet_op.h" namespace mxnet { namespace op { @@ -87,17 +78,12 @@ inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; - // rescale_grad.shape = () SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mxnet::TShape()); // TODO(@reminisce): change "none" behavior in ElemwiseAttr return ElemwiseAttr( attrs, in_attrs, out_attrs, mxnet::TShape()); } -// rescale_grad is a reserved argument at position -1. Example: -// n_in = 2: weight, grad (fp16) -// n_out = 1: weight (fp16) -// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) template inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -120,20 +106,14 @@ struct MPAdamWKernel { const float param_eta, const float param_lr, const float param_wd, const float param_rescale_grad, const float param_epsilon) { float w = weight32[i]; - float mean = mean_data[i]; - float var = var_data[i]; float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); - if (param_clip_gradient >= 0.0f) { - mean = param_beta1 * mean + - (1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient); - var = param_beta2 * var + (1 - param_beta2) * - mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient)); - } else { - mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad; - var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad); - } - mean_data[i] = mean; - var_data[i] = var; + if (param_clip_gradient >= 0.0f) + scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient); + + float mean = mean_data[i] = param_beta1 * mean_data[i] + (1.0f - param_beta1) * scaled_grad; + float var = var_data[i] = param_beta2 * var_data[i] + + (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad); + w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + param_wd * w); weight32[i] = w; @@ -141,7 +121,6 @@ struct MPAdamWKernel { } }; - template struct MPAdamWUpdate { static inline void Forward(const nnvm::NodeAttrs& attrs, @@ -151,7 +130,7 @@ struct MPAdamWUpdate { const std::vector &outputs, const float rescale_grad) { using namespace mxnet_op; - AdamWParam param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { Tensor weight = inputs[0].FlatTo2D(s); @@ -183,25 +162,22 @@ struct AdamWUpdate { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; - const AdamWParam& param = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); + const Tensor &weight = inputs[0].FlatTo2D(s); Tensor grad = inputs[1].FlatTo2D(s); Tensor mean = inputs[2].FlatTo2D(s); Tensor var = inputs[3].FlatTo2D(s); Tensor out = outputs[0].FlatTo2D(s); grad = scalar(rescale_grad) * grad; - if (param.clip_gradient >= 0.0f) { - mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * - F(grad, DType(param.clip_gradient)); - var = scalar(param.beta2)*var + scalar(1.f-param.beta2)*F( - F(grad, DType(param.clip_gradient))); - } else { - mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; - var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); - } + if (param.clip_gradient >= 0.0f) + grad = F(grad, DType(param.clip_gradient)); + + mean = scalar(param.beta1) * mean + scalar(1.f-param.beta1) * grad; + var = scalar(param.beta2) * var + scalar(1.f-param.beta2) * F(grad); + Assign(out, req[0], weight - scalar(param.eta) * (scalar(param.lr) * @@ -211,6 +187,312 @@ struct AdamWUpdate { } }; +//// +// Multiple gradients in single kernel +//// +struct MultiAdamWParam : public dmlc::Parameter { + mxnet::Tuple lrs; + mxnet::Tuple wds; + mxnet::Tuple etas; + float beta1; + float beta2; + float epsilon; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(MultiAdamWParam) { + DMLC_DECLARE_FIELD(lrs) + .describe("Learning rates"); + DMLC_DECLARE_FIELD(beta1) + .set_default(0.9f) + .describe("The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2) + .set_default(0.999f) + .describe("The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon) + .set_default(1e-8f) + .describe("A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wds) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(etas) + .describe("Learning rates schedule multiplier"); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + + +template +inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_shapes = *in_attrs; + auto& output_shapes = *out_attrs; + + // Learning rates + CHECK_EQ(param.lrs.ndim(), param.num_weights) + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " + << param.num_weights << ", and got " << param.lrs.ndim(); + // Weight decays + CHECK_EQ(param.wds.ndim(), param.num_weights) + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " + << param.num_weights << ", and got " << param.wds.ndim(); + // Learning rates schedule multiplier + CHECK_EQ(param.etas.ndim(), param.num_weights) + << "Number of learning rates schedule multiplier is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates schedule multiplier: " + << param.num_weights << ", and got " << param.lrs.ndim(); + + // Weights, gradients, mean and variance + for (int i = 0; i < param.num_weights; ++i) { + mxnet::ShapeVector input_vec; + mxnet::ShapeVector output_vec({output_shapes[i]}); + for (int j = 0; j < input_stride; ++j) { + input_vec.push_back(input_shapes[i * input_stride + j]); + } + all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); + } + + SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights*input_stride, mxnet::TShape()); + return all_inferred; +} + +template +inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_types = *in_attrs; + auto& output_types = *out_attrs; + + // Weights, gradients, + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_types[i]}); + for (int j = 0; j < input_stride - 2 - num_fp32_inputs; ++j) { + input_vec.push_back(input_types[i * input_stride + j]); + } + all_inferred = all_inferred && + ElemwiseType(attrs, &input_vec, &output_vec); + } + // mean, var + for (int i = 0; i < param.num_weights; ++i) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i +2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i +3, mshadow::kFloat32); + } + + // master copies of weights + for (int i = 0; i < param.num_weights; ++i) { + for (int j = 0; j < num_fp32_inputs; ++j) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32); + } + } + + TYPE_ASSIGN_CHECK(input_types, param.num_weights*input_stride, mshadow::kFloat32); + return all_inferred; +} + + +template +class Adam_type_identity { + public: + using type = T; +}; + + +template +class Adam_single_precision { + public: + using type = float; +}; + +template +struct MultiAdamKernelParam { + static const int N = 50; + int count; + size_t max_size; + size_t sizes[N]; + DType* weights[N]; + DType* grad_data[N]; + MPDType* mean_data[N]; + MPDType* var_data[N]; + MPDType* weights32[N]; + DType* out_data[N]; + MPDType clip_gradient; + MPDType beta1; + MPDType beta2; + MPDType etas[N]; + MPDType lrs[N]; + MPDType wds[N]; + MPDType epsilon; +}; + +template +struct MultiMPAdamWKernel { + template + MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam& param, + const OpReqType req, const float rescale_grad){ + for (int index = 0; index < param.count; ++index) { + if ((size_t)i < param.sizes[index]) { + MPDType w = has_mixed_precision ? param.weights32[index][i]: + MPDType(param.weights[index][i]); + MPDType scaled_grad = static_cast(rescale_grad)* + static_cast(param.grad_data[index][i]); + + if (param.clip_gradient >= 0.0f) + scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient); + + const auto mean = param.beta1 * (param.mean_data[index][i]- scaled_grad) + scaled_grad; + const auto adj = mshadow_op::square::Map(scaled_grad); + const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj; + + param.mean_data[index][i] = mean; + param.var_data[index][i] = var; + w = w - param.etas[index] * (param.lrs[index] * + mean / (mshadow_op::square_root::Map(var) + param.epsilon) + + param.wds[index] * w); + if (has_mixed_precision) + param.weights32[index][i] = w; + + KERNEL_ASSIGN(param.out_data[index][i], req, w); + } + } + } +}; + +template +void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + MultiAdamKernelParam *pParam) { + const ParamType& p = nnvm::get(attrs.parsed); + mxnet_op::Stream* s = ctx.get_stream(); + pParam->clip_gradient = p.clip_gradient; + pParam->beta1 = p.beta1; + pParam->beta2 = p.beta2; + + pParam->epsilon = p.epsilon; + + pParam->count = p.num_weights; + pParam->max_size = 0; + constexpr bool isSame = std::is_same::value; + for (int i = 0; i < pParam->count; ++i) { + const auto idx = i * input_stride; + pParam->sizes[i] = inputs[idx].shape_.Size(); + if (pParam->max_size < pParam->sizes[i]) + pParam->max_size = pParam->sizes[i]; + + pParam->weights[i] = inputs[idx].FlatTo2D(s).dptr_; + pParam->grad_data[i] = inputs[idx + 1].FlatTo2D(s).dptr_; + pParam->mean_data[i] = inputs[idx + 2].FlatTo2D(s).dptr_; + pParam->var_data[i] = inputs[idx + 3].FlatTo2D(s).dptr_; + // if mixed precision, then the last input in a set + // is 32-bit master copy of the weights + if (!isSame) + pParam->weights32[i] = inputs[idx + input_stride - 1].FlatTo2D(s).dptr_; + + pParam->out_data[i] = outputs[i].FlatTo2D(s).dptr_; + } + memcpy(pParam->etas, p.etas.begin(), pParam->count * sizeof(p.etas[0])); + memcpy(pParam->lrs, p.lrs.begin(), pParam->count * sizeof(p.lrs[0])); + memcpy(pParam->wds, p.wds.begin(), pParam->count * sizeof(p.wds[0])); +} + +template class MPTypeChooser, int input_stride> +static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + const float rescale_grad) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiAdamKernelParam param; + FillMultiAdamKernelParam + (attrs, ctx, inputs, outputs, ¶m); + + Kernel::value>, xpu>:: + Launch(s, param.max_size, param, req[0], rescale_grad); + }); +} + +template +void GetScaleFloat(const TBlob &scale_blob, float *pScalef); + +template +bool PrepareInputBlobs(const std::vector &inputs, + std::vector *inputs_wo_scale, + float *pScalef) { + const size_t num_in = inputs.size() - 1; + GetScaleFloat(inputs[num_in], pScalef); + if (!std::isfinite(*pScalef) || *pScalef == 0) + return false; + + inputs_wo_scale->reserve(num_in); + for (size_t i = 0; i < num_in; i++) + inputs_wo_scale->emplace_back(inputs[i]); + + return true; +} + +template +inline void MPUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + std::vector inputs_wo_scale; + float scalef; + if (!PrepareInputBlobs(inputs, &inputs_wo_scale, &scalef)) + return; + + F::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); +} + +template +inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + std::vector inputs_wo_scale; + float scalef; + if (!PrepareInputBlobs(inputs, &inputs_wo_scale, &scalef)) + return; + + if (!MP) + MultiAdamWUpdate + (attrs, ctx, inputs_wo_scale, req, outputs, scalef); + else + MultiAdamWUpdate + (attrs, ctx, inputs_wo_scale, req, outputs, scalef); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index f0716c6020f9..2c730f0b3e7b 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -18,37 +18,18 @@ */ /*! - * Copyright (c) 2016 by Contributors - * \file optimizer_op.cc + * Copyright (c) 2018 by Contributors + * \file adamw.cc * \brief Optimizer operators - * \author Haibin Lin + * \author Haibin Lin, Moises Hernandez, Andrei Ivanov */ #include "./adamw-inl.h" -#include "../optimizer_op-inl.h" namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(AdamWParam); - -template