Skip to content

Commit

Permalink
numpy-compatible cumsum (apache#15309)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jul 26, 2019
1 parent a88ebd5 commit f2ceddd
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 0 deletions.
184 changes: 184 additions & 0 deletions src/operator/numpy/np_cumsum-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_cumsum-inl.h
* \brief Function definition of numpy-compatible cumsum operator
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_CUMSUM_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_CUMSUM_INL_H_

#include <mxnet/base.h>
#include <mxnet/operator_util.h>
#include <vector>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

struct CumsumParam : public dmlc::Parameter<CumsumParam> {
dmlc::optional<int> axis;
dmlc::optional<int> dtype;
DMLC_DECLARE_PARAMETER(CumsumParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<int>())
.describe("Axis along which the cumulative sum is computed."
" The default (None) is to compute the cumsum over the flattened array.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("int8", mshadow::kInt8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.set_default(dmlc::optional<int>())
.describe("Type of the returned array and of the accumulator in which the elements"
" are summed. If dtype is not specified, it defaults to the dtype of a,"
" unless a has an integer dtype with a precision less than that of the"
" default platform integer. In that case, the default platform integer is used.");
}
};

struct cumsum_forward {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int i,
OType *out,
const IType *in,
const int middle,
const int trailing) {
int left = i / trailing, right = i % trailing;
int offset = left * middle * trailing + right;
const IType *lane_in = in + offset;
OType *lane_out = out + offset;
lane_out[0] = OType(lane_in[0]);
for (int j = 1; j < middle; ++j) {
lane_out[j * trailing] = lane_out[(j - 1) * trailing] + OType(lane_in[j * trailing]);
}
}
};

template<typename xpu>
void CumsumForwardImpl(const OpContext& ctx,
const TBlob& in,
const TBlob& out,
const dmlc::optional<int>& axis) {
using namespace mshadow;
using namespace mxnet_op;

int middle = axis.has_value() ? out.shape_[axis.value()] : out.Size();
if (middle == 0 || out.Size() == 0) return;
int trailing = 1;
if (axis.has_value()) {
for (int i = axis.value() + 1; i < out.shape_.ndim(); ++i) {
trailing *= out.shape_[i];
}
}

Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(in.type_flag_, IType, {
MSHADOW_TYPE_SWITCH(out.type_flag_, OType, {
Kernel<cumsum_forward, xpu>::Launch(
s, out.Size() / middle, out.dptr<OType>(),
in.dptr<IType>(), middle, trailing);
});
});
}

template<typename xpu>
void CumsumForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);

CumsumForwardImpl<xpu>(ctx, inputs[0], outputs[0], param.axis);
}

struct cumsum_backward {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int i,
IType *igrad,
const OType *ograd,
const int middle,
const int trailing) {
int left = i / trailing, right = i % trailing;
int offset = left * middle * trailing + right;
const OType *lane_ograd = ograd + offset;
IType *lane_igrad = igrad + offset;
lane_igrad[(middle - 1) * trailing] = IType(lane_ograd[(middle - 1) * trailing]);
for (int j = middle - 2; j >= 0; --j) {
lane_igrad[j * trailing] = lane_igrad[(j + 1) * trailing] + IType(lane_ograd[j * trailing]);
}
}
};

template<typename xpu>
void CumsumBackwardImpl(const OpContext& ctx,
const TBlob& ograd,
const TBlob& igrad,
const dmlc::optional<int>& axis) {
using namespace mshadow;
using namespace mxnet_op;
int middle = axis.has_value() ? igrad.shape_[axis.value()] : igrad.Size();
if (middle == 0 || igrad.Size() == 0) return;
int trailing = 1;
if (axis.has_value()) {
for (int i = axis.value() + 1; i < igrad.shape_.ndim(); ++i) {
trailing *= igrad.shape_[i];
}
}
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(igrad.type_flag_, IType, {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
Kernel<cumsum_backward, xpu>::Launch(
s, igrad.Size() / middle, igrad.dptr<IType>(),
ograd.dptr<OType>(), middle, trailing);
});
});
}

template<typename xpu>
void CumsumBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);

CumsumBackwardImpl<xpu>(ctx, inputs[0], outputs[0], param.axis);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_CUMSUM_INL_H_
92 changes: 92 additions & 0 deletions src/operator/numpy/np_cumsum.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_cumsum.cc
* \brief CPU implementation of numpy-compatible cumsum operator
*/

#include "./np_cumsum-inl.h"

namespace mxnet {
namespace op {

inline bool CumsumShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);

if (param.axis.has_value()) {
return ElemwiseShape<1, 1>(attrs, in_attrs, out_attrs);
} else {
TShape out_shape(1, in_attrs->at(0).Size());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
return shape_is_known(out_attrs->at(0));
}
}

inline bool CumsumType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);

if (param.dtype.has_value()) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
}

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

DMLC_REGISTER_PARAMETER(CumsumParam);

NNVM_REGISTER_OP(_np_cumsum)
.set_attr_parser(ParamParser<CumsumParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.set_attr<mxnet::FInferShape>("FInferShape", CumsumShape)
.set_attr<nnvm::FInferType>("FInferType", CumsumType)
.set_attr<FCompute>("FCompute<cpu>", CumsumForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_cumsum"})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("a", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(CumsumParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_cumsum)
.set_attr_parser(ParamParser<CumsumParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", CumsumBackward<cpu>);

} // namespace op
} // namespace mxnet
37 changes: 37 additions & 0 deletions src/operator/numpy/np_cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_cumsum.cu
* \brief GPU implementation of numpy-compatible cumsum operator
*/

#include "./np_cumsum-inl.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_np_cumsum)
.set_attr<FCompute>("FCompute<gpu>", CumsumForward<gpu>);

NNVM_REGISTER_OP(_backward_np_cumsum)
.set_attr<FCompute>("FCompute<gpu>", CumsumBackward<gpu>);

} // namespace op
} // namespace mxnet
42 changes: 42 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,48 @@ def get_indices(axis_size):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@npx.use_np_shape
def test_np_cumsum():
def np_cumsum_backward(ograd, axis=None, dtype=None):
return _np.flip(_np.cumsum(_np.flip(ograd, axis=axis), axis=axis, dtype=dtype), axis=axis)

@npx.use_np_shape
class TestCumsum(HybridBlock):
def __init__(self, axis=None, dtype=None):
super(TestCumsum, self).__init__()
self._axis = axis
self._dtype = dtype

def hybrid_forward(self, F, a):
return F.np.cumsum(a, axis=self._axis, dtype=self._dtype)

shapes = [(2, 3, 4), (2, 0, 3), ()]
for hybridize in [True, False]:
for shape in shapes:
for axis in [None] + [i for i in range(0, len(shape))]:
for otype in [None, _np.float32, _np.float64]:
test_cumsum = TestCumsum(axis=axis, dtype=otype)
if hybridize:
test_cumsum.hybridize()
for itype in [_np.float16, _np.float32, _np.float64]:
x = rand_ndarray(shape).astype(itype).as_np_ndarray()
x.attach_grad()
np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype)
with mx.autograd.record():
mx_out = test_cumsum(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = np_cumsum_backward(_np.ones(np_out.shape, dtype=otype),
axis=axis, dtype=otype).reshape(x.shape)
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = np.cumsum(x, axis=axis, dtype=otype)
np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@npx.use_np_shape
def test_np_tile():
Expand Down

0 comments on commit f2ceddd

Please sign in to comment.