Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/relax/analysis/tir_op_pattern_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,19 +517,23 @@ bool HasReshapePattern(const PrimFunc& func) {
arith::Analyzer ana_;
};

if (func->params.size() < 2) {
return false;
Array<Buffer> buffer_args;
for (const auto& param : func->params) {
if (auto buffer = func->buffer_map.Get(param)) {
buffer_args.push_back(buffer.value());
}
}
Optional<Buffer> src_buffer = func->buffer_map.Get(func->params.front());
Optional<Buffer> dst_buffer = func->buffer_map.Get(func->params.back());
if (!(src_buffer.defined() && dst_buffer.defined())) {

if (buffer_args.size() < 2) {
return false;
}
Buffer src_buffer = buffer_args.front();
Buffer dst_buffer = buffer_args.back();

// To detect the reshape pattern, we require each For to have
// either another For or a BlockRealize as body.
ICHECK(func->body->IsInstance<BlockRealizeNode>());
return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body);
return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body);
}

TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern);
Expand Down
17 changes: 10 additions & 7 deletions src/relax/transform/rewrite_dataflow_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
namespace tvm {
namespace relax {

std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) {
std::vector<size_t> GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) {
std::vector<size_t> indices;
for (size_t i = 0; i < num_args; ++i) {
auto buffer_var = fn->buffer_map[fn->params[i]]->data;
if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
indices.push_back(i);
if (auto buffer = fn->buffer_map.Get(fn->params[i])) {
auto buffer_var = buffer.value()->data;
if (tir::UsesVar(fn->body,
[=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
indices.push_back(i);
}
}
}
return indices;
Expand Down Expand Up @@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator {

auto prim_fn = Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size());
auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size());

// The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps
// can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR
// then flattens the tuple input so that the fused TIR reshape function ends up having
// multiple input buffers. But only one of them should be accessed and reshaped.
if (used_arg_indices.size() != 1) {
if (used_tensor_arg_indices.size() != 1) {
return GetRef<Call>(call);
}

auto arg = arg_tuple[used_arg_indices[0]];
auto arg = arg_tuple[used_tensor_arg_indices[0]];

if (!IsCallingTIRReshape(call, arg)) {
return GetRef<Call>(call);
Expand Down
183 changes: 182 additions & 1 deletion tests/python/relax/test_transform_rewrite_dataflow_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T
from tvm.script import relax as R, tir as T, ir as I


def test_reshape_expand_dims():
Expand Down Expand Up @@ -581,5 +581,186 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
tvm.ir.assert_structural_equal(rewritten, Expected)


def test_rewrite_static_reshape():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor([256], dtype="float32")):
with R.dataflow():
y = R.reshape(x, [64, 4])
z = R.add(y, y)
R.output(z)
return z

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((256,), dtype="float32")):
cls = Expected

with R.dataflow():
y = R.reshape(x, R.shape([64, 4]))
z = R.call_tir(cls.add, (y, y), out_sinfo=R.Tensor((64, 4), dtype="float32"))
R.output(z)
return z

@T.prim_func(private=True)
def add(
y1: T.Buffer((T.int64(64), T.int64(4)), "float32"),
y2: T.Buffer((T.int64(64), T.int64(4)), "float32"),
z: T.Buffer((T.int64(64), T.int64(4)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})

for iters in T.grid(T.int64(64), T.int64(4)):
with T.block("T_add"):
i, j = T.axis.remap("SS", iters)
z[i, j] = y1[i, j] + y2[i, j]

After = tvm.ir.transform.Sequential(
[
# Lower both R.reshape and R.add from Relax to TIR
relax.transform.LegalizeOps(),
# Identify reshapes, raise calls to cls.reshape from TIR
# to Relax
relax.transform.RewriteDataflowReshape(),
# Clean up afterwards, removing the no-longer-required
# PrimFunc "reshape"
relax.transform.DeadCodeElimination(),
]
)(Before)

tvm.ir.assert_structural_equal(Expected, After)


# def test_rewrite_dynamic_reshape():
# @I.ir_module
# class Before:
# @R.function
# def main(x: R.Tensor(["N"], dtype="float32")):
# N = T.int64()
# with R.dataflow():
# y = R.reshape(x, [N // 4, 4])
# z = R.add(y, y)
# R.output(z)
# return z

# @I.ir_module
# class Expected:
# @R.function
# def main(x: R.Tensor(["N"], dtype="float32")):
# N = T.int64()
# cls = Expected

# with R.dataflow():
# y = R.reshape(x, R.shape([N // 4, 4]))
# z = R.call_tir(
# cls.add,
# (y, y),
# tir_vars=[N],
# out_sinfo=R.Tensor((N // 4, 4), dtype="float32"),
# )
# R.output(z)
# return z

# @T.prim_func(private=True)
# def add(
# y1_handle: T.handle,
# y2_handle: T.handle,
# z_handle: T.handle,
# N: T.int64,
# ):

# y1 = T.match_buffer(y1_handle, [N // 4, 4], "float32")
# y2 = T.match_buffer(y2_handle, [N // 4, 4], "float32")
# z = T.match_buffer(z_handle, [N // 4, 4], "float32")

# T.func_attr({"tir.noalias": T.bool(True)})

# for iters in T.grid(T.int64(64), T.int64(4)):
# with T.block("T_add"):
# i, j = T.axis.remap("SS", iters)
# z[i, j] = y1[i, j] + y2[i, j]

# After = tvm.ir.transform.Sequential(
# [
# # Lower both R.reshape and R.add from Relax to TIR
# relax.transform.LegalizeOps(),
# # Identify reshapes, raise calls to cls.reshape from TIR
# # to Relax
# relax.transform.RewriteDataflowReshape(),
# # Clean up afterwards, removing the no-longer-required
# # PrimFunc "reshape"
# relax.transform.DeadCodeElimination(),
# ]
# )(Before)
# After.show()
# tvm.ir.assert_structural_equal(Expected, After)


def test_rewrite_dynamic_reshape():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
N = T.int64()
with R.dataflow():
y = R.reshape(x, [N * 4, T.int64(4)])
z = R.add(y, y)
R.output(z)
return z

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
N = T.int64()
cls = Expected

with R.dataflow():
y = R.reshape(x, R.shape([N * 4, T.int64(4)]))
z = R.call_tir(
cls.add,
(y, y),
tir_vars=[N],
out_sinfo=R.Tensor((N * 4, 4), dtype="float32"),
)
R.output(z)
return z

@T.prim_func(private=True)
def add(
y1_handle: T.handle,
y2_handle: T.handle,
z_handle: T.handle,
N: T.int64,
):

y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32")
y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32")
z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32")

T.func_attr({"tir.noalias": T.bool(True)})

for iters in T.grid(N * 4, T.int64(4)):
with T.block("T_add"):
i, j = T.axis.remap("SS", iters)
z[i, j] = y1[i, j] + y2[i, j]

After = tvm.ir.transform.Sequential(
[
# Lower both R.reshape and R.add from Relax to TIR
relax.transform.LegalizeOps(),
# Identify reshapes, raise calls to cls.reshape from TIR
# to Relax
relax.transform.RewriteDataflowReshape(),
# Clean up afterwards, removing the no-longer-required
# PrimFunc "reshape"
relax.transform.DeadCodeElimination(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()