Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix mean output type for integer inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 13, 2019
1 parent 2c02bff commit d306d28
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/ndarray/ndarray_function-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>(
template<>
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, {
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
});
}
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
.add_enum("int8", mshadow::kInt8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.add_enum("bool", mshadow::kBool)
.set_default(dmlc::optional<int>())
.describe("The type of the returned array and of the accumulator in which the elements are "
"summed. The dtype of a is used by default unless a has an integer dtype of less "
Expand Down
11 changes: 6 additions & 5 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);

if (param.dtype.has_value()) {
if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) {
LOG(FATAL) << "Output cannot be float type when input is integer type for now";
}
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
if (common::is_float(in_attrs->at(0))) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
}
}

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
#ifndef _WIN32
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
Expand All @@ -255,7 +255,7 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
#else
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
Expand Down Expand Up @@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape;
for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) {
Expand Down
93 changes: 57 additions & 36 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,49 +617,70 @@ def is_int(dtype):
in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
ft_types = ['float16', 'float32', 'float64']
it_types = ['bool', 'int8', 'int32', 'int64']
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64']:
for dtype in ['float16', 'float32', 'float64']:
if is_int(dtype) and not is_int(itype):
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
if is_int(itype):
x = _np.random.randint(-128, 128, shape, dtype=itype)
x = mx.nd.array(x, dtype=itype)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x.attach_grad()
for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types):
if dtype == 'bool':
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype)
x = x.as_np_ndarray()
x.attach_grad()

expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
with mx.autograd.record():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
with mx.autograd.record():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)

y.backward()
N = x.size / y.size
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)
y.backward()
N = x.size / y.size
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)

# test numeric
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
# test numeric
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)

# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types):
if dtype == 'bool':
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()

if itype == 'bool':
x = np.random.uniform(size=shape) > 0.5
else:
x = np.random.uniform(-128, 127, size=shape).astype(itype)

expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims)
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)

# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
Expand Down

0 comments on commit d306d28

Please sign in to comment.