From ef53f2723444f37951785bda6980b533fc1a2f5e Mon Sep 17 00:00:00 2001 From: Fan Date: Fri, 16 Aug 2019 15:07:24 +0800 Subject: [PATCH 1/3] tvm broadcast backward --- contrib/tvmop/basic/ufunc.py | 58 +++++++++++++++++++++++++ src/operator/contrib/tvmop/ufunc.cc | 64 +++++++++++++++++++++++++--- tests/python/unittest/test_tvm_op.py | 12 +++++- 3 files changed, 126 insertions(+), 8 deletions(-) diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index 0419e5fd2ca9..d67fb3d140a4 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -27,6 +27,7 @@ def compute_add(dtype, ndim): s = tvm.create_schedule(C.op) return s, A, B, C + @defop(name="vadd", target="cpu", auto_broadcast=True, dtype=AllTypes, ndim=list(range(1, 6))) def vadd(dtype, ndim): @@ -37,6 +38,7 @@ def vadd(dtype, ndim): return s, [A, B, C] + @defop(name="cuda_vadd", target="cuda", auto_broadcast=True, dtype=["float32", "float64"], ndim=list(range(1, 6))) def vadd_gpu(dtype, ndim): @@ -48,3 +50,59 @@ def vadd_gpu(dtype, ndim): s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) return s, [A, B, C] + + +def reduce_axes(X, axes, reducer): + def get_index(idx, ridx): + j = 0 + k = 0 + ret = [] + for val in axes: + ret.append(idx[j] if val == 0 else ridx[k]) + j += (val == 0) + k += (val != 0) + return tuple(ret) + + ishape = X.shape + odim = (len(ishape) + 1 - axes[0]) // 2 + oshape = [tvm.var() for _ in range(odim)] + ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1] + ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret') + return ret + + +def compute_backward_vadd(dtype, ndim, reduce1st): + axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim] + X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype) + reducer = tvm.comm_reducer(lambda x, y: x + y, + lambda t: tvm.const(0, dtype=t), name="sum") + ret = reduce_axes(X, axes, reducer) + s = tvm.create_schedule(ret.op) + return s, X, ret, [ret] + + +@defop(name="backward_vadd", target="cpu", dtype=AllTypes, + ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) +def backward_vadd(dtype, ndim, reduce1st): + s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) + for t in c_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [X, ret] + + +@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"], + ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) +def backward_vadd_gpu(dtype, ndim, reduce1st): + s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) + num_thread = 64 + for t in c_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_thread) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [X, ret] diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index 3475a211cfec..e6999e27b6a0 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "../../tensor/elemwise_binary_broadcast_op.h" #include "../../tvmop/op_module.h" #include "../../tensor/elemwise_binary_op.h" @@ -37,29 +38,78 @@ namespace op { static constexpr char func_vadd_cpu[] = "vadd"; static constexpr char func_vadd_gpu[] = "cuda_vadd"; +static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd"; +static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd"; template -void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs, - const mxnet::OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void TVMBinaryCompute(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]}); } +template +void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + int ndim = inputs[0].shape_.ndim(); + for (int k = 0; k < 2; ++k) { + std::vector ov, iv; + const TBlob& ograd = inputs[0], igrad = outputs[k]; + bool flag = ograd.size(0) != igrad.size(0); + for (int i = 0; i < ndim; ++i) { + if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) { + ov.push_back(ograd.size(i)); + } else { + ov.back() *= ograd.size(i); + } + } + for (int i = flag; i < ov.size(); i += 2) { + iv.push_back(ov[i]); + } + TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end()); + TBlob ograd_tvm(ograd.reshape(oshape).dltensor()); + TBlob igrad_tvm(igrad.reshape(ishape).dltensor()); + std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag); + tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm}); + } +} + NNVM_REGISTER_OP(_contrib_tvm_vadd) .set_num_inputs(2) .set_num_outputs(1) .add_argument("a", "NDArray-or-Symbol", "first input") .add_argument("b", "NDArray-or-Symbol", "second input") + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) .set_attr("FInferShape", BinaryBroadcastShape) .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) #if MXNET_USE_CUDA - .set_attr("FCompute", mxnet::op::TVMBroadcastCompute) + .set_attr("FCompute", mxnet::op::TVMBinaryCompute) +#endif // MXNET_USE_CUDA + .set_attr("FCompute", mxnet::op::TVMBinaryCompute) + .set_attr("FGradient", ElemwiseGradUseNone{"_backward_contrib_tvm_vadd"}); + +NNVM_REGISTER_OP(_backward_contrib_tvm_vadd) + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr("TIsBackward", true) +#if MXNET_USE_CUDA + .set_attr("FCompute", + mxnet::op::TVMBinaryBackwardComputeUseNone) #endif // MXNET_USE_CUDA - .set_attr("FCompute", mxnet::op::TVMBroadcastCompute); + .set_attr("FCompute", + mxnet::op::TVMBinaryBackwardComputeUseNone); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_tvm_op.py b/tests/python/unittest/test_tvm_op.py index 3ab2a25bc20e..7253ad9a40ce 100644 --- a/tests/python/unittest/test_tvm_op.py +++ b/tests/python/unittest/test_tvm_op.py @@ -16,6 +16,7 @@ # under the License. import mxnet as mx +import numpy as _np from mxnet.test_utils import same, rand_shape_nd from mxnet.runtime import Features from common import with_seed @@ -29,9 +30,18 @@ def test_tvm_broadcast_add(): b_shape = (1,) + a_shape[1:2] + (1, 1) a = mx.nd.normal(shape=a_shape) b = mx.nd.normal(shape=b_shape) - c = mx.nd.contrib.tvm_vadd(a, b) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + c = mx.nd.contrib.tvm_vadd(a, b) c_np = a.asnumpy() + b.asnumpy() assert same(c.asnumpy(), c_np) + c.backward() + expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size + expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size + assert same(a.grad.asnumpy(), expected_grad_a) + assert same(b.grad.asnumpy(), expected_grad_b) + if __name__ == '__main__': import nose From f373aa84ac74742d64c30761fcdb6fca31ce2bd3 Mon Sep 17 00:00:00 2001 From: Fan Date: Mon, 19 Aug 2019 14:09:19 +0800 Subject: [PATCH 2/3] dispatch by req --- contrib/tvmop/basic/ufunc.py | 36 ++++++++++++++++++---------- src/operator/contrib/tvmop/ufunc.cc | 12 +++++++++- tests/python/unittest/test_tvm_op.py | 16 ++++++++++++- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index d67fb3d140a4..d526e463412a 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -52,6 +52,15 @@ def vadd_gpu(dtype, ndim): return s, [A, B, C] +def assign_by_req(a, req): + b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype) + if (req == "kAddTo"): + c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx]) + else: + c = tvm.compute(a.shape, lambda *idx: a[idx]) + return b, c + + def reduce_axes(X, axes, reducer): def get_index(idx, ridx): j = 0 @@ -71,31 +80,34 @@ def get_index(idx, ridx): return ret -def compute_backward_vadd(dtype, ndim, reduce1st): +def compute_backward_vadd(dtype, ndim, reduce1st, req): axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim] X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype) reducer = tvm.comm_reducer(lambda x, y: x + y, lambda t: tvm.const(0, dtype=t), name="sum") ret = reduce_axes(X, axes, reducer) - s = tvm.create_schedule(ret.op) - return s, X, ret, [ret] + in_grad_a, in_grad = assign_by_req(ret, req) + s = tvm.create_schedule(in_grad.op) + return s, X, in_grad_a, in_grad, [ret, in_grad] -@defop(name="backward_vadd", target="cpu", dtype=AllTypes, - ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) -def backward_vadd(dtype, ndim, reduce1st): - s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) +@defop(name="backward_vadd", target="cpu", dtype=AllTypes, + ndim=list(range(1, 6)), reduce1st=[0, 1], + req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"]) +def backward_vadd(dtype, ndim, reduce1st, req): + s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req) for t in c_list: axes = [axis for axis in t.op.axis] fused = s[t].fuse(*axes) s[t].parallel(fused) - return s, [X, ret] + return s, [X, in_grad_a, in_grad] @defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"], - ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) -def backward_vadd_gpu(dtype, ndim, reduce1st): - s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) + ndim=list(range(1, 6)), reduce1st=[0, 1], + req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"]) +def backward_vadd_gpu(dtype, ndim, reduce1st, req): + s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req) num_thread = 64 for t in c_list: block_x = tvm.thread_axis("blockIdx.x") @@ -105,4 +117,4 @@ def backward_vadd_gpu(dtype, ndim, reduce1st): bx, tx = s[t].split(fused, factor=num_thread) s[t].bind(bx, block_x) s[t].bind(tx, thread_x) - return s, [X, ret] + return s, [X, in_grad_a, in_grad] diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index e6999e27b6a0..b4f3ab4bd317 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -62,6 +62,7 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 2U); int ndim = inputs[0].shape_.ndim(); for (int k = 0; k < 2; ++k) { + // dispatch by backward std::vector ov, iv; const TBlob& ograd = inputs[0], igrad = outputs[k]; bool flag = ograd.size(0) != igrad.size(0); @@ -79,7 +80,16 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, TBlob ograd_tvm(ograd.reshape(oshape).dltensor()); TBlob igrad_tvm(igrad.reshape(ishape).dltensor()); std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag); - tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm}); + // dispatch by req + funcname += "req_"; + MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, { + if (req_type == kWriteTo) { + funcname += "kWriteTo"; + } else { + funcname += "kAddTo"; + } + }) + tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm, igrad_tvm}); } } diff --git a/tests/python/unittest/test_tvm_op.py b/tests/python/unittest/test_tvm_op.py index 7253ad9a40ce..2126631077d4 100644 --- a/tests/python/unittest/test_tvm_op.py +++ b/tests/python/unittest/test_tvm_op.py @@ -36,12 +36,26 @@ def test_tvm_broadcast_add(): c = mx.nd.contrib.tvm_vadd(a, b) c_np = a.asnumpy() + b.asnumpy() assert same(c.asnumpy(), c_np) + # test backward c.backward() expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size assert same(a.grad.asnumpy(), expected_grad_a) assert same(b.grad.asnumpy(), expected_grad_b) - + # test kAddTo request + a = mx.nd.normal(shape=a_shape) + b = mx.nd.normal(shape=b_shape) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + c = mx.nd.contrib.tvm_vadd(a, b) + d = mx.nd.contrib.tvm_vadd(a, b) + mx.autograd.backward([c, d]) + expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size + expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size + assert same(a.grad.asnumpy(), expected_grad_a) + assert same(b.grad.asnumpy(), expected_grad_b) + if __name__ == '__main__': import nose From 14351d98643c919e872770de550058d317b98f51 Mon Sep 17 00:00:00 2001 From: Fan Date: Wed, 21 Aug 2019 11:36:19 +0800 Subject: [PATCH 3/3] pad for broadcast to a larger dim --- contrib/tvmop/__init__.py | 1 + contrib/tvmop/basic/ufunc.py | 44 +++++------------- contrib/tvmop/utils.py | 29 ++++++++++++ src/operator/contrib/tvmop/ufunc.cc | 35 +++++++++++--- tests/python/unittest/test_tvm_op.py | 69 ++++++++++++++++------------ 5 files changed, 111 insertions(+), 67 deletions(-) diff --git a/contrib/tvmop/__init__.py b/contrib/tvmop/__init__.py index 31189d499b5a..1234ee7d31f1 100644 --- a/contrib/tvmop/__init__.py +++ b/contrib/tvmop/__init__.py @@ -18,5 +18,6 @@ # coding: utf-8 from .opdef import defop from .utils import AllTypes, RealTypes +from .utils import assign_by_req, reduce_axes from . import basic diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index d526e463412a..6bb102ccf7e3 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -18,6 +18,7 @@ # coding: utf-8 import tvm from .. import defop, AllTypes +from .. import assign_by_req, reduce_axes def compute_add(dtype, ndim): A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype) @@ -29,7 +30,7 @@ def compute_add(dtype, ndim): @defop(name="vadd", target="cpu", auto_broadcast=True, - dtype=AllTypes, ndim=list(range(1, 6))) + dtype=AllTypes, ndim=[5]) def vadd(dtype, ndim): s, A, B, C = compute_add(dtype, ndim) axes = [axis for axis in C.op.axis] @@ -40,7 +41,7 @@ def vadd(dtype, ndim): @defop(name="cuda_vadd", target="cuda", auto_broadcast=True, - dtype=["float32", "float64"], ndim=list(range(1, 6))) + dtype=["float32", "float64"], ndim=[5]) def vadd_gpu(dtype, ndim): s, A, B, C = compute_add(dtype, ndim) s = tvm.create_schedule(C.op) @@ -52,35 +53,14 @@ def vadd_gpu(dtype, ndim): return s, [A, B, C] -def assign_by_req(a, req): - b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype) - if (req == "kAddTo"): - c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx]) - else: - c = tvm.compute(a.shape, lambda *idx: a[idx]) - return b, c - - -def reduce_axes(X, axes, reducer): - def get_index(idx, ridx): - j = 0 - k = 0 - ret = [] - for val in axes: - ret.append(idx[j] if val == 0 else ridx[k]) - j += (val == 0) - k += (val != 0) - return tuple(ret) - - ishape = X.shape - odim = (len(ishape) + 1 - axes[0]) // 2 - oshape = [tvm.var() for _ in range(odim)] - ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1] - ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret') - return ret - - def compute_backward_vadd(dtype, ndim, reduce1st, req): + # The backward of broadcast op is basically a reduction on broadcast axes. + # We label the reduce axes as 1 and other axes as 0, and they form a bit string. + # Each bit string correponds to a kernel, so the number of kernels is as many as `2^n` + # To reduce it, the bit string is compressed by combining consecutive 0s or 1s. + # In this way, the number of bit string (the number of kernels) is reduced to `2 * n` + # They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit + # of the compressed bit string. Credit to @junrushao1994 and @yzhliu. axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim] X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype) reducer = tvm.comm_reducer(lambda x, y: x + y, @@ -92,7 +72,7 @@ def compute_backward_vadd(dtype, ndim, reduce1st, req): @defop(name="backward_vadd", target="cpu", dtype=AllTypes, - ndim=list(range(1, 6)), reduce1st=[0, 1], + ndim=[5], reduce1st=[0, 1], req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"]) def backward_vadd(dtype, ndim, reduce1st, req): s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req) @@ -104,7 +84,7 @@ def backward_vadd(dtype, ndim, reduce1st, req): @defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"], - ndim=list(range(1, 6)), reduce1st=[0, 1], + ndim=[5], reduce1st=[0, 1], req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"]) def backward_vadd_gpu(dtype, ndim, reduce1st, req): s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req) diff --git a/contrib/tvmop/utils.py b/contrib/tvmop/utils.py index 0b2416b4f3ae..329dce2148d9 100644 --- a/contrib/tvmop/utils.py +++ b/contrib/tvmop/utils.py @@ -16,5 +16,34 @@ # under the License. # coding: utf-8 +import tvm + AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"] RealTypes = ["float32", "float64", "float16"] + +def assign_by_req(a, req): + b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype) + if (req == "kAddTo"): + c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx]) + else: + c = tvm.compute(a.shape, lambda *idx: a[idx]) + return b, c + + +def reduce_axes(X, axes, reducer): + def get_index(idx, ridx): + j = 0 + k = 0 + ret = [] + for val in axes: + ret.append(idx[j] if val == 0 else ridx[k]) + j += (val == 0) + k += (val != 0) + return tuple(ret) + + ishape = X.shape + odim = (len(ishape) + 1 - axes[0]) // 2 + oshape = [tvm.var() for _ in range(odim)] + ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1] + ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret') + return ret diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index b4f3ab4bd317..7dc28bd0f419 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -40,6 +40,16 @@ static constexpr char func_vadd_cpu[] = "vadd"; static constexpr char func_vadd_gpu[] = "cuda_vadd"; static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd"; static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd"; +static constexpr int max_dim = 5; + +TBlob padding(const TBlob& tblob, const int max_dim) { + TShape tshape(max_dim, 1); + int ndim = tblob.shape_.ndim(); + for (int i = max_dim - ndim; i < max_dim; ++i) { + tshape[i] = tblob.size(i - max_dim + ndim); + } + return tblob.reshape(tshape); +} template void TVMBinaryCompute(const nnvm::NodeAttrs& attrs, @@ -49,7 +59,12 @@ void TVMBinaryCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]}); + TBlob idata[2], odata; + for (int k = 0; k < 2; ++k) { + idata[k] = padding(inputs[k], max_dim); + } + odata = padding(outputs[0], max_dim); + tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {idata[0], idata[1], odata}); } template @@ -64,8 +79,13 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, for (int k = 0; k < 2; ++k) { // dispatch by backward std::vector ov, iv; - const TBlob& ograd = inputs[0], igrad = outputs[k]; - bool flag = ograd.size(0) != igrad.size(0); + TBlob ograd = padding(inputs[0], ndim), igrad = padding(outputs[k], ndim); + int flag; + if (ograd.size(0) != igrad.size(0)) { + flag = 1; + } else { + flag = 0; + } for (int i = 0; i < ndim; ++i) { if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) { ov.push_back(ograd.size(i)); @@ -73,18 +93,21 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, ov.back() *= ograd.size(i); } } + for (int i = ov.size(); i < max_dim; ++i) { + ov.push_back(1); + } for (int i = flag; i < ov.size(); i += 2) { iv.push_back(ov[i]); } TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end()); - TBlob ograd_tvm(ograd.reshape(oshape).dltensor()); - TBlob igrad_tvm(igrad.reshape(ishape).dltensor()); + TBlob ograd_tvm(ograd.reshape(oshape)); + TBlob igrad_tvm(igrad.reshape(ishape)); std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag); // dispatch by req funcname += "req_"; MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, { if (req_type == kWriteTo) { - funcname += "kWriteTo"; + funcname += "kWriteTo"; } else { funcname += "kAddTo"; } diff --git a/tests/python/unittest/test_tvm_op.py b/tests/python/unittest/test_tvm_op.py index 2126631077d4..9c1f567ad31a 100644 --- a/tests/python/unittest/test_tvm_op.py +++ b/tests/python/unittest/test_tvm_op.py @@ -26,35 +26,46 @@ @with_seed() def test_tvm_broadcast_add(): if _features.is_enabled("TVM_OP"): - a_shape = rand_shape_nd(4) - b_shape = (1,) + a_shape[1:2] + (1, 1) - a = mx.nd.normal(shape=a_shape) - b = mx.nd.normal(shape=b_shape) - a.attach_grad() - b.attach_grad() - with mx.autograd.record(): - c = mx.nd.contrib.tvm_vadd(a, b) - c_np = a.asnumpy() + b.asnumpy() - assert same(c.asnumpy(), c_np) - # test backward - c.backward() - expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size - expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size - assert same(a.grad.asnumpy(), expected_grad_a) - assert same(b.grad.asnumpy(), expected_grad_b) - # test kAddTo request - a = mx.nd.normal(shape=a_shape) - b = mx.nd.normal(shape=b_shape) - a.attach_grad() - b.attach_grad() - with mx.autograd.record(): - c = mx.nd.contrib.tvm_vadd(a, b) - d = mx.nd.contrib.tvm_vadd(a, b) - mx.autograd.backward([c, d]) - expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size - expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size - assert same(a.grad.asnumpy(), expected_grad_a) - assert same(b.grad.asnumpy(), expected_grad_b) + configs = [ + [[5, 6, 7, 8, 9], [1]], + [[6, 4, 5, 2, 1], [6, 1, 5, 1, 1]], + [[3, 5, 6], [1, 6]], + [[3, 5, 6], [5, 1]], + [[3, 5, 6], [5, 6]], + [[4, 3, 2, 1], [2, 1]], + [[4, 3, 2, 2], [4, 1, 1, 2]], + [[6, 6], [6, 6]], + ] + for config in configs: + a_shape = config[0] + b_shape = config[1] + a = mx.nd.normal(shape=a_shape) + b = mx.nd.normal(shape=b_shape) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + c = mx.nd.contrib.tvm_vadd(a, b) + c_np = a.asnumpy() + b.asnumpy() + assert same(c.asnumpy(), c_np) + # test backward + c.backward() + expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size + expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size + assert same(a.grad.asnumpy(), expected_grad_a) + assert same(b.grad.asnumpy(), expected_grad_b) + # test kAddTo request + a = mx.nd.normal(shape=a_shape) + b = mx.nd.normal(shape=b_shape) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + c = mx.nd.contrib.tvm_vadd(a, b) + d = mx.nd.contrib.tvm_vadd(a, b) + mx.autograd.backward([c, d]) + expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size + expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size + assert same(a.grad.asnumpy(), expected_grad_a) + assert same(b.grad.asnumpy(), expected_grad_b) if __name__ == '__main__':