Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
pad for broadcast to a larger dim
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Aug 21, 2019
1 parent f373aa8 commit 14351d9
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 67 deletions.
1 change: 1 addition & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
# coding: utf-8
from .opdef import defop
from .utils import AllTypes, RealTypes
from .utils import assign_by_req, reduce_axes

from . import basic
44 changes: 12 additions & 32 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# coding: utf-8
import tvm
from .. import defop, AllTypes
from .. import assign_by_req, reduce_axes

def compute_add(dtype, ndim):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
Expand All @@ -29,7 +30,7 @@ def compute_add(dtype, ndim):


@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=list(range(1, 6)))
dtype=AllTypes, ndim=[5])
def vadd(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
axes = [axis for axis in C.op.axis]
Expand All @@ -40,7 +41,7 @@ def vadd(dtype, ndim):


@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=list(range(1, 6)))
dtype=["float32", "float64"], ndim=[5])
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.create_schedule(C.op)
Expand All @@ -52,35 +53,14 @@ def vadd_gpu(dtype, ndim):
return s, [A, B, C]


def assign_by_req(a, req):
b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
if (req == "kAddTo"):
c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx])
else:
c = tvm.compute(a.shape, lambda *idx: a[idx])
return b, c


def reduce_axes(X, axes, reducer):
def get_index(idx, ridx):
j = 0
k = 0
ret = []
for val in axes:
ret.append(idx[j] if val == 0 else ridx[k])
j += (val == 0)
k += (val != 0)
return tuple(ret)

ishape = X.shape
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.var() for _ in range(odim)]
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret')
return ret


def compute_backward_vadd(dtype, ndim, reduce1st, req):
# The backward of broadcast op is basically a reduction on broadcast axes.
# We label the reduce axes as 1 and other axes as 0, and they form a bit string.
# Each bit string correponds to a kernel, so the number of kernels is as many as `2^n`
# To reduce it, the bit string is compressed by combining consecutive 0s or 1s.
# In this way, the number of bit string (the number of kernels) is reduced to `2 * n`
# They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit
# of the compressed bit string. Credit to @junrushao1994 and @yzhliu.
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.comm_reducer(lambda x, y: x + y,
Expand All @@ -92,7 +72,7 @@ def compute_backward_vadd(dtype, ndim, reduce1st, req):


@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=list(range(1, 6)), reduce1st=[0, 1],
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
Expand All @@ -104,7 +84,7 @@ def backward_vadd(dtype, ndim, reduce1st, req):


@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"],
ndim=list(range(1, 6)), reduce1st=[0, 1],
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
Expand Down
29 changes: 29 additions & 0 deletions contrib/tvmop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,34 @@
# under the License.

# coding: utf-8
import tvm

AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"]
RealTypes = ["float32", "float64", "float16"]

def assign_by_req(a, req):
b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
if (req == "kAddTo"):
c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx])
else:
c = tvm.compute(a.shape, lambda *idx: a[idx])
return b, c


def reduce_axes(X, axes, reducer):
def get_index(idx, ridx):
j = 0
k = 0
ret = []
for val in axes:
ret.append(idx[j] if val == 0 else ridx[k])
j += (val == 0)
k += (val != 0)
return tuple(ret)

ishape = X.shape
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.var() for _ in range(odim)]
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret')
return ret
35 changes: 29 additions & 6 deletions src/operator/contrib/tvmop/ufunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ static constexpr char func_vadd_cpu[] = "vadd";
static constexpr char func_vadd_gpu[] = "cuda_vadd";
static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd";
static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd";
static constexpr int max_dim = 5;

TBlob padding(const TBlob& tblob, const int max_dim) {
TShape tshape(max_dim, 1);
int ndim = tblob.shape_.ndim();
for (int i = max_dim - ndim; i < max_dim; ++i) {
tshape[i] = tblob.size(i - max_dim + ndim);
}
return tblob.reshape(tshape);
}

template<const char* func>
void TVMBinaryCompute(const nnvm::NodeAttrs& attrs,
Expand All @@ -49,7 +59,12 @@ void TVMBinaryCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1], outputs[0]});
TBlob idata[2], odata;
for (int k = 0; k < 2; ++k) {
idata[k] = padding(inputs[k], max_dim);
}
odata = padding(outputs[0], max_dim);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {idata[0], idata[1], odata});
}

template<const char* func>
Expand All @@ -64,27 +79,35 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs,
for (int k = 0; k < 2; ++k) {
// dispatch by backward
std::vector<int> ov, iv;
const TBlob& ograd = inputs[0], igrad = outputs[k];
bool flag = ograd.size(0) != igrad.size(0);
TBlob ograd = padding(inputs[0], ndim), igrad = padding(outputs[k], ndim);
int flag;
if (ograd.size(0) != igrad.size(0)) {
flag = 1;
} else {
flag = 0;
}
for (int i = 0; i < ndim; ++i) {
if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) != igrad.size(i - 1))) {
ov.push_back(ograd.size(i));
} else {
ov.back() *= ograd.size(i);
}
}
for (int i = ov.size(); i < max_dim; ++i) {
ov.push_back(1);
}
for (int i = flag; i < ov.size(); i += 2) {
iv.push_back(ov[i]);
}
TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end());
TBlob ograd_tvm(ograd.reshape(oshape).dltensor());
TBlob igrad_tvm(igrad.reshape(ishape).dltensor());
TBlob ograd_tvm(ograd.reshape(oshape));
TBlob igrad_tvm(igrad.reshape(ishape));
std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag);
// dispatch by req
funcname += "req_";
MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, {
if (req_type == kWriteTo) {
funcname += "kWriteTo";
funcname += "kWriteTo";
} else {
funcname += "kAddTo";
}
Expand Down
69 changes: 40 additions & 29 deletions tests/python/unittest/test_tvm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,46 @@
@with_seed()
def test_tvm_broadcast_add():
if _features.is_enabled("TVM_OP"):
a_shape = rand_shape_nd(4)
b_shape = (1,) + a_shape[1:2] + (1, 1)
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
# test backward
c.backward()
expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size
expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)
# test kAddTo request
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
d = mx.nd.contrib.tvm_vadd(a, b)
mx.autograd.backward([c, d])
expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size
expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)
configs = [
[[5, 6, 7, 8, 9], [1]],
[[6, 4, 5, 2, 1], [6, 1, 5, 1, 1]],
[[3, 5, 6], [1, 6]],
[[3, 5, 6], [5, 1]],
[[3, 5, 6], [5, 6]],
[[4, 3, 2, 1], [2, 1]],
[[4, 3, 2, 2], [4, 1, 1, 2]],
[[6, 6], [6, 6]],
]
for config in configs:
a_shape = config[0]
b_shape = config[1]
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
c_np = a.asnumpy() + b.asnumpy()
assert same(c.asnumpy(), c_np)
# test backward
c.backward()
expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size
expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)
# test kAddTo request
a = mx.nd.normal(shape=a_shape)
b = mx.nd.normal(shape=b_shape)
a.attach_grad()
b.attach_grad()
with mx.autograd.record():
c = mx.nd.contrib.tvm_vadd(a, b)
d = mx.nd.contrib.tvm_vadd(a, b)
mx.autograd.backward([c, d])
expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size
expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size
assert same(a.grad.asnumpy(), expected_grad_a)
assert same(b.grad.asnumpy(), expected_grad_b)


if __name__ == '__main__':
Expand Down

0 comments on commit 14351d9

Please sign in to comment.