Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy-compatible mean
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed May 4, 2019
1 parent e791cbe commit c9007ae
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Tensor<xpu, 1, IType> igrad = outputs[0].FlatTo1D<xpu, IType>(s);
printf("output size: %lu input_size: %lu\n", outputs[0].Size(), inputs[0].Size());
igrad /= scalar<IType>(outputs[0].Size()/inputs[0].Size());
});
}
Expand Down
51 changes: 51 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,56 @@ NNVM_REGISTER_OP(_backward_numpy_sum)
.set_num_inputs(1)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu>);

inline bool IsIntType(const int dtype) {
return (dtype >= 3);
}

inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);

if (param.dtype.has_value()) {
if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) {
LOG(FATAL) << "Output cannot be float type when input is integer type for now";
}
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
}

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_numpy_mean)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMeanType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true, true>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_mean"});

NNVM_REGISTER_OP(_backward_numpy_mean)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_num_inputs(1)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu, true>);

} // namespace op
} // namespace mxnet
8 changes: 8 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_numpy_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true>);

NNVM_REGISTER_OP(_backward_numpy_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);

NNVM_REGISTER_OP(_numpy_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true, true>);

NNVM_REGISTER_OP(_backward_numpy_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);


} // namespace op
} // namespace mxnet
64 changes: 63 additions & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
@with_seed()
def test_np_sum():
class TestSum(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):# , initial=None):
def __init__(self, axis=None, dtype=None, keepdims=False):
super(TestSum, self).__init__()
self._axis = axis
self._dtype = dtype
Expand Down Expand Up @@ -87,6 +87,68 @@ def is_int(dtype):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@mx.use_np_compat
@with_seed()
def test_np_mean():
class TestMean(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):
super(TestMean, self).__init__()
self._axis = axis
self._dtype = dtype
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.numpy.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)

def is_int(dtype):
return 'int' in dtype

in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64']:
for dtype in ['float16', 'float32', 'float64']:
print(itype, dtype)
if is_int(dtype) and not is_int(itype):
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
if is_int(itype):
x = _np.random.randint(-128, 128, shape, dtype=itype)
x = mx.nd.array(x, dtype=itype)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x.attach_grad()
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
with mx.autograd.record():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)

y.backward()
N = x.size / y.size
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)

# test numeric
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x")
mx_sym = mx.sym.numpy.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims)
check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)

# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit c9007ae

Please sign in to comment.