Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Tvm broadcast backward #15938

Merged
merged 3 commits into from
Aug 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
# coding: utf-8
from .opdef import defop
from .utils import AllTypes, RealTypes
from .utils import assign_by_req, reduce_axes

from . import basic
54 changes: 52 additions & 2 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# coding: utf-8
import tvm
from .. import defop, AllTypes
from .. import assign_by_req, reduce_axes

def compute_add(dtype, ndim):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
Expand All @@ -27,8 +28,9 @@ def compute_add(dtype, ndim):
s = tvm.create_schedule(C.op)
return s, A, B, C


@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=list(range(1, 6)))
dtype=AllTypes, ndim=[5])
def vadd(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
axes = [axis for axis in C.op.axis]
Expand All @@ -37,8 +39,9 @@ def vadd(dtype, ndim):

return s, [A, B, C]


@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=list(range(1, 6)))
dtype=["float32", "float64"], ndim=[5])
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.create_schedule(C.op)
Expand All @@ -48,3 +51,50 @@ def vadd_gpu(dtype, ndim):
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B, C]


def compute_backward_vadd(dtype, ndim, reduce1st, req):
# The backward of broadcast op is basically a reduction on broadcast axes.
# We label the reduce axes as 1 and other axes as 0, and they form a bit string.
# Each bit string correponds to a kernel, so the number of kernels is as many as `2^n`
# To reduce it, the bit string is compressed by combining consecutive 0s or 1s.
# In this way, the number of bit string (the number of kernels) is reduced to `2 * n`
# They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit
# of the compressed bit string. Credit to @junrushao1994 and @yzhliu.
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.comm_reducer(lambda x, y: x + y,
lambda t: tvm.const(0, dtype=t), name="sum")
ret = reduce_axes(X, axes, reducer)
in_grad_a, in_grad = assign_by_req(ret, req)
s = tvm.create_schedule(in_grad.op)
return s, X, in_grad_a, in_grad, [ret, in_grad]


@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [X, in_grad_a, in_grad]


@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"],
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [X, in_grad_a, in_grad]
29 changes: 29 additions & 0 deletions contrib/tvmop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,34 @@
# under the License.

# coding: utf-8
import tvm

AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"]
RealTypes = ["float32", "float64", "float16"]

def assign_by_req(a, req):
b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
if (req == "kAddTo"):
c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx])
else:
c = tvm.compute(a.shape, lambda *idx: a[idx])
return b, c


def reduce_axes(X, axes, reducer):
def get_index(idx, ridx):
j = 0
k = 0
ret = []
for val in axes:
ret.append(idx[j] if val == 0 else ridx[k])
j += (val == 0)
k += (val != 0)
return tuple(ret)

ishape = X.shape
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.var() for _ in range(odim)]
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret')
return ret
99 changes: 91 additions & 8 deletions src/operator/contrib/tvmop/ufunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <mxnet/base.h>
#include <string>
#include "../../tensor/elemwise_binary_broadcast_op.h"
#include "../../tvmop/op_module.h"
#include "../../tensor/elemwise_binary_op.h"
Expand All @@ -37,29 +38,111 @@ namespace op {

static constexpr char func_vadd_cpu[] = "vadd";
static constexpr char func_vadd_gpu[] = "cuda_vadd";
static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd";
static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd";
static constexpr int max_dim = 5;

TBlob padding(const TBlob& tblob, const int max_dim) {
TShape tshape(max_dim, 1);
int ndim = tblob.shape_.ndim();
for (int i = max_dim - ndim; i < max_dim; ++i) {
tshape[i] = tblob.size(i - max_dim + ndim);
}
return tblob.reshape(tshape);
}

template<const char* func>
void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void TVMBinaryCompute(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]});
TBlob idata[2], odata;
for (int k = 0; k < 2; ++k) {
idata[k] = padding(inputs[k], max_dim);
}
odata = padding(outputs[0], max_dim);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {idata[0], idata[1], odata});
}

template<const char* func>
void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
int ndim = inputs[0].shape_.ndim();
for (int k = 0; k < 2; ++k) {
// dispatch by backward
std::vector<int> ov, iv;
TBlob ograd = padding(inputs[0], ndim), igrad = padding(outputs[k], ndim);
int flag;
if (ograd.size(0) != igrad.size(0)) {
flag = 1;
} else {
flag = 0;
}
for (int i = 0; i < ndim; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, there seems to be an assumption that ograd.ndim = igrad.ndim, which is not necessarily true. I think you need to prepend axes before igrad if igrad.ndim < ograd.ndim and then use the logic here.

Copy link
Contributor Author

@hzfan hzfan Aug 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, igrad.ndim = ograd.ndim is assumed.

@yzhliu suggests padding the input to 5-dim, which is the largest possible dim supported by this op. The padding will 1) reduce the number of kernels (by a factor of 5) 2) handle the igrad.ndim < ograd.ndim issue. But there may be loss in performance.

I think prepending axes before igrad to make it ograd.dim requires more kernels, but the performance is better. It is a tradeoff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if my understanding is wrong, but don't you still need kernels generated for ndims < 5 since you will collapse consecutive dimensions where reduction is performed? For example, given a 5d shape (2, 3, 4, 5, 6), and perform reduction on axis=(1, 2), the tblob will be first reshaped into (2, 12, 30), and then reduce on axis=1. In this case, do you need a kernel generated for 3D shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can pad the shape after dimension collapse. In this case, the tblob will be reshaped into (2, 12, 30, 1, 1) and then reduce on axis=[1, 3].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I am in favor of the approach with less kernels generated. We can revisit the performance concern if that turns out to be an issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a new version, where the inputs and outputs are padded to 5 dim.

if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) {
ov.push_back(ograd.size(i));
} else {
ov.back() *= ograd.size(i);
}
}
for (int i = ov.size(); i < max_dim; ++i) {
ov.push_back(1);
}
for (int i = flag; i < ov.size(); i += 2) {
iv.push_back(ov[i]);
}
TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end());
TBlob ograd_tvm(ograd.reshape(oshape));
TBlob igrad_tvm(igrad.reshape(ishape));
std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag);
// dispatch by req
funcname += "req_";
MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, {
if (req_type == kWriteTo) {
funcname += "kWriteTo";
} else {
funcname += "kAddTo";
}
})
tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm, igrad_tvm});
}
}

NNVM_REGISTER_OP(_contrib_tvm_vadd)
.set_num_inputs(2)
.set_num_outputs(1)
.add_argument("a", "NDArray-or-Symbol", "first input")
.add_argument("b", "NDArray-or-Symbol", "second input")
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "b"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::ElemwiseType<2, 1>)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMBroadcastCompute<func_vadd_gpu>)
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMBinaryCompute<func_vadd_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMBinaryCompute<func_vadd_cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_contrib_tvm_vadd"});

NNVM_REGISTER_OP(_backward_contrib_tvm_vadd)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>",
mxnet::op::TVMBinaryBackwardComputeUseNone<func_bakcward_vadd_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMBroadcastCompute<func_vadd_cpu>);
.set_attr<mxnet::FCompute>("FCompute<cpu>",
mxnet::op::TVMBinaryBackwardComputeUseNone<func_bakcward_vadd_cpu>);

} // namespace op
} // namespace mxnet
Expand Down
49 changes: 42 additions & 7 deletions tests/python/unittest/test_tvm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import mxnet as mx
import numpy as _np
from mxnet.test_utils import same, rand_shape_nd
from mxnet.runtime import Features
from common import with_seed
Expand All @@ -25,13 +26,47 @@
@with_seed()
def test_tvm_broadcast_add():
if _features.is_enabled("TVM_OP"):
a_shape = rand_shape_nd(4)
b_shape = (1,) + a_shape[1:2] + (1, 1)
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
configs = [
[[5, 6, 7, 8, 9], [1]],
[[6, 4, 5, 2, 1], [6, 1, 5, 1, 1]],
[[3, 5, 6], [1, 6]],
[[3, 5, 6], [5, 1]],
[[3, 5, 6], [5, 6]],
[[4, 3, 2, 1], [2, 1]],
[[4, 3, 2, 2], [4, 1, 1, 2]],
[[6, 6], [6, 6]],
]
for config in configs:
a_shape = config[0]
b_shape = config[1]
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
# test backward
c.backward()
expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size
expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)
# test kAddTo request
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
d = mx.nd.contrib.tvm_vadd(a, b)
mx.autograd.backward([c, d])
expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size
expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)


if __name__ == '__main__':
import nose
Expand Down