Skip to content

Commit

Permalink
fix add_n bug: when input mem overlap with output mem, results is wro…
Browse files Browse the repository at this point in the history
…ng (apache#14889)

* fix add_n bug: when input mem overlap with output mem, results is wrong

* add testcase for bugfix verification

* add more comments for modification and change testcase name to test_add_n
  • Loading branch information
rongzha1 authored and haohuw committed Jun 23, 2019
1 parent f2baf50 commit bcee862
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
using namespace mxnet::op::mxnet_op;
const TBlob& out_data = out->data();
MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type
Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
// Do not set_zero when output mem inplace with input[0] mem
// Now for add_n OP, output mem can be in-placed with the first input
if (nds[0].data().dptr<DType>() != out_data.dptr<DType>()) {
Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
}
for (size_t i = 0; i < nds.size(); ++i) {
const NDArray& nd = nds[i];
const TBlob& nd_data = nd.data();
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8357,6 +8357,18 @@ def check_concat(shape1, shape2, axis):
check_concat((8, 0, 0), (8, 0, 0), 2)


@with_seed()
def test_add_n():
data_shape = (2, 2)
input_num = 5
data = [mx.nd.random.uniform(shape=data_shape) for i in range(input_num)]
rslt = mx.nd.zeros(shape=data_shape)
for i in range(input_num):
rslt += data[i]
add_n_rslt = mx.nd.add_n(*data, out=data[0])
assert_almost_equal(rslt.asnumpy(), add_n_rslt.asnumpy(), atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit bcee862

Please sign in to comment.