Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix numpy-compatible mean output type for integer inputs #16792

Merged
merged 2 commits into from
Nov 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ndarray/ndarray_function-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>(
template<>
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, {
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
});
}
Expand Down
12 changes: 10 additions & 2 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
.add_enum("int8", mshadow::kInt8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.add_enum("bool", mshadow::kBool)
.set_default(dmlc::optional<int>())
.describe("The type of the returned array and of the accumulator in which the elements are "
"summed. The dtype of a is used by default unless a has an integer dtype of less "
Expand Down Expand Up @@ -221,15 +222,15 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
if (req[0] == kNullOp) return;
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
Stream<xpu>* s = ctx.get_stream<xpu>();
if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
using namespace mxnet_op;
using namespace mshadow;
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr<DType>());
});
Expand All @@ -246,6 +247,13 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays";
}
TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name);
if (normalize) {
using namespace mshadow::expr;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
auto out = outputs[0].FlatTo2D<xpu, OType>(s);
out /= scalar<OType>(inputs[0].Size()/outputs[0].Size());
});
}
return;
}
#endif
Expand Down
11 changes: 6 additions & 5 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);

if (param.dtype.has_value()) {
if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) {
LOG(FATAL) << "Output cannot be float type when input is integer type for now";
}
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
if (common::is_float(in_attrs->at(0))) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
}
}

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
Expand Down
13 changes: 0 additions & 13 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
ReduceImplConfig<ndim> config =
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
if (safe_acc) {
// TODO(haojin2): Use real-only type swtich for windows temporarily due to CI issues.
#ifndef _WIN32
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
Expand All @@ -630,17 +628,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
stream, small, req, big, workspace, config);
});
});
#else
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
config = ConfigureReduceImpl<ndim, AccType>(small.shape_, big.shape_, NULL, NULL);
ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
stream, small, req, big, workspace, config);
});
});
#endif
} else {
ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config);
}
Expand Down
15 changes: 1 addition & 14 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,28 +241,15 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
} else {
// TODO(haojin2): Use real-only type swtich for windows temporarily due to CI issues.
#ifndef _WIN32
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
});
});
#else
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
});
});
#endif
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
Expand Down Expand Up @@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape;
for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) {
Expand Down
103 changes: 67 additions & 36 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,52 +614,83 @@ def hybrid_forward(self, F, a, *args, **kwargs):
def is_int(dtype):
return 'int' in dtype

is_windows = sys.platform.startswith('win')
in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
ft_types = ['float16', 'float32', 'float64']
it_types = ['bool', 'int8', 'int32', 'int64']
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64']:
for dtype in ['float16', 'float32', 'float64']:
if is_int(dtype) and not is_int(itype):
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
if is_int(itype):
x = _np.random.randint(-128, 128, shape, dtype=itype)
x = mx.nd.array(x, dtype=itype)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x.attach_grad()
for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types):
if dtype == 'bool':
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype)
x = x.as_np_ndarray()
x.attach_grad()

expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
with mx.autograd.record():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
with mx.autograd.record():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)

y.backward()
N = x.size / y.size
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)
y.backward()
N = x.size / y.size
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)

# test numeric
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
# test numeric
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)

# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types):
if dtype == 'bool':
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()

if itype == 'bool':
x = np.array(_np.random.uniform(size=shape) > 0.5)
else:
x = np.random.uniform(-128, 127, size=shape).astype(itype)

expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims)

if itype == 'bool':
if is_op_runnable() and (not is_windows) and dtype not in ['float16', 'int8']: # special handling of boolean ndarray
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)
continue

y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)

# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
Expand Down