diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index 0419e5fd2ca9..b828ac0f9834 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -27,6 +27,7 @@ def compute_add(dtype, ndim): s = tvm.create_schedule(C.op) return s, A, B, C + @defop(name="vadd", target="cpu", auto_broadcast=True, dtype=AllTypes, ndim=list(range(1, 6))) def vadd(dtype, ndim): @@ -37,6 +38,7 @@ def vadd(dtype, ndim): return s, [A, B, C] + @defop(name="cuda_vadd", target="cuda", auto_broadcast=True, dtype=["float32", "float64"], ndim=list(range(1, 6))) def vadd_gpu(dtype, ndim): @@ -48,3 +50,65 @@ def vadd_gpu(dtype, ndim): s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) return s, [A, B, C] + + +def reduce_axes(X, ishape, axes): + def get_index(idx, ridx): + j = 0 + k = 0 + ret = [] + for val in axes: + ret.append(idx[j] if val == 0 else ridx[k]) + j += (val == 0) + k += (val != 0) + return tuple(ret) + + if (len(ishape) == 2 and axes == [1, 0]): + r = tvm.reduce_axis((0, ishape[0]), "r") + oshape = [tvm.var()] + ret = tvm.compute(oshape, lambda i: tvm.sum(X[r, i], axis=r), name='ret') + else: + odim = (len(ishape) + 1 - axes[0]) // 2 + oshape = [tvm.var() for _ in range(odim)] + ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1] + ret = tvm.compute(oshape, lambda *idx: tvm.sum(X[get_index(idx, ridx)], axis=ridx), name='ret') + + return ret + + +def compute_backward_vadd(dtype, ndim, reduce1st): + axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim] + ishape = [tvm.var() for _ in range(ndim)] + odim = (len(ishape) + 1 - axes[0]) // 2 + oshape = [tvm.var() for _ in range(odim)] + X = tvm.placeholder(ishape, name='X', dtype=dtype) + ret = reduce_axes(X, ishape, axes) + s = tvm.create_schedule(ret.op) + return s, X, ret, [ret] + + +@defop(name="backward_vadd", target="cpu", dtype=AllTypes, + ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) +def backward_vadd(dtype, ndim, reduce1st): + s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) + for t in c_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [X, ret] + + +@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"], + ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) +def backward_vadd_gpu(dtype, ndim, reduce1st): + s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) + num_thread = 64 + for t in c_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_thread) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [X, ret] diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index 3475a211cfec..348a09e66ffe 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "../../tensor/elemwise_binary_broadcast_op.h" #include "../../tvmop/op_module.h" #include "../../tensor/elemwise_binary_op.h" @@ -37,6 +38,8 @@ namespace op { static constexpr char func_vadd_cpu[] = "vadd"; static constexpr char func_vadd_gpu[] = "cuda_vadd"; +static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd"; +static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd"; template void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs, @@ -49,17 +52,64 @@ void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs, tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]}); } +template +void TVMBackwardBroadcastCompute(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + int ndim = inputs[0].shape_.ndim(); + for (int k = 0; k < 2; ++k) { + std::vector ov, iv; + const TBlob& ograd = inputs[0], igrad = outputs[k]; + bool flag = ograd.size(0) != igrad.size(0); + for (int i = 0; i < ndim; ++i) { + if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) { + ov.push_back(ograd.size(i)); + } else { + ov.back() *= ograd.size(i); + } + } + for (int i = flag; i < ov.size(); i += 2) { + iv.push_back(ov[i]); + } + TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end()); + std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag); + TBlob ograd_tvm(ograd.reshape(oshape).dltensor()); + TBlob igrad_tvm(igrad.reshape(ishape).dltensor()); + tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm}); + } +} + NNVM_REGISTER_OP(_contrib_tvm_vadd) .set_num_inputs(2) .set_num_outputs(1) .add_argument("a", "NDArray-or-Symbol", "first input") .add_argument("b", "NDArray-or-Symbol", "second input") + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) .set_attr("FInferShape", BinaryBroadcastShape) .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) #if MXNET_USE_CUDA .set_attr("FCompute", mxnet::op::TVMBroadcastCompute) #endif // MXNET_USE_CUDA - .set_attr("FCompute", mxnet::op::TVMBroadcastCompute); + .set_attr("FCompute", mxnet::op::TVMBroadcastCompute) + .set_attr("FGradient", ElemwiseGradUseNone{"_backward_contrib_tvm_vadd"}); + +NNVM_REGISTER_OP(_backward_contrib_tvm_vadd) + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr("TIsBackward", true) +#if MXNET_USE_CUDA + .set_attr("FCompute", + mxnet::op::TVMBackwardBroadcastCompute) +#endif // MXNET_USE_CUDA + .set_attr("FCompute", + mxnet::op::TVMBackwardBroadcastCompute); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_tvm_op.py b/tests/python/unittest/test_tvm_op.py index 3ab2a25bc20e..7253ad9a40ce 100644 --- a/tests/python/unittest/test_tvm_op.py +++ b/tests/python/unittest/test_tvm_op.py @@ -16,6 +16,7 @@ # under the License. import mxnet as mx +import numpy as _np from mxnet.test_utils import same, rand_shape_nd from mxnet.runtime import Features from common import with_seed @@ -29,9 +30,18 @@ def test_tvm_broadcast_add(): b_shape = (1,) + a_shape[1:2] + (1, 1) a = mx.nd.normal(shape=a_shape) b = mx.nd.normal(shape=b_shape) - c = mx.nd.contrib.tvm_vadd(a, b) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + c = mx.nd.contrib.tvm_vadd(a, b) c_np = a.asnumpy() + b.asnumpy() assert same(c.asnumpy(), c_np) + c.backward() + expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size + expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size + assert same(a.grad.asnumpy(), expected_grad_a) + assert same(b.grad.asnumpy(), expected_grad_b) + if __name__ == '__main__': import nose