Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
tvm broadcast backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Aug 18, 2019
1 parent 1a6fe60 commit 53dd8cd
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
64 changes: 64 additions & 0 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def compute_add(dtype, ndim):
s = tvm.create_schedule(C.op)
return s, A, B, C


@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=list(range(1, 6)))
def vadd(dtype, ndim):
Expand All @@ -37,6 +38,7 @@ def vadd(dtype, ndim):

return s, [A, B, C]


@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=list(range(1, 6)))
def vadd_gpu(dtype, ndim):
Expand All @@ -48,3 +50,65 @@ def vadd_gpu(dtype, ndim):
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B, C]


def reduce_axes(X, ishape, axes):
def get_index(idx, ridx):
j = 0
k = 0
ret = []
for val in axes:
ret.append(idx[j] if val == 0 else ridx[k])
j += (val == 0)
k += (val != 0)
return tuple(ret)

if (len(ishape) == 2 and axes == [1, 0]):
r = tvm.reduce_axis((0, ishape[0]), "r")
oshape = [tvm.var()]
ret = tvm.compute(oshape, lambda i: tvm.sum(X[r, i], axis=r), name='ret')
else:
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.var() for _ in range(odim)]
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.compute(oshape, lambda *idx: tvm.sum(X[get_index(idx, ridx)], axis=ridx), name='ret')

return ret


def compute_backward_vadd(dtype, ndim, reduce1st):
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
ishape = [tvm.var() for _ in range(ndim)]
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.var() for _ in range(odim)]
X = tvm.placeholder(ishape, name='X', dtype=dtype)
ret = reduce_axes(X, ishape, axes)
s = tvm.create_schedule(ret.op)
return s, X, ret, [ret]


@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"])
def backward_vadd(dtype, ndim, reduce1st):
s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [X, ret]


@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"],
ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"])
def backward_vadd_gpu(dtype, ndim, reduce1st):
s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [X, ret]
52 changes: 51 additions & 1 deletion src/operator/contrib/tvmop/ufunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <mxnet/base.h>
#include <string>
#include "../../tensor/elemwise_binary_broadcast_op.h"
#include "../../tvmop/op_module.h"
#include "../../tensor/elemwise_binary_op.h"
Expand All @@ -37,6 +38,8 @@ namespace op {

static constexpr char func_vadd_cpu[] = "vadd";
static constexpr char func_vadd_gpu[] = "cuda_vadd";
static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd";
static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd";

template<const char* func>
void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs,
Expand All @@ -49,17 +52,64 @@ void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs,
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]});
}

template<const char* func>
void TVMBackwardBroadcastCompute(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
int ndim = inputs[0].shape_.ndim();
for (int k = 0; k < 2; ++k) {
std::vector<int> ov, iv;
const TBlob& ograd = inputs[0], igrad = outputs[k];
bool flag = ograd.size(0) != igrad.size(0);
for (int i = 0; i < ndim; ++i) {
if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) {
ov.push_back(ograd.size(i));
} else {
ov.back() *= ograd.size(i);
}
}
for (int i = flag; i < ov.size(); i += 2) {
iv.push_back(ov[i]);
}
TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end());
std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag);
TBlob ograd_tvm(ograd.reshape(oshape).dltensor());
TBlob igrad_tvm(igrad.reshape(ishape).dltensor());
tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm});
}
}

NNVM_REGISTER_OP(_contrib_tvm_vadd)
.set_num_inputs(2)
.set_num_outputs(1)
.add_argument("a", "NDArray-or-Symbol", "first input")
.add_argument("b", "NDArray-or-Symbol", "second input")
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "b"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::ElemwiseType<2, 1>)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMBroadcastCompute<func_vadd_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMBroadcastCompute<func_vadd_cpu>);
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMBroadcastCompute<func_vadd_cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_contrib_tvm_vadd"});

NNVM_REGISTER_OP(_backward_contrib_tvm_vadd)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>",
mxnet::op::TVMBackwardBroadcastCompute<func_bakcward_vadd_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>",
mxnet::op::TVMBackwardBroadcastCompute<func_bakcward_vadd_cpu>);

} // namespace op
} // namespace mxnet
Expand Down
12 changes: 11 additions & 1 deletion tests/python/unittest/test_tvm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import mxnet as mx
import numpy as _np
from mxnet.test_utils import same, rand_shape_nd
from mxnet.runtime import Features
from common import with_seed
Expand All @@ -29,9 +30,18 @@ def test_tvm_broadcast_add():
b_shape = (1,) + a_shape[1:2] + (1, 1)
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
c = mx.nd.contrib.tvm_vadd(a, b)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
c.backward()
expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size
expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 53dd8cd

Please sign in to comment.