Skip to content

Commit 0fb5365

Browse files
authored
[Relax] Ignore dynamic parameters in RewriteDataflowReshape (#17086)
The Relax transform `RewriteDataflowReshape` identifies TIR functions that are equivalent to `relax.op.reshape`, and replaces them with calls to `relax.op.reshape`. This is used as a precursor for simplifications that rely on the high-level knowledge that an operator is a reshape, but also require the low-level knowledge of the adjacent TIR PrimFuncs. Prior to this commit, the `RewriteDataflowReshape` pass would only recognize static shapes, or dynamic shapes that could be inferred from the shapes of tensor arguments. This commit updates `RewriteDataflowReshape` to recognize cases where an extra symbolic variable has been provided.
1 parent 0984e97 commit 0fb5365

File tree

3 files changed

+202
-14
lines changed

3 files changed

+202
-14
lines changed

src/relax/analysis/tir_op_pattern_kind.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,19 +517,23 @@ bool HasReshapePattern(const PrimFunc& func) {
517517
arith::Analyzer ana_;
518518
};
519519

520-
if (func->params.size() < 2) {
521-
return false;
520+
Array<Buffer> buffer_args;
521+
for (const auto& param : func->params) {
522+
if (auto buffer = func->buffer_map.Get(param)) {
523+
buffer_args.push_back(buffer.value());
524+
}
522525
}
523-
Optional<Buffer> src_buffer = func->buffer_map.Get(func->params.front());
524-
Optional<Buffer> dst_buffer = func->buffer_map.Get(func->params.back());
525-
if (!(src_buffer.defined() && dst_buffer.defined())) {
526+
527+
if (buffer_args.size() < 2) {
526528
return false;
527529
}
530+
Buffer src_buffer = buffer_args.front();
531+
Buffer dst_buffer = buffer_args.back();
528532

529533
// To detect the reshape pattern, we require each For to have
530534
// either another For or a BlockRealize as body.
531535
ICHECK(func->body->IsInstance<BlockRealizeNode>());
532-
return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body);
536+
return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body);
533537
}
534538

535539
TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern);

src/relax/transform/rewrite_dataflow_reshape.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,15 @@
3434
namespace tvm {
3535
namespace relax {
3636

37-
std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) {
37+
std::vector<size_t> GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) {
3838
std::vector<size_t> indices;
3939
for (size_t i = 0; i < num_args; ++i) {
40-
auto buffer_var = fn->buffer_map[fn->params[i]]->data;
41-
if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
42-
indices.push_back(i);
40+
if (auto buffer = fn->buffer_map.Get(fn->params[i])) {
41+
auto buffer_var = buffer.value()->data;
42+
if (tir::UsesVar(fn->body,
43+
[=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
44+
indices.push_back(i);
45+
}
4346
}
4447
}
4548
return indices;
@@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator {
8386

8487
auto prim_fn = Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
8588
auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
86-
auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size());
89+
auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size());
8790

8891
// The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps
8992
// can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR
9093
// then flattens the tuple input so that the fused TIR reshape function ends up having
9194
// multiple input buffers. But only one of them should be accessed and reshaped.
92-
if (used_arg_indices.size() != 1) {
95+
if (used_tensor_arg_indices.size() != 1) {
9396
return GetRef<Call>(call);
9497
}
9598

96-
auto arg = arg_tuple[used_arg_indices[0]];
99+
auto arg = arg_tuple[used_tensor_arg_indices[0]];
97100

98101
if (!IsCallingTIRReshape(call, arg)) {
99102
return GetRef<Call>(call);

tests/python/relax/test_transform_rewrite_dataflow_reshape.py

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tvm
1919
import tvm.testing
2020
from tvm import relax
21-
from tvm.script import relax as R, tir as T
21+
from tvm.script import relax as R, tir as T, ir as I
2222

2323

2424
def test_reshape_expand_dims():
@@ -581,5 +581,186 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
581581
tvm.ir.assert_structural_equal(rewritten, Expected)
582582

583583

584+
def test_rewrite_static_reshape():
585+
@I.ir_module
586+
class Before:
587+
@R.function
588+
def main(x: R.Tensor([256], dtype="float32")):
589+
with R.dataflow():
590+
y = R.reshape(x, [64, 4])
591+
z = R.add(y, y)
592+
R.output(z)
593+
return z
594+
595+
@I.ir_module
596+
class Expected:
597+
@R.function
598+
def main(x: R.Tensor((256,), dtype="float32")):
599+
cls = Expected
600+
601+
with R.dataflow():
602+
y = R.reshape(x, R.shape([64, 4]))
603+
z = R.call_tir(cls.add, (y, y), out_sinfo=R.Tensor((64, 4), dtype="float32"))
604+
R.output(z)
605+
return z
606+
607+
@T.prim_func(private=True)
608+
def add(
609+
y1: T.Buffer((T.int64(64), T.int64(4)), "float32"),
610+
y2: T.Buffer((T.int64(64), T.int64(4)), "float32"),
611+
z: T.Buffer((T.int64(64), T.int64(4)), "float32"),
612+
):
613+
T.func_attr({"tir.noalias": T.bool(True)})
614+
615+
for iters in T.grid(T.int64(64), T.int64(4)):
616+
with T.block("T_add"):
617+
i, j = T.axis.remap("SS", iters)
618+
z[i, j] = y1[i, j] + y2[i, j]
619+
620+
After = tvm.ir.transform.Sequential(
621+
[
622+
# Lower both R.reshape and R.add from Relax to TIR
623+
relax.transform.LegalizeOps(),
624+
# Identify reshapes, raise calls to cls.reshape from TIR
625+
# to Relax
626+
relax.transform.RewriteDataflowReshape(),
627+
# Clean up afterwards, removing the no-longer-required
628+
# PrimFunc "reshape"
629+
relax.transform.DeadCodeElimination(),
630+
]
631+
)(Before)
632+
633+
tvm.ir.assert_structural_equal(Expected, After)
634+
635+
636+
# def test_rewrite_dynamic_reshape():
637+
# @I.ir_module
638+
# class Before:
639+
# @R.function
640+
# def main(x: R.Tensor(["N"], dtype="float32")):
641+
# N = T.int64()
642+
# with R.dataflow():
643+
# y = R.reshape(x, [N // 4, 4])
644+
# z = R.add(y, y)
645+
# R.output(z)
646+
# return z
647+
648+
# @I.ir_module
649+
# class Expected:
650+
# @R.function
651+
# def main(x: R.Tensor(["N"], dtype="float32")):
652+
# N = T.int64()
653+
# cls = Expected
654+
655+
# with R.dataflow():
656+
# y = R.reshape(x, R.shape([N // 4, 4]))
657+
# z = R.call_tir(
658+
# cls.add,
659+
# (y, y),
660+
# tir_vars=[N],
661+
# out_sinfo=R.Tensor((N // 4, 4), dtype="float32"),
662+
# )
663+
# R.output(z)
664+
# return z
665+
666+
# @T.prim_func(private=True)
667+
# def add(
668+
# y1_handle: T.handle,
669+
# y2_handle: T.handle,
670+
# z_handle: T.handle,
671+
# N: T.int64,
672+
# ):
673+
674+
# y1 = T.match_buffer(y1_handle, [N // 4, 4], "float32")
675+
# y2 = T.match_buffer(y2_handle, [N // 4, 4], "float32")
676+
# z = T.match_buffer(z_handle, [N // 4, 4], "float32")
677+
678+
# T.func_attr({"tir.noalias": T.bool(True)})
679+
680+
# for iters in T.grid(T.int64(64), T.int64(4)):
681+
# with T.block("T_add"):
682+
# i, j = T.axis.remap("SS", iters)
683+
# z[i, j] = y1[i, j] + y2[i, j]
684+
685+
# After = tvm.ir.transform.Sequential(
686+
# [
687+
# # Lower both R.reshape and R.add from Relax to TIR
688+
# relax.transform.LegalizeOps(),
689+
# # Identify reshapes, raise calls to cls.reshape from TIR
690+
# # to Relax
691+
# relax.transform.RewriteDataflowReshape(),
692+
# # Clean up afterwards, removing the no-longer-required
693+
# # PrimFunc "reshape"
694+
# relax.transform.DeadCodeElimination(),
695+
# ]
696+
# )(Before)
697+
# After.show()
698+
# tvm.ir.assert_structural_equal(Expected, After)
699+
700+
701+
def test_rewrite_dynamic_reshape():
702+
@I.ir_module
703+
class Before:
704+
@R.function
705+
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
706+
N = T.int64()
707+
with R.dataflow():
708+
y = R.reshape(x, [N * 4, T.int64(4)])
709+
z = R.add(y, y)
710+
R.output(z)
711+
return z
712+
713+
@I.ir_module
714+
class Expected:
715+
@R.function
716+
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
717+
N = T.int64()
718+
cls = Expected
719+
720+
with R.dataflow():
721+
y = R.reshape(x, R.shape([N * 4, T.int64(4)]))
722+
z = R.call_tir(
723+
cls.add,
724+
(y, y),
725+
tir_vars=[N],
726+
out_sinfo=R.Tensor((N * 4, 4), dtype="float32"),
727+
)
728+
R.output(z)
729+
return z
730+
731+
@T.prim_func(private=True)
732+
def add(
733+
y1_handle: T.handle,
734+
y2_handle: T.handle,
735+
z_handle: T.handle,
736+
N: T.int64,
737+
):
738+
739+
y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32")
740+
y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32")
741+
z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32")
742+
743+
T.func_attr({"tir.noalias": T.bool(True)})
744+
745+
for iters in T.grid(N * 4, T.int64(4)):
746+
with T.block("T_add"):
747+
i, j = T.axis.remap("SS", iters)
748+
z[i, j] = y1[i, j] + y2[i, j]
749+
750+
After = tvm.ir.transform.Sequential(
751+
[
752+
# Lower both R.reshape and R.add from Relax to TIR
753+
relax.transform.LegalizeOps(),
754+
# Identify reshapes, raise calls to cls.reshape from TIR
755+
# to Relax
756+
relax.transform.RewriteDataflowReshape(),
757+
# Clean up afterwards, removing the no-longer-required
758+
# PrimFunc "reshape"
759+
relax.transform.DeadCodeElimination(),
760+
]
761+
)(Before)
762+
tvm.ir.assert_structural_equal(Expected, After)
763+
764+
584765
if __name__ == "__main__":
585766
tvm.testing.main()

0 commit comments

Comments
 (0)