diff --git a/julia/deps/build.jl b/julia/deps/build.jl index a87343d9dab5..a79d2a062c18 100644 --- a/julia/deps/build.jl +++ b/julia/deps/build.jl @@ -54,7 +54,7 @@ if Sys.isunix() nvcc_path = Sys.which("nvcc") if nvcc_path ≢ nothing @info "Found nvcc: $nvcc_path" - push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc", "lib64")) + push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc" => "lib64")) end end diff --git a/julia/src/metric.jl b/julia/src/metric.jl index f1cdc68d947f..2ae7fc85144b 100644 --- a/julia/src/metric.jl +++ b/julia/src/metric.jl @@ -260,7 +260,7 @@ end function get(metric::MSE) # Delay copy until last possible moment - mse_sum = mapreduce(nda->copy(nda)[1], +, 0.0, metric.mse_sum) + mse_sum = mapreduce(nda->copy(nda)[1], +, metric.mse_sum; init = zero(MX_float)) [(:MSE, mse_sum / metric.n_sample)] end diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 397f4775f8cd..2146853a6866 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -368,6 +368,9 @@ 'arctanh', '_sparse_arcsin', '_sparse_arctanh', + '_contrib_MultiBoxDetection', + '_contrib_MultiBoxPrior', + '_contrib_MultiBoxTarget', # Exponents 'exp', @@ -575,9 +578,6 @@ 'stack', '_Maximum', '_Minimum', - '_contrib_MultiBoxDetection', - '_contrib_MultiBoxPrior', - '_contrib_MultiBoxTarget', '_contrib_MultiProposal', '_contrib_PSROIPooling', '_contrib_Proposal', diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 7e02922b5e66..9f794b408e8c 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -34,11 +34,12 @@ 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', - 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', - 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', + 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', + 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'append'] + 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', + 'append'] @set_module('mxnet.ndarray.numpy') @@ -3058,6 +3059,36 @@ def get_list(arrays): return _npi.vstack(*arrays) +@set_module('mxnet.ndarray.numpy') +def column_stack(tup): + """ + Stack 1-D arrays as columns into a 2-D array. + Take a sequence of 1-D arrays and stack them as columns + to make a single 2-D array. 2-D arrays are stacked as-is, + just like with `hstack`. 1-D arrays are turned into 2-D columns + first. + + Returns + -------- + stacked : 2-D array + The array formed by stacking the given arrays. + + See Also + -------- + stack, hstack, vstack, concatenate + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.column_stack((a,b)) + array([[1., 2.], + [2., 3.], + [3., 4.]]) + """ + return _npi.column_stack(*tup) + + @set_module('mxnet.ndarray.numpy') def dstack(arrays): """ @@ -5037,3 +5068,51 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _npi.share_memory(a, b).item() + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 9de428878292..301e0101f90e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -51,12 +51,12 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', - 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', + 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'append'] + 'may_share_memory', 'diff', 'append'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -577,7 +577,7 @@ def __setitem__(self, key, value): if not isinstance(key, tuple) or len(key) != 0: raise IndexError('scalar tensor can only accept `()` as index') if isinstance(value, numeric_types): - self.full(value) + self._full(value) elif isinstance(value, ndarray) and value.size == 1: if value.shape != self.shape: value = value.reshape(self.shape) @@ -1993,15 +1993,20 @@ def array(object, dtype=None, ctx=None): """ if ctx is None: ctx = current_context() - if isinstance(object, ndarray): + if isinstance(object, (ndarray, _np.ndarray)): dtype = object.dtype if dtype is None else dtype + elif isinstance(object, NDArray): + raise ValueError("If you're trying to create a mxnet.numpy.ndarray " + "from mx.nd.NDArray, please use the zero-copy as_np_ndarray function.") else: - dtype = _np.float32 if dtype is None else dtype - if not isinstance(object, (ndarray, _np.ndarray)): - try: - object = _np.array(object, dtype=dtype) - except Exception as e: - raise TypeError('{}'.format(str(e))) + if dtype is None: + dtype = object.dtype if hasattr(object, "dtype") else _np.float32 + try: + object = _np.array(object, dtype=dtype) + except Exception as e: + # printing out the error raised by official NumPy's array function + # for transparency on users' side + raise TypeError('{}'.format(str(e))) ret = empty(object.shape, dtype=dtype, ctx=ctx) if len(object.shape) == 0: ret[()] = object @@ -4940,6 +4945,42 @@ def vstack(arrays, out=None): return _mx_nd_np.vstack(arrays) +@set_module('mxnet.numpy') +def column_stack(tup): + """ + Stack 1-D arrays as columns into a 2-D array. + + Take a sequence of 1-D arrays and stack them as columns + to make a single 2-D array. 2-D arrays are stacked as-is, + just like with `hstack`. 1-D arrays are turned into 2-D columns + first. + + Parameters + ---------- + tup : sequence of 1-D or 2-D arrays. + Arrays to stack. All of them must have the same first dimension. + + Returns + -------- + stacked : 2-D array + The array formed by stacking the given arrays. + + See Also + -------- + stack, hstack, vstack, concatenate + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.column_stack((a,b)) + array([[1., 2.], + [2., 3.], + [3., 4.]]) + """ + return _mx_nd_np.column_stack(tup) + + @set_module('mxnet.numpy') def dstack(arrays): """ @@ -7016,3 +7057,52 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _mx_nd_np.may_share_memory(a, b, max_work) + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _mx_nd_np.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index ee1153007ea7..cdd21af829de 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -122,6 +122,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'var', 'vdot', 'vstack', + 'column_stack', 'zeros_like', 'linalg.norm', 'trace', @@ -131,6 +132,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'einsum', 'shares_memory', 'may_share_memory', + 'diff', ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 507ab30c2e83..505cc1f5bd10 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -36,11 +36,11 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', - 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', - 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', + 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', + 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'append'] @@ -3129,6 +3129,42 @@ def get_list(arrays): return _npi.vstack(*arrays) +@set_module('mxnet.symbol.numpy') +def column_stack(tup): + """ + Stack 1-D arrays as columns into a 2-D array. + + Take a sequence of 1-D arrays and stack them as columns + to make a single 2-D array. 2-D arrays are stacked as-is, + just like with `hstack`. 1-D arrays are turned into 2-D columns + first. + + Parameters + ---------- + tup : sequence of 1-D or 2-D arrays. + Arrays to stack. All of them must have the same first dimension. + + Returns + ------- + stacked : 2-D array + The array formed by stacking the given arrays. + + See Also + -------- + stack, hstack, vstack, concatenate + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.column_stack((a,b)) + array([[1., 2.], + [2., 3.], + [3., 4.]]) + """ + return _npi.column_stack(*tup) + + @set_module('mxnet.symbol.numpy') def dstack(arrays): """ @@ -4686,4 +4722,53 @@ def may_share_memory(a, b, max_work=None): return _npi.share_memory(a, b) +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) + + _set_np_symbol_class(_Symbol) diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 2a1461cc8c48..da7b60db7f13 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -76,7 +76,7 @@ void Copy(const TBlob &from, TBlob *to, from.FlatTo1D(s), s); } else { - MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); }) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 463c71b5b0eb..91478660a123 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -671,7 +671,7 @@ template MSHADOW_CINLINE void copy(mshadow::Stream *s, const TBlob& to, const TBlob& from) { CHECK_EQ(from.Size(), to.Size()); CHECK_EQ(from.dev_mask(), to.dev_mask()); - MSHADOW_TYPE_SWITCH(to.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to.type_flag_, DType, { if (to.type_flag_ == from.type_flag_) { mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); } else { diff --git a/src/operator/numpy/np_cumsum-inl.h b/src/operator/numpy/np_cumsum-inl.h index 6c6b56d46e76..375d83b2240f 100644 --- a/src/operator/numpy/np_cumsum-inl.h +++ b/src/operator/numpy/np_cumsum-inl.h @@ -98,7 +98,7 @@ void CumsumForwardImpl(const OpContext& ctx, } Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, IType, { MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { Kernel::Launch( s, out.Size() / middle, out.dptr(), @@ -157,7 +157,7 @@ void CumsumBackwardImpl(const OpContext& ctx, } } Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(igrad.type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(igrad.type_flag_, IType, { MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { Kernel::Launch( s, igrad.Size() / middle, igrad.dptr(), diff --git a/src/operator/numpy/np_cumsum.cc b/src/operator/numpy/np_cumsum.cc index 0ddbf521186c..2d5dbb99f90a 100644 --- a/src/operator/numpy/np_cumsum.cc +++ b/src/operator/numpy/np_cumsum.cc @@ -55,6 +55,9 @@ inline bool CumsumType(const nnvm::NodeAttrs& attrs, } else { TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + if (out_attrs->at(0) == mshadow::kBool) { + (*out_attrs)[0] = mshadow::kInt64; + } } return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; diff --git a/src/operator/numpy/np_diff-inl.h b/src/operator/numpy/np_diff-inl.h new file mode 100644 index 000000000000..69f175e802dd --- /dev/null +++ b/src/operator/numpy/np_diff-inl.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff-inl.h + * \brief Function definition of numpy-compatible diff operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct DiffParam : public dmlc::Parameter { + int n, axis; + dmlc::optional prepend; + dmlc::optional append; + DMLC_DECLARE_PARAMETER(DiffParam) { + DMLC_DECLARE_FIELD(n).set_default(1).describe( + "The number of times values are differenced." + " If zero, the input is returned as-is."); + DMLC_DECLARE_FIELD(axis).set_default(-1).describe( + "Axis along which the cumulative sum is computed." + " The default (None) is to compute the diff over the flattened array."); + } +}; + +inline void YanghuiTri(std::vector* buffer, int n) { + // apply basic yanghui's triangular to calculate the factors + (*buffer)[0] = 1; + for (int i = 1; i <= n; ++i) { + (*buffer)[i] = 1; + for (int j = i - 1; j > 0; --j) { + (*buffer)[j] += (*buffer)[j - 1]; + } + } +} + +struct diff_forward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* out, + const IType* in, const int n, + const int stride, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + + // j represent the memory index of the corresponding input entry + int j = ravel(unravel(i, oshape), ishape); + int indicator = 1; + out[i] = 0; + for (int k = n; k >= 0; --k) { + out[i] += in[j + stride * k] * indicator * diffFactor[k]; + indicator *= -1; + } + } +}; + +template +void DiffForwardImpl(const OpContext& ctx, const TBlob& in, const TBlob& out, + const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, in.ndim()); + // nothing in the output + if (n >= in.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = in.ndim() - 1; i > axis_checked; --i) { + stride *= in.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { + MXNET_NDIM_SWITCH(in.ndim(), ndim, { + Kernel::Launch( + s, out.Size(), diffFactor.dptr_, + out.dptr(), in.dptr(), + n, stride, out.shape_.get(), + in.shape_.get()); + }); + }); + }); +} + +template +void DiffForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffForwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +struct diff_backward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* igrad, + const IType* ograd, const int n, + const int stride, const int axis, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + if (n == 0) { + igrad[i] = ograd[i]; + return; + } + + Shape coor = unravel(i, oshape); + // one head thread for a whole sequence along the axis + if (coor[axis] != 0) return; + int j = ravel(coor, ishape); + // initialize the elements of output array + for (int k = 0; k < oshape[axis]; ++k) igrad[i + k * stride] = 0; + for (int k = 0; k < ishape[axis]; ++k) { + int indicator = 1; + for (int m = n; m >= 0; --m) { + igrad[i + (m + k) * stride] += + ograd[j + k * stride] * indicator * diffFactor[m]; + indicator *= -1; + } + } + } +}; + +template +void DiffBackwardImpl(const OpContext& ctx, const TBlob& ograd, + const TBlob& igrad, const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, igrad.ndim()); + // nothing in the ograd and igrad + if (n >= igrad.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = igrad.ndim() - 1; i > axis_checked; --i) { + stride *= igrad.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(ograd.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(igrad.type_flag_, OType, { + MXNET_NDIM_SWITCH(igrad.ndim(), ndim, { + Kernel::Launch( + s, igrad.Size(), diffFactor.dptr_, + igrad.dptr(), ograd.dptr(), + n, stride, axis_checked, + igrad.shape_.get(), ograd.shape_.get()); + }); + }); + }); +} + +template +void DiffBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffBackwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ diff --git a/src/operator/numpy/np_diff.cc b/src/operator/numpy/np_diff.cc new file mode 100644 index 000000000000..a3dae332d842 --- /dev/null +++ b/src/operator/numpy/np_diff.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cc + * \brief CPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +inline TShape NumpyDiffShapeImpl(const TShape& ishape, + const int n, + const int axis) { + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, ishape.ndim()); + + TShape oshape = ishape; + if (n >= ishape[axis_checked]) { + oshape[axis_checked] = 0; + } else { + oshape[axis_checked] -= n; + } + return oshape; +} + +inline bool DiffShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const DiffParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyDiffShapeImpl((*in_attrs)[0], param.n, param.axis)); + return shape_is_known(out_attrs->at(0)); +} + +inline bool DiffType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(DiffParam); + +NNVM_REGISTER_OP(_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("FInferShape", DiffShape) +.set_attr("FInferType", DiffType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffForward) +.set_attr("FGradient", + ElemwiseGradUseNone{"_backward_npi_diff"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("a", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(DiffParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_diff.cu b/src/operator/numpy/np_diff.cu new file mode 100644 index 000000000000..daea6e368e05 --- /dev/null +++ b/src/operator/numpy/np_diff.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cu + * \brief GPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_diff) +.set_attr("FCompute", DiffForward); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h index 69999ae8710e..df30d611aa02 100644 --- a/src/operator/numpy/np_init_op.h +++ b/src/operator/numpy/np_init_op.h @@ -205,7 +205,7 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); const TBlob& out_data = outputs[0]; int n = out_data.shape_[0]; - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { Kernel, xpu>::Launch( s, out_data.Size(), out_data.dptr(), n); diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 60daba102367..a9828f40436d 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -52,6 +52,14 @@ struct NumpyVstackParam : public dmlc::Parameter { } }; +struct NumpyColumnStackParam : public dmlc::Parameter { + int num_args; + DMLC_DECLARE_PARAMETER(NumpyColumnStackParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs to be column stacked"); + } +}; + struct NumpyReshapeParam : public dmlc::Parameter { mxnet::TShape newshape; std::string order; @@ -124,6 +132,78 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } +template +void NumpyColumnStackForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), param.num_args); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(req.size(), 1); + + // reshape if necessary + std::vector data(param.num_args); + for (int i = 0; i < param.num_args; i++) { + if (inputs[i].shape_.ndim() == 0 || inputs[i].shape_.ndim() == 1) { + TShape shape = Shape2(inputs[i].shape_.Size(), 1); + data[i] = inputs[i].reshape(shape); + } else { + data[i] = inputs[i]; + } + } + + // initialize ConcatOp + ConcatParam cparam; + cparam.num_args = param.num_args; + cparam.dim = 1; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + ConcatOp op; + op.Init(cparam); + op.Forward(ctx, data, req, outputs); + }); +} + +template +void NumpyColumnStackBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), param.num_args); + CHECK_EQ(req.size(), param.num_args); + + // reshape if necessary + std::vector data(param.num_args); + for (int i = 0; i < param.num_args; i++) { + if (outputs[i].shape_.ndim() == 0 || outputs[i].shape_.ndim() == 1) { + TShape shape = Shape2(outputs[i].shape_.Size(), 1); + data[i] = outputs[i].reshape(shape); + } else { + data[i] = outputs[i]; + } + } + + // initialize ConcatOp + ConcatParam cparam; + cparam.num_args = param.num_args; + cparam.dim = 1; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + ConcatOp op; + op.Init(cparam); + op.Backward(ctx, inputs[0], req, data); + }); +} + template void NumpyVstackForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index ca89ddaa3178..3967cde91d2a 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -690,6 +690,152 @@ Examples:: .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack") .add_arguments(StackParam::__FIELDS__()); +bool NumpyColumnStackType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), param.num_args); + CHECK_EQ(out_type->size(), 1); + int dtype = -1; + for (int i = 0; i < param.num_args; i++) { + if (dtype == -1) { + dtype = in_type->at(i); + } + } + if (dtype == -1) { + dtype = out_type->at(0); + } + for (int i = 0; i < param.num_args; i++) { + TYPE_ASSIGN_CHECK(*in_type, i, dtype); + } + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return dtype != -1; +} + +bool NumpyColumnStackShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(out_attrs->size(), 1U); + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), param.num_args); + std::vector in_attrs_tmp(param.num_args); + TShape dshape; + // For each array in the input, reshape to 2D if ndim < 2. + for (int i = 0; i < param.num_args; i++) { + if ((*in_attrs)[i].ndim() == 0) { + in_attrs_tmp[i] = TShape(2, 1); + } else if ((*in_attrs)[i].ndim() == 1) { + // Transpose 1D row into a column. + in_attrs_tmp[i] = TShape(2, 1); + in_attrs_tmp[i][0] = (*in_attrs)[i][0]; + } else { + in_attrs_tmp[i] = (*in_attrs)[i]; + } + TShape tmp(in_attrs_tmp[i].ndim(), -1); + shape_assign(&dshape, tmp); + } + TShape tmp((*out_attrs)[0].ndim(), -1); + shape_assign(&dshape, tmp); + for (int i = 0; i < param.num_args; i++) { + SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape) + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape) + if (dshape.ndim() == -1) { + return false; + } + // Accumulate along column axis. + int cnt = 0, sum = 0, pos = -1; + for (int i = 0; i < param.num_args; i++) { + TShape tmp = in_attrs_tmp[i]; + if (!dim_size_is_known(tmp, 1)) { + cnt++; + pos = i; + } else { + sum += tmp[1]; + } + tmp[1] = -1; + shape_assign(&dshape, tmp); + } + tmp = out_attrs->at(0); + if (!dim_size_is_known(tmp, 1)) { + cnt++; + pos = -1; + } else { + sum += tmp[1]; + } + tmp[1] = -1; + shape_assign(&dshape, tmp); + for (int i = 0; i < param.num_args; i++) { + SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape) + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape)\ + dshape[1] = 0; + if (!shape_is_known(dshape)) { + return false; + } + dshape[1] = sum; + if (cnt == 0) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); + } else if (cnt == 1) { + // Infer missing dimension if only one column dimension of the input is missing + if (pos >= 0) { + in_attrs_tmp[pos][1] = out_attrs->at(0)[1] - sum; + } else { + out_attrs->at(0)[1] = sum; + } + } else { + return false; + } + for (int i = 0; i < param.num_args; i++) { + if (in_attrs->at(i).ndim() == 1) { + in_attrs->at(i)[0] = in_attrs_tmp[i][1]; + } else if (in_attrs->at(i).ndim() >= 2) { + in_attrs->at(i) = in_attrs_tmp[i]; + } + } + + return true; +} + +DMLC_REGISTER_PARAMETER(NumpyColumnStackParam); + +NNVM_REGISTER_OP(_npi_column_stack) +.describe(R"code()code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const NumpyColumnStackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); +}) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + int num_args = dmlc::get(attrs.parsed).num_args; + std::vector ret; + for (int i = 0; i < num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + return ret; + }) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferShape", NumpyColumnStackShape) +.set_attr("FInferType", NumpyColumnStackType) +.set_attr("FCompute", NumpyColumnStackForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_column_stack"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to column_stack") +.add_arguments(NumpyColumnStackParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_column_stack) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const NumpyColumnStackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); +}) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyColumnStackBackward); + +DMLC_REGISTER_PARAMETER(NumpyVstackParam); + bool NumpyVstackType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { @@ -795,8 +941,6 @@ bool NumpyVstackShape(const nnvm::NodeAttrs& attrs, return true; } -DMLC_REGISTER_PARAMETER(NumpyVstackParam); - NNVM_REGISTER_OP(_npi_vstack) .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 36c2b6518009..7ca205565413 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -59,6 +59,12 @@ NNVM_REGISTER_OP(_npi_dstack) NNVM_REGISTER_OP(_backward_np_dstack) .set_attr("FCompute", DStackGradCompute); +NNVM_REGISTER_OP(_npi_column_stack) +.set_attr("FCompute", NumpyColumnStackForward); + +NNVM_REGISTER_OP(_backward_np_column_stack) +.set_attr("FCompute", NumpyColumnStackBackward); + NNVM_REGISTER_OP(_np_roll) .set_attr("FCompute", NumpyRollCompute); diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index b448261f215d..d5fd351986e3 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1376,21 +1376,12 @@ class RNNOp { seed_)); // RNN descriptors - cudnnDataType_t dtype_with_fallback_; + // adopt pseudo-fp16 for all architectures + cudnnDataType_t dtype_with_fallback_ = + (cudnnGetVersion() >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT + : dtype_; cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional; - // On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute. - // Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell. - // That's not the case begining from 7.5.0. Thereby adding fallback explicitly here. -#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500 - if (dtype_ == CUDNN_DATA_HALF) { - dtype_with_fallback_ = CUDNN_DATA_FLOAT; - } else { - dtype_with_fallback_ = dtype_; - } -#else - dtype_with_fallback_ = dtype_; -#endif CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, rnn_desc_, param_.state_size, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index b7625fccf258..188ccd68a340 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -451,7 +451,7 @@ void CastCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, { Tensor out = outputs[0].FlatTo1D(s); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); if (outputs[0].type_flag_ != inputs[0].type_flag_ || req[0] != kWriteInplace) { diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 9f2c8807ebb6..0b272d2633ab 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -112,6 +112,7 @@ def _add_workload_copy(): def _add_workload_expand_dims(): OpArgMngr.add_workload('expand_dims', np.random.uniform(size=(4, 1)), -1) + OpArgMngr.add_workload('expand_dims', np.random.uniform(size=(4, 1)) > 0.5, -1) for axis in range(-5, 4): OpArgMngr.add_workload('expand_dims', np.empty((2, 3, 4, 5)), axis) @@ -871,8 +872,8 @@ def _signs(dt): # test_float_remainder_corner_cases # Check remainder magnitude. for ct in _FLOAT_DTYPES: - b = _np.array(1.0) - a = np.array(_np.nextafter(_np.array(0.0), -b), dtype=ct) + b = _np.array(1.0, dtype=ct) + a = np.array(_np.nextafter(_np.array(0.0, dtype=ct), -b), dtype=ct) b = np.array(b, dtype=ct) OpArgMngr.add_workload('remainder', a, b) OpArgMngr.add_workload('remainder', -a, -b) @@ -1039,6 +1040,11 @@ def _add_workload_vstack(array_pool): OpArgMngr.add_workload('vstack', array_pool['4x1']) OpArgMngr.add_workload('vstack', array_pool['1x1x0']) +def _add_workload_column_stack(): + OpArgMngr.add_workload('column_stack', (np.array([1, 2, 3]), np.array([2, 3, 4]))) + OpArgMngr.add_workload('column_stack', (np.array([[1], [2], [3]]), np.array([[2], [3], [4]]))) + OpArgMngr.add_workload('column_stack', [np.array(_np.arange(3)) for _ in range(2)]) + def _add_workload_equal(array_pool): # TODO(junwu): fp16 does not work yet with TVM generated ops @@ -1104,6 +1110,29 @@ def _add_workload_nonzero(): OpArgMngr.add_workload('nonzero', np.array([True, False, False], dtype=np.bool_)) +def _add_workload_diff(): + x = np.array([1, 4, 6, 7, 12]) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, 2) + OpArgMngr.add_workload('diff', x, 3) + OpArgMngr.add_workload('diff', np.array([1.1, 2.2, 3.0, -0.2, -0.1])) + x = np.zeros((10, 20, 30)) + x[:, 1::2, :] = 1 + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, axis=-1) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, axis=1) + OpArgMngr.add_workload('diff', x, axis=-2) + x = 20 * np.random.uniform(size=(10,20,30)) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, n=2) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, n=2, axis=0) + x = np.array([list(range(3))]) + for n in range(1, 5): + OpArgMngr.add_workload('diff', x, n=n) + + @use_np def _prepare_workloads(): array_pool = { @@ -1204,12 +1233,14 @@ def _prepare_workloads(): _add_workload_logical_not(array_pool) _add_workload_vdot() _add_workload_vstack(array_pool) + _add_workload_column_stack() _add_workload_equal(array_pool) _add_workload_not_equal(array_pool) _add_workload_greater(array_pool) _add_workload_greater_equal(array_pool) _add_workload_less(array_pool) _add_workload_less_equal(array_pool) + _add_workload_diff() _prepare_workloads() diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 6077f4df13ae..239f300e028e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -18,6 +18,7 @@ # pylint: skip-file from __future__ import absolute_import from __future__ import division +import itertools import os import unittest import numpy as _np @@ -87,6 +88,7 @@ def test_np_array_creation(): [], (), [[1, 2], [3, 4]], + _np.random.randint(-10, 10, size=rand_shape_nd(3)), _np.random.uniform(size=rand_shape_nd(3)), _np.random.uniform(size=(3, 0, 4)) ] @@ -94,10 +96,12 @@ def test_np_array_creation(): for src in objects: mx_arr = np.array(src, dtype=dtype) assert mx_arr.ctx == mx.current_context() + if dtype is None: + dtype = src.dtype if isinstance(src, _np.ndarray) else _np.float32 if isinstance(src, mx.nd.NDArray): - np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) + np_arr = _np.array(src.asnumpy(), dtype=dtype) else: - np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32) + np_arr = _np.array(src, dtype=dtype) assert mx_arr.dtype == np_arr.dtype assert same(mx_arr.asnumpy(), np_arr) @@ -471,9 +475,6 @@ def test_np_grad_ndarray_type(): @with_seed() @use_np def test_np_ndarray_astype(): - mx_data = np.array([2, 3, 4, 5], dtype=_np.int32) - np_data = mx_data.asnumpy() - class TestAstype(HybridBlock): def __init__(self, dtype, copy): super(TestAstype, self).__init__() @@ -483,24 +484,29 @@ def __init__(self, dtype, copy): def hybrid_forward(self, F, x): return x.astype(dtype=self._dtype, copy=self._copy) - def check_astype_equal(dtype, copy, expect_zero_copy=False, hybridize=False): - test_astype = TestAstype(dtype, copy) + def check_astype_equal(itype, otype, copy, expect_zero_copy=False, hybridize=False): + expect_zero_copy = copy is False and itype == otype + mx_data = np.array([2, 3, 4, 5], dtype=itype) + np_data = mx_data.asnumpy() + test_astype = TestAstype(otype, copy) if hybridize: test_astype.hybridize() mx_ret = test_astype(mx_data) assert type(mx_ret) is np.ndarray - np_ret = np_data.astype(dtype=dtype, copy=copy) + np_ret = np_data.astype(dtype=otype, copy=copy) assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) if expect_zero_copy and not hybridize: assert id(mx_ret) == id(mx_data) assert id(np_ret) == id(np_data) - for dtype in [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, - 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool']: + dtypes = [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, + 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool'] + + for itype, otype in itertools.product(dtypes, dtypes): for copy in [True, False]: for hybridize in [True, False]: - check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype, hybridize) + check_astype_equal(itype, otype, copy, hybridize) @with_seed() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ab1eacf58889..029e7946f2c3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -442,7 +442,7 @@ def is_int(dtype): for axis in ([i for i in range(in_data_dim)] + [(), None]): for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']: for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: - if (is_int(dtype) and not is_int(itype))\ + if (is_int(dtype) and not is_int(itype)) or (is_windows and is_int(itype))\ or (itype == 'bool' and\ (dtype not in ('float32', 'float64', 'int32', 'int64') or is_windows)): continue @@ -2403,6 +2403,16 @@ def hybrid_forward(self, F, a): np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + for shape in shapes: + for axis in [None] + [i for i in range(0, len(shape))]: + for otype in [None, _np.int32, _np.int64]: + for itype in [_np.bool, _np.int8, _np.int32, _np.int64]: + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype) + mx_out = np.cumsum(x, axis=axis, dtype=otype) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + @with_seed() @use_np @@ -3746,6 +3756,61 @@ def get_new_shape(shape, axis): np_out = _np.append(a.asnumpy(), b.asnumpy(), axis=axis) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + +@with_seed() +@use_np +def test_np_column_stack(): + class TestColumnStack(HybridBlock): + def __init__(self): + super(TestColumnStack, self).__init__() + + def hybrid_forward(self, F, a, *args): + return F.np.column_stack([a] + list(args)) + + def g(data): + return _np.ones_like(data) + + configs = [ + ((), (), ()), + ((2), (2), (2)), + ((0), (0), (0)), + ((0, 3, 0), (0, 0, 0), (0, 1, 0)), + ((2, 2), (2, 1), (2, 3)), + ((4, 3), (4, 0), (4, 1)), + ((2, 2, 2), (2, 4, 2), (2, 2, 2)), + ((0, 1, 1), (0, 1, 1), (0, 1, 1)) + ] + types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + for config, hybridize, dtype in itertools.product(configs, [True, False], types): + test_column_stack = TestColumnStack() + if hybridize: + test_column_stack.hybridize() + rtol = 1e-3 + atol = 1e-5 + v = [] + v_np = [] + for i in range(3): + v_np.append(_np.array(_np.random.uniform(-10.0, 10.0, config[i]), dtype=dtype)) + v.append(mx.nd.array(v_np[i]).as_np_ndarray()) + v[i].attach_grad() + expected_np = _np.column_stack(v_np) + with mx.autograd.record(): + mx_out = test_column_stack(*v) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + for i in range(3): + expected_grad = g(v_np[i]) + assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.column_stack(v) + expected_np = _np.column_stack(v_np) + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_npx_reshape(): @@ -3832,6 +3897,112 @@ def test_np_share_memory(): assert not op(np.ones((5, 0), dtype=dt), np.ones((0, 3, 0), dtype=adt)) +@with_seed() +@use_np +def test_np_diff(): + def np_diff_backward(ograd, n, axis): + res = ograd + for i in range(n): + res = _np.negative(_np.diff(res, n=1, axis=axis, prepend=0, append=0)) + return res + + class TestDiff(HybridBlock): + def __init__(self, n=1, axis=-1): + super(TestDiff, self).__init__() + self._n = n + self._axis = axis + + def hybrid_forward(self, F, a): + return F.np.diff(a, n=self._n, axis=self._axis) + + shapes = [tuple(random.randrange(10) for i in range(random.randrange(6))) for j in range(5)] + for hybridize in [True, False]: + for shape in shapes: + for axis in [i for i in range(-len(shape), len(shape))]: + for n in [i for i in range(0, shape[axis]+1)]: + test_np_diff = TestDiff(n=n, axis=axis) + if hybridize: + test_np_diff.hybridize() + for itype in [_np.float16, _np.float32, _np.float64]: + # note the tolerance shall be scaled by the input n + if itype == _np.float16: + rtol = atol = 1e-2*len(shape)*n + else: + rtol = atol = 1e-5*len(shape)*n + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + x.attach_grad() + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + with mx.autograd.record(): + mx_out = test_np_diff(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + if (np_out.size == 0): + np_backward = _np.zeros(shape) + else: + np_backward = np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis) + assert x.grad.shape == np_backward.shape + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + mx_out = np.diff(x, n=n, axis=axis) + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + +@with_seed() +@use_np +def test_np_column_stack(): + class TestColumnStack(HybridBlock): + def __init__(self): + super(TestColumnStack, self).__init__() + + def hybrid_forward(self, F, a, *args): + return F.np.column_stack([a] + list(args)) + + def g(data): + return _np.ones_like(data) + + configs = [ + ((), (), ()), + ((2), (2), (2)), + ((0), (0), (0)), + ((0, 3, 0), (0, 0, 0), (0, 1, 0)), + ((2, 2), (2, 1), (2, 3)), + ((4, 3), (4, 0), (4, 1)), + ((2, 2, 2), (2, 4, 2), (2, 2, 2)), + ((0, 1, 1), (0, 1, 1), (0, 1, 1)) + ] + types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + for config, hybridize, dtype in itertools.product(configs, [True, False], types): + test_column_stack = TestColumnStack() + if hybridize: + test_column_stack.hybridize() + rtol = 1e-3 + atol = 1e-5 + v = [] + v_np = [] + for i in range(3): + v_np.append(_np.array(_np.random.uniform(-10.0, 10.0, config[i]), dtype=dtype)) + v.append(mx.nd.array(v_np[i]).as_np_ndarray()) + v[i].attach_grad() + expected_np = _np.column_stack(v_np) + with mx.autograd.record(): + mx_out = test_column_stack(*v) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + for i in range(3): + expected_grad = g(v_np[i]) + assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.column_stack(v) + expected_np = _np.column_stack(v_np) + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()