Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
NumPy-compatible std and var
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Aug 31, 2019
1 parent 3baa6eb commit 24db38d
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 6 deletions.
16 changes: 15 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2201,3 +2201,17 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable
array(0.55)
"""
return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)


@set_module('mxnet.ndarray.numpy')
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
"""
"""
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.ndarray.numpy')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
"""
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
20 changes: 16 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -1172,11 +1172,9 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disa
"""Returns the average of the array elements along given axis."""
raise NotImplementedError

# TODO(junwu): Use mxnet std op instead of onp.std
def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=arguments-differ
"""Returns the standard deviation of the array elements along given axis."""
ret_np = self.asnumpy().std(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims)
return array(ret_np, dtype=ret_np.dtype, ctx=self.context)
return _mx_np_op.std(self, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)

def cumsum(self, axis=None, dtype=None, out=None):
"""Return the cumulative sum of the elements along the given axis."""
Expand Down Expand Up @@ -3644,3 +3642,17 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable
array(0.55)
"""
return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)


@set_module('mxnet.numpy')
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
"""
"""
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.numpy')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
"""
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
16 changes: 15 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']


def _num_outputs(sym):
Expand Down Expand Up @@ -2569,4 +2569,18 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable
return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)


@set_module('mxnet.symbol.numpy')
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
"""
"""
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.symbol.numpy')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
"""
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


_set_np_symbol_class(_Symbol)
134 changes: 134 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <algorithm>
#include <vector>
#include "../nn/moments-inl.h"
#include "../tensor/broadcast_reduce_op.h"

namespace mxnet {
Expand Down Expand Up @@ -230,6 +231,139 @@ void NumpyReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, outputs);
}

struct NumpyMomentsParam : public dmlc::Parameter<NumpyMomentsParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
dmlc::optional<int> dtype;
bool keepdims;
int ddof;
DMLC_DECLARE_PARAMETER(NumpyMomentsParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
"all of the elements of the input array. If axis is negative it counts from the "
"last to the first axis.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("int8", mshadow::kInt8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.set_default(dmlc::optional<int>())
.describe("The type of the returned array and of the accumulator in which the elements are "
"summed. The dtype of a is used by default unless a has an integer dtype of less "
"precision than the default platform integer. In that case, if a is signed then "
"the platform integer is used while if a is unsigned then an unsigned integer of "
"the same precision as the platform integer is used.");
DMLC_DECLARE_FIELD(ddof).set_default(0)
.describe("Starting value for the sum.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
}
};

template<typename xpu, typename reducer, bool safe_acc, bool normalize = false,
typename OP = op::mshadow_op::identity>
void ReduceAxesComputeWithWorkspaceImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
mshadow::Tensor<xpu, 1, char>& workspace,
const mxnet::TShape& src_shape,
const mxnet::TShape& dst_shape,
const int ddof = 0) {
using namespace mshadow;
using namespace mshadow::expr;

Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
s, out_data, req[0], workspace, in_data);
if (normalize) {
auto out = out_data.FlatTo2D<xpu, OType>(s);
out /= scalar<OType>(src_shape.Size()/dst_shape.Size() - ddof);
}
});
});
});
}

template<typename xpu, bool sqrt>
void NumpyMomentsForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;

CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(req.size(), 2U);
CHECK_EQ(outputs.size(), 2U);

const NumpyMomentsParam& param = nnvm::get<NumpyMomentsParam>(attrs.parsed);

Stream<xpu> *s = ctx.get_stream<xpu>();

const TBlob& data = inputs[0];
const TBlob& moment = outputs[0];
const TBlob& mean = outputs[1];

mxnet::TShape small;
if (param.keepdims) {
small = moment.shape_;
} else {
small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true);
}

mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(data.shape_, small, &src_shape, &dst_shape);

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
// Get workspace and temp space for data - mean
size_t workspace_size = 0;
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
workspace_size = broadcast::ReduceWorkspaceSize<NDim, DType>(
s, dst_shape, req[0], src_shape);;
});
size_t temp_data_size = data.shape_.Size() * sizeof(DType);
size_t temp_mem_size = temp_data_size + workspace_size;
Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
char *workspace_ptr = temp_mem.dptr_ + temp_data_size;
Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
// Compute mean
ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true, true>(
ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape);
// Compute data - mean
Shape<6> data_shape, mean_shape;
for (int i = 0; i < 6; ++i) {
data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
mean_shape[i] = (i < small.ndim()) ? small[i] : 1;
}
Kernel<VarBroadcastKernel, xpu>::Launch(s, data_shape.Size(), temp_data_ptr,
data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
Tensor<xpu, 1, DType> temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s);
TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_);
ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true, true>(
ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof);
if (sqrt) {
Tensor<xpu, 1, OType> moment_tensor = moment.FlatTo1D<xpu, OType>(s);
moment_tensor = F<mshadow_op::square_root>(moment_tensor);
}
});
});
}

template<typename xpu>
void NumpyBroadcastToForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
89 changes: 89 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
DMLC_REGISTER_PARAMETER(NumpyMomentsParam);

inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
Expand Down Expand Up @@ -153,6 +154,94 @@ NNVM_REGISTER_OP(_backward_np_mean)
.set_num_inputs(1)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu, true>);

inline bool NumpyMomentsShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 2U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
const NumpyMomentsParam& param = nnvm::get<NumpyMomentsParam>(attrs.parsed);
mxnet::TShape out_shape = NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_shape);

return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1));
}

inline bool NumpyMomentsType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 2U);
const NumpyMomentsParam &param = nnvm::get<NumpyMomentsParam>(attrs.parsed);

if (param.dtype.has_value()) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
}
TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_npi_std)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr_parser(ParamParser<NumpyMomentsParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyMomentsShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMomentsType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"std", "mean"};
})
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs) {
return 1;
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyMomentsParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyMomentsForward<cpu, true>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

NNVM_REGISTER_OP(_npi_var)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr_parser(ParamParser<NumpyMomentsParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyMomentsShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMomentsType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"var", "mean"};
})
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs) {
return 1;
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyMomentsParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyMomentsForward<cpu, false>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ NNVM_REGISTER_OP(_npi_mean)
NNVM_REGISTER_OP(_backward_np_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);

NNVM_REGISTER_OP(_npi_std)
.set_attr<FCompute>("FCompute<gpu>", NumpyMomentsForward<gpu, true>);

NNVM_REGISTER_OP(_npi_var)
.set_attr<FCompute>("FCompute<gpu>", NumpyMomentsForward<gpu, false>);

NNVM_REGISTER_OP(_np_broadcast_to)
.set_attr<FCompute>("FCompute<gpu>", NumpyBroadcastToForward<gpu>);

Expand Down
Loading

0 comments on commit 24db38d

Please sign in to comment.