diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index d494f0882bbb..6ac586acd99c 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -379,7 +379,7 @@ void EvalRandom( template<> void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { mshadow::Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, { ret->FlatTo2D(s) = DType(rhs); }); } diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 4951f62457a7..e94cd96fe61f 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter { .add_enum("int8", mshadow::kInt8) .add_enum("int32", mshadow::kInt32) .add_enum("int64", mshadow::kInt64) + .add_enum("bool", mshadow::kBool) .set_default(dmlc::optional()) .describe("The type of the returned array and of the accumulator in which the elements are " "summed. The dtype of a is used by default unless a has an integer dtype of less " diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 435fe1df1134..fb133568a7a5 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs, const NumpyReduceAxesParam ¶m = nnvm::get(attrs.parsed); if (param.dtype.has_value()) { - if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) { - LOG(FATAL) << "Output cannot be float type when input is integer type for now"; - } TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); } else { - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + if (common::is_float(in_attrs->at(0))) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + } } return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 415059ac12a0..800dbffbcd38 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -245,7 +245,7 @@ void Reduce(Stream* s, const TBlob& small, const OpReqType req, #ifndef _WIN32 MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { typedef typename std::conditional::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, { typedef typename std::conditional::type OutType; seq_reduce_compute( N, M, req == kAddTo, big.dptr(), small.dptr(), @@ -255,7 +255,7 @@ void Reduce(Stream* s, const TBlob& small, const OpReqType req, #else MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { typedef typename std::conditional::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, { typedef typename std::conditional::type OutType; seq_reduce_compute( N, M, req == kAddTo, big.dptr(), small.dptr(), diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 414b606bbe0a..27e22491ca35 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx, BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { @@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape in_shape; mshadow::Shape out_shape; for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9aabdfd4cabc..10637edaa9c7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -617,49 +617,70 @@ def is_int(dtype): in_data_dim = random.choice([2, 3, 4]) shape = rand_shape_nd(in_data_dim, dim=3) acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64', - 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + 'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + ft_types = ['float16', 'float32', 'float64'] + it_types = ['bool', 'int8', 'int32', 'int64'] for hybridize in [False, True]: for keepdims in [True, False]: for axis in ([i for i in range(in_data_dim)] + [(), None]): - for itype in ['float16', 'float32', 'float64']: - for dtype in ['float16', 'float32', 'float64']: - if is_int(dtype) and not is_int(itype): - continue - # test gluon - test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims) - if hybridize: - test_mean.hybridize() - if is_int(itype): - x = _np.random.randint(-128, 128, shape, dtype=itype) - x = mx.nd.array(x, dtype=itype) - else: - x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) - x = x.as_np_ndarray() - x.attach_grad() + for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types): + if dtype == 'bool': + continue + # test gluon + test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims) + if hybridize: + test_mean.hybridize() + x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype) + x = x.as_np_ndarray() + x.attach_grad() - expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) - expected_ret = expected_ret.astype(dtype) - with mx.autograd.record(): - y = test_mean(x) - assert y.shape == expected_ret.shape - assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, - atol=1e-5 if dtype == 'float16' else 1e-5) + expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) + expected_ret = expected_ret.astype(dtype) + with mx.autograd.record(): + y = test_mean(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, + atol=1e-5 if dtype == 'float16' else 1e-5) - y.backward() - N = x.size / y.size - assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N) + y.backward() + N = x.size / y.size + assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N) - # test numeric - if itype == 'float32' and dtype == 'float32': - x_sym = mx.sym.Variable("x").as_np_ndarray() - mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray() - check_numeric_gradient(mx_sym, [x.as_nd_ndarray()], - numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) + # test numeric + if itype == 'float32' and dtype == 'float32': + x_sym = mx.sym.Variable("x").as_np_ndarray() + mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray() + check_numeric_gradient(mx_sym, [x.as_nd_ndarray()], + numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) - # test imperative - mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) - np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + # test imperative + mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) + np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types): + if dtype == 'bool': + continue + # test gluon + test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims) + if hybridize: + test_mean.hybridize() + + if itype == 'bool': + x = np.random.uniform(size=shape) > 0.5 + else: + x = np.random.uniform(-128, 127, size=shape).astype(itype) + + expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims) + y = test_mean(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, + atol=1e-5 if dtype == 'float16' else 1e-5) + + # test imperative + mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) + np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) @with_seed()