diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 3cc5b85c8384..6af8ef8b8f2b 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,7 @@ 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', - 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', + 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'average', 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', @@ -3464,6 +3464,97 @@ def argmin(a, axis=None, out=None): return _npi.argmin(a, axis=axis, keepdims=False, out=out) +@set_module('mxnet.ndarray.numpy') +def average(a, axis=None, weights=None, returned=False, out=None): + """ + Compute the weighted average along the specified axis. + + Parameters + -------- + a : ndarray + Array containing data to be averaged. + axis : None or int or tuple of ints, optional + Axis or axes along which to average a. + The default, axis=None, will average over + all of the elements of the input array. + If axis is negative it counts from the last to the first axis. + New in version 1.7.0. + If axis is a tuple of ints, averaging is + performed on all of the axes specified in the tuple + instead of a single axis or all the axes as before. + weights : ndarray, optional + An array of weights associated with the values in a, must be the same dtype with a. + Each value in a contributes to the average according to its associated weight. + The weights array can either be 1-D (in which case its length must be + the size of a along the given axis) or of the same shape as a. + If weights=None, then all data in a are assumed to have a weight equal to one. + The 1-D calculation is: avg = sum(a * weights) / sum(weights) + The only constraint on weights is that sum(weights) must not be 0. + returned : bool, optional + Default is False. + If True, the tuple (average, sum_of_weights) is returned, + otherwise only the average is returned. + If weights=None, sum_of_weights is equivalent to + the number of elements over which the average is taken. + out : ndarray, optional + If provided, the calculation is done into this array. + + Returns + -------- + retval, [sum_of_weights] : ndarray + Return the average along the specified axis. + When returned is True, return a tuple with the average as the first element + and the sum of the weights as the second element. sum_of_weights is of the same type as retval. + If a is integral, the result dtype will be float32, otherwise it will be the same as dtype of a. + + Raises + -------- + MXNetError + - When all weights along axis sum to zero. + - When the length of 1D weights is not the same as the shape of a along axis. + - When given 1D weights, the axis is not specified or is not int. + - When the shape of weights and a differ, but weights are not 1D. + + See also + -------- + mean + + Notes + -------- + This function differs from the original `numpy.average` + `_ in + the following way(s): + + - Does not guarantee the same behavior with numpy when given float16 dtype and overflow happens + - Does not support complex dtype + - The dtypes of a and weights must be the same + - Integral a results in float32 returned dtype, not float64 + + Examples + -------- + >>> data = np.arange(1, 5) + >>> data + array([1., 2., 3., 4.]) + >>> np.average(data) + array(2.5) + >>> np.average(np.arange(1, 11), weights=np.arange(10, 0, -1)) + array(4.) + >>> data = np.arange(6).reshape((3,2)) + >>> data + array([[0., 1.], + [2., 3.], + [4., 5.]]) + >>> weights = np.array([0.25, 0.75]) + array([0.25, 0.75]) + >>> np.average(data, axis=1, weights=weights) + array([0.75, 2.75, 4.75]) + """ + if weights is None: + return _npi.average(a, axis=axis, weights=None, returned=returned, weighted=False, out=out) + else: + return _npi.average(a, axis=axis, weights=weights, returned=returned, out=out) + + @set_module('mxnet.ndarray.numpy') def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index e94d4c8341b4..9676bbe920ab 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -51,7 +51,7 @@ 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', - 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', + 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'average', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -5402,6 +5402,94 @@ def argmin(a, axis=None, out=None): return _mx_nd_np.argmin(a, axis, out) +@set_module('mxnet.numpy') +def average(a, axis=None, weights=None, returned=False, out=None): + """ + Compute the weighted average along the specified axis. + + Parameters + -------- + a : ndarray + Array containing data to be averaged. + axis : None or int or tuple of ints, optional + Axis or axes along which to average a. + The default, axis=None, will average over + all of the elements of the input array. + If axis is negative it counts from the last to the first axis. + New in version 1.7.0. + If axis is a tuple of ints, averaging is + performed on all of the axes specified in the tuple + instead of a single axis or all the axes as before. + weights : ndarray, optional + An array of weights associated with the values in a, must be the same dtype with a. + Each value in a contributes to the average according to its associated weight. + The weights array can either be 1-D (in which case its length must be + the size of a along the given axis) or of the same shape as a. + If weights=None, then all data in a are assumed to have a weight equal to one. + The 1-D calculation is: avg = sum(a * weights) / sum(weights) + The only constraint on weights is that sum(weights) must not be 0. + returned : bool, optional + Default is False. + If True, the tuple (average, sum_of_weights) is returned, + otherwise only the average is returned. + If weights=None, sum_of_weights is equivalent to + the number of elements over which the average is taken. + out : ndarray, optional + If provided, the calculation is done into this array. + + Returns + -------- + retval, [sum_of_weights] : ndarray + Return the average along the specified axis. + When returned is True, return a tuple with the average as the first element + and the sum of the weights as the second element. sum_of_weights is of the same type as retval. + If a is integral, the result dtype will be float32, otherwise it will be the same as dtype of a. + + Raises + -------- + MXNetError + - When all weights along axis sum to zero. + - When the length of 1D weights is not the same as the shape of a along axis. + - When given 1D weights, the axis is not specified or is not int. + - When the shape of weights and a differ, but weights are not 1D. + + See also + -------- + mean + + Notes + -------- + This function differs from the original `numpy.average` + `_ in + the following way(s): + + - Does not guarantee the same behavior with numpy when given float16 dtype and overflow happens + - Does not support complex dtype + - The dtypes of a and weights must be the same + - Integral a results in float32 returned dtype, not float64 + + Examples + -------- + >>> data = np.arange(1, 5) + >>> data + array([1., 2., 3., 4.]) + >>> np.average(data) + array(2.5) + >>> np.average(np.arange(1, 11), weights=np.arange(10, 0, -1)) + array(4.) + >>> data = np.arange(6).reshape((3,2)) + >>> data + array([[0., 1.], + [2., 3.], + [4., 5.]]) + >>> weights = np.array([0.25, 0.75]) + array([0.25, 0.75]) + >>> np.average(data, axis=1, weights=weights) + array([0.75, 2.75, 4.75]) + """ + return _mx_nd_np.average(a, axis=axis, weights=weights, returned=returned, out=out) + + @set_module('mxnet.numpy') def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """ diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 7da771966f1f..362c9b4af6d8 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,7 +34,7 @@ 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'average', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', @@ -3356,6 +3356,96 @@ def argmin(a, axis=None, out=None): return _npi.argmin(a, axis=axis, keepdims=False, out=out) +def average(a, axis=None, weights=None, returned=False, out=None): + """ + Compute the weighted average along the specified axis. + + Parameters + -------- + a : _Symbol + Array containing data to be averaged. + axis : None or int or tuple of ints, optional + Axis or axes along which to average a. + The default, axis=None, will average over + all of the elements of the input array. + If axis is negative it counts from the last to the first axis. + New in version 1.7.0. + If axis is a tuple of ints, averaging is + performed on all of the axes specified in the tuple + instead of a single axis or all the axes as before. + weights : _Symbol, optional + An array of weights associated with the values in a, must be the same dtype with a. + Each value in a contributes to the average according to its associated weight. + The weights array can either be 1-D (in which case its length must be + the size of a along the given axis) or of the same shape as a. + If weights=None, then all data in a are assumed to have a weight equal to one. + The 1-D calculation is: avg = sum(a * weights) / sum(weights) + The only constraint on weights is that sum(weights) must not be 0. + returned : bool, optional + Default is False. + If True, the tuple (average, sum_of_weights) is returned, + otherwise only the average is returned. + If weights=None, sum_of_weights is equivalent to + the number of elements over which the average is taken. + out : _Symbol, optional + If provided, the calculation is done into this array. + + Returns + -------- + retval, [sum_of_weights] : _Symbol + Return the average along the specified axis. + When returned is True, return a tuple with the average as the first element + and the sum of the weights as the second element. sum_of_weights is of the same type as retval. + If a is integral, the result dtype will be float32, otherwise it will be the same as dtype of a. + + Raises + -------- + MXNetError + - When all weights along axis sum to zero. + - When the length of 1D weights is not the same as the shape of a along axis. + - When given 1D weights, the axis is not specified or is not int. + - When the shape of weights and a differ, but weights are not 1D. + + See also + -------- + mean + + Notes + -------- + This function differs from the original `numpy.average` + `_ in + the following way(s): + + - Does not guarantee the same behavior with numpy when given float16 dtype and overflow happens + - Does not support complex dtype + - The dtypes of a and weights must be the same + - Integral a results in float32 returned dtype, not float64 + + Examples + -------- + >>> data = np.arange(1, 5) + >>> data + array([1., 2., 3., 4.]) + >>> np.average(data) + array(2.5) + >>> np.average(np.arange(1, 11), weights=np.arange(10, 0, -1)) + array(4.) + >>> data = np.arange(6).reshape((3,2)) + >>> data + array([[0., 1.], + [2., 3.], + [4., 5.]]) + >>> weights = np.array([0.25, 0.75]) + array([0.25, 0.75]) + >>> np.average(data, axis=1, weights=weights) + array([0.75, 2.75, 4.75]) + """ + if weights is None: + return _npi.average(a, axis=axis, weights=None, returned=returned, weighted=False, out=out) + else: + return _npi.average(a, axis=axis, weights=weights, returned=returned, out=out) + + @set_module('mxnet.symbol.numpy') def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """ diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 3566323f1eb3..df9a7c932490 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -30,6 +30,7 @@ #include #include "../nn/moments-inl.h" #include "../tensor/broadcast_reduce_op.h" +#include "../tensor/elemwise_binary_broadcast_op.h" namespace mxnet { namespace op { @@ -406,6 +407,353 @@ void ReduceAxesComputeWithWorkspaceImpl(const OpContext& ctx, }); } +struct NumpyWeightedAverageParam : public dmlc::Parameter { + dmlc::optional> axis; + bool returned; + bool weighted; + + DMLC_DECLARE_PARAMETER(NumpyWeightedAverageParam) { + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional>()) + .describe("Axis or axes along which a average is performed. " + "The default, axis=None, will average " + "all of the elements of the input array. If axis is negative it counts from the " + "last to the first axis."); + DMLC_DECLARE_FIELD(returned) + .set_default(false) + .describe("If True, the tuple (average, sum_of_weights) is returned," + "otherwise only the average is returned." + "If weights=None, sum_of_weights is equivalent to" + "the number of elements over which the average is taken."); + DMLC_DECLARE_FIELD(weighted) + .set_default(true) + .describe("Auxiliary flag to deal with none weights."); + } +}; + +inline bool NumpyWeightedAverageShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const auto& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), (param.weighted ? 2U : 1U)); + CHECK_EQ(out_attrs->size(), 2U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + + const TShape& a_shape = (*in_attrs)[0]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl(a_shape, param.axis, false)); + + if (param.weighted) { + const TShape& w_shape = (*in_attrs)[1]; + if (w_shape.ndim() != a_shape.ndim()) { + CHECK_EQ(w_shape.ndim(), 1U) + << "1D weights expected when shapes of a and weights differ."; + CHECK_EQ(param.axis.has_value(), true) + << "Axis must be specified when shapes of a and weights differ."; + mxnet::Tuple axes(param.axis.value()); + CHECK_EQ(axes.ndim(), 1U) << "Axis must be int when shapes of a and weights differ."; + int red_axis = axes[0] < 0 ? axes[0] + a_shape.ndim() : axes[0]; + CHECK_EQ(a_shape[red_axis], w_shape[0]) + << "Length of weights not compatible with specified axis."; + SHAPE_ASSIGN_CHECK(*out_attrs, 1, + NumpyReduceAxesShapeImpl( + w_shape, dmlc::optional>(), false)); + } else { + for (int i = 0; i < w_shape.ndim(); i++) { + CHECK_EQ(w_shape[i], a_shape[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 1, + NumpyReduceAxesShapeImpl(w_shape, param.axis, false)); + } + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(0, -1)); + } + + return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1)); +} + +template +struct avg_grad_a_kernel { + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const DType* w, + const DType* scl, + const DType* ograd, + mshadow::Shape small, + mshadow::Shape big) { + // partial a = w / sum(w) + size_t big_idx = i; + size_t small_idx = i; + size_t big_stride = 1; + size_t small_stride = 1; + size_t red_axis_idx = 0; + for (int axis = NDim-1; axis >= 0; --axis) { + size_t axis_idx = big_idx % big[axis]; + small_idx -= axis_idx * big_stride; + if (small[axis] != 1) { + small_idx += axis_idx * small_stride; + } else if (onedim && small[axis] != big[axis]) { + red_axis_idx = axis_idx; + } + big_idx /= big[axis]; + big_stride *= big[axis]; + small_stride *= small[axis]; + } + if (onedim) { + KERNEL_ASSIGN(out[i], req, (ograd[small_idx] * (w[red_axis_idx] / *scl))); + } else { + KERNEL_ASSIGN(out[i], req, (ograd[small_idx] * (w[i] / scl[small_idx]))); + } + } +}; + +template +struct avg_grad_w_kernel { + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const DType* a, + const DType* scl, + const DType* sum_of_wa, + const DType* ograd, + mshadow::Shape small, + mshadow::Shape big) { + // partial w = (a * sum(w) - sum(a*w)) / (sum(w) * sum(w)) + size_t big_idx = i; + size_t small_idx = i; + size_t big_stride = 1; + size_t small_stride = 1; + for (int axis = NDim-1; axis >= 0; --axis) { + size_t axis_idx = big_idx % big[axis]; + small_idx -= axis_idx * big_stride; + if (small[axis] != 1) { + small_idx += axis_idx * small_stride; + } + big_idx /= big[axis]; + big_stride *= big[axis]; + small_stride *= small[axis]; + } + DType ret = ograd[small_idx] * + (((a[i] * scl[small_idx] - sum_of_wa[small_idx]) / scl[small_idx]) / scl[small_idx]); + KERNEL_ASSIGN(out[i], req, ret); + } +}; + +template +struct avg_grad_w_1D_kernel { + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const DType* a, + const DType* scl, + const DType* sum_of_wa, + const DType* ograd, + mshadow::Shape big, + const int red_axis) { + DType scl_val = *scl; + size_t tail = 1; + size_t head = 1; + for (int axis = NDim-1; axis > red_axis; --axis) { + tail *= big[axis]; + } + for (int axis = 0; axis < red_axis; ++axis) { + head *= big[axis]; + } + DType ret = 0; + for (size_t j = 0; j < head; ++j) { + for (size_t k = 0; k < tail; ++k) { + size_t a_idx = j*(tail*big[red_axis]) + i * tail + k; + size_t small_idx = j*tail + k; + ret += (ograd[small_idx] * + (((a[a_idx] * scl_val - sum_of_wa[small_idx]) / scl_val) / scl_val)); + } + } + KERNEL_ASSIGN(out[i], req, ret); + } +}; + +template +void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const dmlc::optional>& axis) { + using namespace mshadow; + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + const TBlob& data = inputs[0]; + TShape small1 = NumpyReduceAxesShapeImpl(data.shape_, axis, true); + // Reshape weights + TShape small2 = small1; + TBlob weights = inputs[1]; + + bool one_dim = weights.shape_.ndim() != data.shape_.ndim(); + + int red_axis = -1; + + if (one_dim) { + CHECK_EQ(weights.shape_.ndim(), 1U) + << "1D weights expected when shapes of a and weights differ."; + CHECK_EQ(axis.has_value(), true) + << "Axis must be specified when shapes of a and weights differ."; + Tuple axes(axis.value()); + CHECK_EQ(axes.ndim(), 1U) + << "Axis must be int when shapes of a and weights differ."; + red_axis = axes[0] < 0 ? axes[0] + data.shape_.ndim() : axes[0]; + CHECK_EQ(weights.shape_[0], data.shape_[red_axis]) + << "Length of weights not compatible with specified axis."; + TShape new_w_shape(data.shape_.ndim(), 1); + new_w_shape[red_axis] = weights.shape_[0]; + weights = weights.reshape(new_w_shape); + small2 = TShape(new_w_shape.ndim(), 1); + } + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + // Get temp space + size_t temp_data_size = data.shape_.Size() * sizeof(DType); + size_t temp_sum_size = small1.Size() * sizeof(DType); + TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(data.shape_, small1, &src_shape, &dst_shape); + size_t workspace_size = 0; + MXNET_NDIM_SWITCH(dst_shape.ndim(), NDim, { + workspace_size = broadcast::ReduceWorkspaceSize( + s, dst_shape, {kWriteTo}, src_shape); + }); + size_t temp_mem_size = temp_data_size + temp_sum_size + workspace_size; + Tensor temp_mem = + ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); + DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); + DType *temp_sum_ptr = reinterpret_cast(temp_mem.dptr_ + temp_data_size); + char *workspace_ptr = temp_mem.dptr_ + temp_data_size + temp_sum_size; + Tensor workspace(workspace_ptr, Shape1(workspace_size), s); + + // Compute weighted data + TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); + BinaryBroadcastCompute( + attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + + // Compute sum of weighted data + TBlob sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); + ReduceAxesComputeWithWorkspaceImpl( + ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape); + if (!back) { + const TBlob& avg = outputs[0]; + const TBlob& sum_of_weights = outputs[1]; + TShape w_src_shape, w_dst_shape; + BroadcastReduceShapeCompact(weights.shape_, small2, &w_src_shape, &w_dst_shape); + // Compute sum of weight + TBlob scl = sum_of_weights.reshape(small2); + ReduceAxesComputeWithWorkspaceImpl( + ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape); + + // Compute avg and assign output + BinaryBroadcastCompute( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); + } else { + // Compute and assign the derivatives of a and weights + const TBlob& igrad_a = outputs[0]; + const TBlob& igrad_w = outputs[1]; + const TBlob& scl = inputs[2]; + const TBlob& ograd = inputs[3]; + MXNET_NDIM_SWITCH(igrad_a.shape_.ndim(), NDim, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_a, { + if (one_dim) { + // 1D weights + Kernel, xpu>::Launch( + s, igrad_a.shape_.Size(), igrad_a.dptr(), + weights.dptr(), scl.dptr(), ograd.dptr(), + small1.get(), + igrad_a.shape_.get()); + } else { + Kernel, xpu>::Launch( + s, igrad_a.shape_.Size(), igrad_a.dptr(), + weights.dptr(), scl.dptr(), ograd.dptr(), + small1.get(), + igrad_a.shape_.get()); + } + }); + MXNET_ASSIGN_REQ_SWITCH(req[1], req_w, { + if (one_dim) { + Kernel, xpu>::Launch( + s, igrad_w.shape_.Size(), igrad_w.dptr(), + data.dptr(), scl.dptr(), sum_of_wa.dptr(), ograd.dptr(), + data.shape_.get(), + red_axis); + } else { + Kernel, xpu>::Launch( + s, igrad_w.shape_.Size(), igrad_w.dptr(), + data.dptr(), scl.dptr(), sum_of_wa.dptr(), ograd.dptr(), + small1.get(), + igrad_w.shape_.get()); + } + }); + }) + } + }); +} + +template +void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + if (req[0] == kNullOp) return; + CHECK_NE(req[0], kWriteInplace) << "Average does not support write in-place"; + const auto& param = nnvm::get(attrs.parsed); + const TBlob& data = inputs[0]; + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + if (!param.weighted) { + TShape small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true); + // Compute sum of weights which equals to the product of sizes of reduced axes + Stream* s = ctx.get_stream(); + auto ret = outputs[1].FlatTo1D(s); + ret = scalar(data.shape_.Size()/small.Size()); + // Compute mean + ReduceAxesComputeImpl( + ctx, inputs, req, {outputs[0]}, small); + } else { + NumpyWeightedAverageComputeImpl( + attrs, ctx, inputs, req, outputs, param.axis); + } + }); +} + +template +void NumpyWeightedAverageBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const auto& param = nnvm::get(attrs.parsed); + if (req[0] == kNullOp && !param.weighted) return; + CHECK_EQ(inputs.size(), (param.weighted ? 6U : 5U)); + CHECK_EQ(outputs.size(), (param.weighted ? 2U : 1U)); + const TBlob& ograd = inputs[0]; + const TBlob& data = inputs[2]; + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + if (!param.weighted) { + TShape small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); + Stream* s = ctx.get_stream(); + auto ograd_tensor = ograd.FlatTo1D(s); + ograd_tensor /= scalar(data.shape_.Size()/small.Size()); + BroadcastComputeImpl(attrs, ctx, {ograd}, req, {outputs[0]}, small); + } else { + const TBlob& weights = inputs[3]; + const TBlob& scl = inputs[5]; + NumpyWeightedAverageComputeImpl( + attrs, ctx, {data, weights, scl, ograd}, req, outputs, param.axis); + } + }); +} + template void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index fb133568a7a5..2a1bc5261701 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -35,6 +35,7 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam); DMLC_REGISTER_PARAMETER(NumpyReduceAxesNoDTypeParam); DMLC_REGISTER_PARAMETER(NumpyMomentsParam); +DMLC_REGISTER_PARAMETER(NumpyWeightedAverageParam); inline bool NumpySumType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -249,6 +250,76 @@ inline bool IsIntType(const int dtype) { dtype == mshadow::kInt64); } +inline bool NumpyWeightedAverageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const auto ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), (param.weighted ? 2U : 1U)); + CHECK_EQ(out_attrs->size(), 2U); + + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + if (param.weighted) { + TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); + } + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0)); + + return in_attrs->at(0) != -1 && out_attrs->at(0) != -1 && + (!param.weighted || (in_attrs->at(1) != -1)) && + out_attrs->at(1) != -1; +} + +NNVM_REGISTER_OP(_npi_average) +.set_num_inputs( + [](const NodeAttrs& attrs) { + const auto& param = nnvm::get(attrs.parsed); + return param.weighted ? 2 : 1; + }) +.set_num_outputs(2) +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + const auto& param = nnvm::get(attrs.parsed); + return param.returned ? 2 : 1; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyWeightedAverageShape) +.set_attr("FInferType", NumpyWeightedAverageType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const auto& param = nnvm::get(attrs.parsed); + return param.weighted ? + std::vector{"a", "weights"} : + std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "The input") +.add_argument("weights", "NDArray-or-Symbol", "The weights to calculate average") +.add_arguments(NumpyWeightedAverageParam::__FIELDS__()) +.set_attr("FCompute", NumpyWeightedAverageForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_np_average"}); + +NNVM_REGISTER_OP(_backward_np_average) +.set_num_outputs( + [](const NodeAttrs& attrs) { + const auto& param = nnvm::get(attrs.parsed); + return param.weighted ? 2 : 1; + }) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_num_inputs( + [](const NodeAttrs& attrs) { + const auto& param = nnvm::get(attrs.parsed); + return param.weighted ? 6 : 5; + }) +.set_attr("FCompute", NumpyWeightedAverageBackward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}); + inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index 53e78787d47d..56194ff34a7e 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -50,6 +50,12 @@ NNVM_REGISTER_OP(_np_prod) NNVM_REGISTER_OP(_backward_np_prod) .set_attr("FCompute", NumpyReduceAxesBackwardUseInOut); +NNVM_REGISTER_OP(_npi_average) +.set_attr("FCompute", NumpyWeightedAverageForward); + +NNVM_REGISTER_OP(_backward_np_average) +.set_attr("FCompute", NumpyWeightedAverageBackward); + NNVM_REGISTER_OP(_npi_mean) .set_attr("FCompute", NumpyReduceAxesCompute); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d019081ec3ee..0b9389463818 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -599,6 +599,119 @@ def _test_np_exception(func, shape, dim): _test_np_exception(func, shape, dim) +@with_seed() +@use_np +def test_np_average(): + class TestAverage(HybridBlock): + def __init__(self, axis=None, returned=False): + super(TestAverage, self).__init__() + # necessary initializations + self._axis = axis + self._returned = returned + + def hybrid_forward(self, F, a, weights): + return F.np.average(a, weights=weights, axis=self._axis, returned=self._returned) + + def avg_backward(a, w, avg, axes, init_a_grad=None, init_w_grad=None): + # avg = sum(a * w) / sum(w) + if axes is not None and not isinstance(axes, tuple) and axes < 0: + axes += a.ndim + if w is None: + a_grad = _np.ones(shape=a.shape, dtype=a.dtype)/(a.size/avg.size) + if init_a_grad is not None: + a_grad += init_a_grad.asnumpy() + return [a_grad, None] + onedim = a.ndim != w.ndim + if onedim: + new_shape = [a.shape[i] if i == axes else 1 for i in range(a.ndim)] + w = w.reshape(new_shape) + w = _np.broadcast_to(w, a.shape) + + # partial a = w / sum(w) + # partial w = (a*sum(w) - sum(a*w)) / (sum(w) * sum(w)) + scl = _np.sum(w, axis=axes, keepdims=True) + a_grad = _np.divide(w, scl) + w_grad = _np.divide(a*scl-_np.sum(a*w, axis=axes, keepdims=True), scl*scl) + + if onedim: + axis = list(range(a.ndim)) + axis.remove(axes) + w_grad = _np.sum(w_grad, axis=tuple(axis)) + if init_a_grad is not None: + a_grad += init_a_grad.asnumpy() + if init_w_grad is not None: + w_grad += init_w_grad.asnumpy() + return [a_grad, w_grad] + + tensor_shapes = [ + ((3, 5), (3, 5), None), # (a_shape, w_shape, axes) + ((4, 5, 6), (4, 5, 6), (0, 2)), + ((3,), (3,), 0), + ((2, 3), (3,), 1), + ((2, 3, 4), (2,), 0), + ((2, 3, 4), (3,), 1), + ((2, 3, 4), (4,), -1), + ((2, 3, 4, 5), (5,), 3) + ] + + flags = [True, False] + dtypes = ['float32', 'float64'] + reqs = ['null', 'add', 'write'] + for hybridize, returned, (a_shape, w_shape, axes), dtype, is_weighted, req_a in \ + itertools.product(flags, flags, tensor_shapes, dtypes, flags, reqs): + if req_a == 'null' and not is_weighted: + continue + rtol, atol = 1e-3, 1e-4 + test_average = TestAverage(axes, returned) + if hybridize: + test_average.hybridize() + a = np.random.uniform(-1.0, 1.0, size=a_shape, dtype=dtype) + a.attach_grad(req_a) + init_a_grad = np.random.uniform(-1.0, 1.0, size=a_shape, dtype=dtype) if req_a == 'add' else None + init_w_grad = None + req_w = req_a + w, np_w = None, None + if is_weighted: + w = np.random.uniform(-1.0, 1.0, size=w_shape, dtype=dtype) + if req_a == 'null': + req_w = random.choice(['add', 'write']) + w.attach_grad(req_w) + if req_w == 'add': + init_w_grad = np.random.uniform(-1.0, 1.0, size=w_shape, dtype=dtype) + np_w = w.asnumpy() + np_out = _np.average(a.asnumpy(), axis=axes, weights=np_w, returned=returned) + with mx.autograd.record(): + mx_out = test_average(a, w) + if returned: + np_out, np_sum_of_weights = np_out + mx_out, mx_sum_of_weights = mx_out + assert_almost_equal(mx_sum_of_weights.asnumpy(), np_sum_of_weights, rtol=rtol, atol=atol) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out.astype(dtype), rtol=rtol, atol=atol) + if req_a == 'add': + a.grad[:] = init_a_grad + if is_weighted and req_w == 'add': + w.grad[:] = init_w_grad + mx_out.backward() + # Code to get reference backward value + a_grad, w_grad = avg_backward(a.asnumpy(), np_w, np_out, axes, init_a_grad, init_w_grad) + if is_weighted: + assert_almost_equal(w.grad.asnumpy(), w_grad, rtol=rtol*10, atol=atol*10) + if req_a == 'null': + assert a.grad is None + else: + assert_almost_equal(a.grad.asnumpy(), a_grad, rtol=rtol, atol=atol) + + # Test imperative once again + np_out = _np.average(a.asnumpy(), weights=np_w, axis=axes, returned=returned) + mx_out = np.average(a, weights=w, axis=axes, returned=returned) + if returned: + np_out, np_sum_of_weights = np_out + mx_out, mx_sum_of_weights = mx_out + assert_almost_equal(mx_sum_of_weights.asnumpy(), np_sum_of_weights, rtol=rtol, atol=atol) + assert_almost_equal(mx_out.asnumpy(), np_out.astype(dtype), rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_mean():