Skip to content

Commit

Permalink
Cherry-pick of apache#17995 and apache#17937 to 1.x branch (apache#18041
Browse files Browse the repository at this point in the history
)

* Fix ElemwiseSum for more than 4 inputs (apache#17995)

* Fix ElemwiseSum for more than 4 inputs

* Added test

* Fix for handling negative indices in the fusion of slice (apache#17937)

* Fix for handling of negative axis, begin and end in fusion of slice ops

* Added test
  • Loading branch information
ptrendx authored Apr 18, 2020
1 parent 3835139 commit 814530d
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ __device__ inline VectorType<DType, nvec> load_slice(const DType * input, const
strides[ndim-1] = 1;
#pragma unroll
for (int dim = ndim-1; dim >=0; dim--) {
if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim];
if (end[dim] < 0) end[dim] = shape[dim] - end[dim];
if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
if (end[dim] == INT_MAX) end[dim] = shape[dim];
if (dim > 0) {
ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
Expand Down Expand Up @@ -442,8 +442,8 @@ __device__ inline VectorType<DType, nvec> fast_load_slice(const DType * input,
strides[ndim-1] = 1;
#pragma unroll
for (int dim = ndim-1; dim >=0; dim--) {
if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim];
if (end[dim] < 0) end[dim] = shape[dim] - end[dim];
if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
if (end[dim] == INT_MAX) end[dim] = shape[dim];
if (dim > 0) {
ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
Expand Down
7 changes: 7 additions & 0 deletions src/operator/fusion/fused_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
return out;
};
auto build_tuple = [ndim](int axis, const std::string str, const std::string def) {
if (axis < 0 &&
axis >= -ndim) {
axis += ndim;
}
if (axis < 0 || axis >= ndim) {
LOG(FATAL) << "Axis " << axis << " is out of bounds for array of dimension " << ndim;
}
std::string tuple = "{";
for (int i = 0; i < axis; i++) {
tuple = tuple + def + ",";
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void ElementWiseSumCompute_(const nnvm::NodeAttrs& attrs,
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr);
for (size_t i = 1; i < size; ++i) {
DType* in_dptr = in_data[i].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], out_dptr, in_dptr);
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, kWriteTo, out_dptr, in_dptr);
}
break;
}
Expand Down
16 changes: 15 additions & 1 deletion tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def check_other_ops():
b = mx.sym.Variable('b')
c = mx.sym.Variable('c')
shape = rand_shape_2d()
shape = (5,) + shape
shape = list((5,) + shape)
# Make sure there is at least 2 elements for the test with negative indices
shape[1] += 1
shape[2] += 1
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)
arr3 = mx.random.uniform(shape=shape)
Expand All @@ -197,6 +200,9 @@ def check_other_ops():

check_fused_symbol(mx.sym.slice_axis(a, axis=0, begin=1, end=4), a=arr1)

# Testing handling of negative axis
check_fused_symbol(mx.sym.slice_axis(a, axis=-3, begin=1, end=4), a=arr1)

begin = (random.randint(0, shape[0]-1),
random.randint(0, shape[1]-1),
random.randint(0, shape[2]-1))
Expand All @@ -205,6 +211,14 @@ def check_other_ops():
random.randint(begin[2]+1, shape[2]))
check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)

begin = (random.randint(-shape[0], -2),
random.randint(-shape[1], -2),
random.randint(-shape[2], -2))
end = (random.randint(begin[0]+1, -1),
random.randint(begin[1]+1, -1),
random.randint(begin[2]+1, -1))
check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)

arr1 = mx.random.uniform(shape=(2,3,4,5))
arr2 = mx.random.uniform(shape=(1,2,3))
check_fused_symbol(mx.sym.slice_like(a,b, axes=[-2, 0]), a=arr1, b=arr2)
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9875,6 +9875,25 @@ def test_grad(input_shape, kernel, stride=1, dilate=1, pad=0):
pad = 1
)

def test_elemwise_sum_for_gradient_accumulation():
for nrepeat in range(1, 10):
stored_grad = dict()
for grad_req in ['write', 'add']:
a = mx.nd.array([1])
b = mx.nd.array([2])
if grad_req == 'write':
a.attach_grad(grad_req='write')
elif grad_req == 'add':
a.attach_grad(grad_req='add')
a.grad[:] = 0
with mx.autograd.record():
for _ in range(nrepeat):
b = b * a
b.backward()
stored_grad[grad_req] = a.grad.asscalar()
assert stored_grad['write'] == stored_grad['add']
assert stored_grad['write'] == 2 * nrepeat


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 814530d

Please sign in to comment.