Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Tvm broadcast backward (#15938)
Browse files Browse the repository at this point in the history
* tvm broadcast backward

* dispatch by req

* pad for broadcast to a larger dim
  • Loading branch information
hzfan authored and yzhliu committed Aug 22, 2019
1 parent 434f185 commit 9023256
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 17 deletions.
1 change: 1 addition & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
# coding: utf-8
from .opdef import defop
from .utils import AllTypes, RealTypes
from .utils import assign_by_req, reduce_axes

from . import basic
54 changes: 52 additions & 2 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# coding: utf-8
import tvm
from .. import defop, AllTypes
from .. import assign_by_req, reduce_axes

def compute_add(dtype, ndim):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
Expand All @@ -27,8 +28,9 @@ def compute_add(dtype, ndim):
s = tvm.create_schedule(C.op)
return s, A, B, C


@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=list(range(1, 6)))
dtype=AllTypes, ndim=[5])
def vadd(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
axes = [axis for axis in C.op.axis]
Expand All @@ -37,8 +39,9 @@ def vadd(dtype, ndim):

return s, [A, B, C]


@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=list(range(1, 6)))
dtype=["float32", "float64"], ndim=[5])
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.create_schedule(C.op)
Expand All @@ -48,3 +51,50 @@ def vadd_gpu(dtype, ndim):
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B, C]


def compute_backward_vadd(dtype, ndim, reduce1st, req):
# The backward of broadcast op is basically a reduction on broadcast axes.
# We label the reduce axes as 1 and other axes as 0, and they form a bit string.
# Each bit string correponds to a kernel, so the number of kernels is as many as `2^n`
# To reduce it, the bit string is compressed by combining consecutive 0s or 1s.
# In this way, the number of bit string (the number of kernels) is reduced to `2 * n`
# They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit
# of the compressed bit string. Credit to @junrushao1994 and @yzhliu.
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.comm_reducer(lambda x, y: x + y,
lambda t: tvm.const(0, dtype=t), name="sum")
ret = reduce_axes(X, axes, reducer)
in_grad_a, in_grad = assign_by_req(ret, req)
s = tvm.create_schedule(in_grad.op)
return s, X, in_grad_a, in_grad, [ret, in_grad]


@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [X, in_grad_a, in_grad]


@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"],
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [X, in_grad_a, in_grad]
29 changes: 29 additions & 0 deletions contrib/tvmop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,34 @@
# under the License.

# coding: utf-8
import tvm

AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"]
RealTypes = ["float32", "float64", "float16"]

def assign_by_req(a, req):
b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
if (req == "kAddTo"):
c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx])
else:
c = tvm.compute(a.shape, lambda *idx: a[idx])
return b, c


def reduce_axes(X, axes, reducer):
def get_index(idx, ridx):
j = 0
k = 0
ret = []
for val in axes:
ret.append(idx[j] if val == 0 else ridx[k])
j += (val == 0)
k += (val != 0)
return tuple(ret)

ishape = X.shape
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.var() for _ in range(odim)]
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret')
return ret
99 changes: 91 additions & 8 deletions src/operator/contrib/tvmop/ufunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <mxnet/base.h>
#include <string>
#include "../../tensor/elemwise_binary_broadcast_op.h"
#include "../../tvmop/op_module.h"
#include "../../tensor/elemwise_binary_op.h"
Expand All @@ -37,29 +38,111 @@ namespace op {

static constexpr char func_vadd_cpu[] = "vadd";
static constexpr char func_vadd_gpu[] = "cuda_vadd";
static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd";
static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd";
static constexpr int max_dim = 5;

TBlob padding(const TBlob& tblob, const int max_dim) {
TShape tshape(max_dim, 1);
int ndim = tblob.shape_.ndim();
for (int i = max_dim - ndim; i < max_dim; ++i) {
tshape[i] = tblob.size(i - max_dim + ndim);
}
return tblob.reshape(tshape);
}

template<const char* func>
void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void TVMBinaryCompute(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]});
TBlob idata[2], odata;
for (int k = 0; k < 2; ++k) {
idata[k] = padding(inputs[k], max_dim);
}
odata = padding(outputs[0], max_dim);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {idata[0], idata[1], odata});
}

template<const char* func>
void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
int ndim = inputs[0].shape_.ndim();
for (int k = 0; k < 2; ++k) {
// dispatch by backward
std::vector<int> ov, iv;
TBlob ograd = padding(inputs[0], ndim), igrad = padding(outputs[k], ndim);
int flag;
if (ograd.size(0) != igrad.size(0)) {
flag = 1;
} else {
flag = 0;
}
for (int i = 0; i < ndim; ++i) {
if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) {
ov.push_back(ograd.size(i));
} else {
ov.back() *= ograd.size(i);
}
}
for (int i = ov.size(); i < max_dim; ++i) {
ov.push_back(1);
}
for (int i = flag; i < ov.size(); i += 2) {
iv.push_back(ov[i]);
}
TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end());
TBlob ograd_tvm(ograd.reshape(oshape));
TBlob igrad_tvm(igrad.reshape(ishape));
std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag);
// dispatch by req
funcname += "req_";
MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, {
if (req_type == kWriteTo) {
funcname += "kWriteTo";
} else {
funcname += "kAddTo";
}
})
tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm, igrad_tvm});
}
}

NNVM_REGISTER_OP(_contrib_tvm_vadd)
.set_num_inputs(2)
.set_num_outputs(1)
.add_argument("a", "NDArray-or-Symbol", "first input")
.add_argument("b", "NDArray-or-Symbol", "second input")
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "b"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::ElemwiseType<2, 1>)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMBroadcastCompute<func_vadd_gpu>)
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMBinaryCompute<func_vadd_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMBinaryCompute<func_vadd_cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_contrib_tvm_vadd"});

NNVM_REGISTER_OP(_backward_contrib_tvm_vadd)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>",
mxnet::op::TVMBinaryBackwardComputeUseNone<func_bakcward_vadd_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMBroadcastCompute<func_vadd_cpu>);
.set_attr<mxnet::FCompute>("FCompute<cpu>",
mxnet::op::TVMBinaryBackwardComputeUseNone<func_bakcward_vadd_cpu>);

} // namespace op
} // namespace mxnet
Expand Down
49 changes: 42 additions & 7 deletions tests/python/unittest/test_tvm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import mxnet as mx
import numpy as _np
from mxnet.test_utils import same, rand_shape_nd
from mxnet.runtime import Features
from common import with_seed
Expand All @@ -25,13 +26,47 @@
@with_seed()
def test_tvm_broadcast_add():
if _features.is_enabled("TVM_OP"):
a_shape = rand_shape_nd(4)
b_shape = (1,) + a_shape[1:2] + (1, 1)
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
configs = [
[[5, 6, 7, 8, 9], [1]],
[[6, 4, 5, 2, 1], [6, 1, 5, 1, 1]],
[[3, 5, 6], [1, 6]],
[[3, 5, 6], [5, 1]],
[[3, 5, 6], [5, 6]],
[[4, 3, 2, 1], [2, 1]],
[[4, 3, 2, 2], [4, 1, 1, 2]],
[[6, 6], [6, 6]],
]
for config in configs:
a_shape = config[0]
b_shape = config[1]
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
# test backward
c.backward()
expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size
expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)
# test kAddTo request
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
d = mx.nd.contrib.tvm_vadd(a, b)
mx.autograd.backward([c, d])
expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size
expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 9023256

Please sign in to comment.