From dbdff6ed65bc5d39d088ed4640db4b93e97f4029 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Jan 2024 20:13:59 +0000 Subject: [PATCH 1/4] [DPL] Support tir_vars field in is_call_tir pattern --- python/tvm/relax/dpl/pattern.py | 6 ++++-- src/relax/ir/dataflow_pattern.cc | 5 ++++- tests/python/relax/test_dataflow_pattern.py | 21 ++++++++++++++++++--- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index e5670dee4b7e..9ce6a14bf83a 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -871,19 +871,21 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": def _is_call_tir( func_pattern: DFPattern, args: Union[List, Tuple, TuplePattern] = None, + tir_vars: Optional[DFPattern] = None ) -> CallPattern: if args is None: args = wildcard() elif isinstance(args, (list, tuple)): args = TuplePattern(args) - return is_op("relax.call_tir")(func_pattern, args, add_constraint=False) + return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False) # Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo def is_call_tir( func_name: str, args: Union[List, Tuple, TuplePattern] = None, + tir_vars: Optional[DFPattern] = None ) -> CallPattern: """ Syntax sugar for creating a CallPattern for call_tir that calls an function through global var. @@ -901,7 +903,7 @@ def is_call_tir( The resulting CallPattern """ func_pattern = GlobalVarPattern(func_name) - return _is_call_tir(func_pattern, args) + return _is_call_tir(func_pattern, args, tir_vars) def _is_call_dps_packed( diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 1286a32e4cb8..19f7f7848749 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -574,7 +574,7 @@ ConstantPattern IsConst() { return ConstantPattern(make_object()); } ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } -CallPattern IsCallTIR(const String& name, Optional var_args) { +CallPattern IsCallTIR(const String& name, Optional var_args, Optional tir_vars) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -582,6 +582,9 @@ CallPattern IsCallTIR(const String& name, Optional var_args) { arg_pattern = var_args.value(); } + if (tir_vars.defined()) { + return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, tir_vars.value()); + } return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); } diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 7f2cb241bb1d..39b3c40c2676 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -56,14 +56,27 @@ def tir_relu(x: T.handle, y: T.handle): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.max(A[vi, vj], 0.0) + @T.prim_func + def tir_zeros(x: T.handle, n: T.int64): + T.func_attr({"global_symbol": "tir_zeros"}) + A = T.match_buffer(x, [n]) + for i in range(n): + with T.block(): + vi = T.axis.remap("S", [i]) + A[vi] = 1.0 + @R.function - def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple: cls = Module with R.dataflow(): lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) - R.output(lv1) - return lv1 + lv2 = R.call_tir( + cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) + ) + gv = (lv1, lv2) + R.output(gv) + return gv main_fn = Module["main"] @@ -293,10 +306,12 @@ def test_match_call_attr(): def test_is_call_tir(): lv1_val = bindings[1].value + lv2_val = bindings[2].value var2val = get_var2val(Module["main"]) assert is_call_tir("tir_relu").match(lv1_val) assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val) assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val) + assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val) @R.function From b6dff5f6a9aafc2ea9ebdef3e71797af44e0a8b7 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Jan 2024 20:40:19 +0000 Subject: [PATCH 2/4] add doc --- python/tvm/relax/dpl/pattern.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 9ce6a14bf83a..7bc5febdd073 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -878,6 +878,8 @@ def _is_call_tir( elif isinstance(args, (list, tuple)): args = TuplePattern(args) + if tir_vars is None: + return is_op("relax.call_tir")(func_pattern, args, add_constraint=False) return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False) @@ -896,7 +898,8 @@ def is_call_tir( Name of the CPS function to call. args : Union[List[DFPattern], Tuple[DFPattern]], optional Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments - + tir_vars : Optional[DFPattern] + Pattern to match the tuple of integers that are unpacked when calling the tir func. Returns ------- CallPattern From 689746046212ec37ae3f6a75d8836310b66a19b4 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Jan 2024 16:34:22 -0800 Subject: [PATCH 3/4] lint --- python/tvm/relax/dpl/pattern.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 7bc5febdd073..5594dea3ad74 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -871,7 +871,7 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": def _is_call_tir( func_pattern: DFPattern, args: Union[List, Tuple, TuplePattern] = None, - tir_vars: Optional[DFPattern] = None + tir_vars: Optional[DFPattern] = None, ) -> CallPattern: if args is None: args = wildcard() @@ -887,7 +887,7 @@ def _is_call_tir( def is_call_tir( func_name: str, args: Union[List, Tuple, TuplePattern] = None, - tir_vars: Optional[DFPattern] = None + tir_vars: Optional[DFPattern] = None, ) -> CallPattern: """ Syntax sugar for creating a CallPattern for call_tir that calls an function through global var. From 6cbc0844c7cf7e508e005bedefefdd6daacfa123 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 31 Jan 2024 21:12:08 +0000 Subject: [PATCH 4/4] lint --- src/relax/ir/dataflow_pattern.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 19f7f7848749..ca81b910126a 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -574,7 +574,8 @@ ConstantPattern IsConst() { return ConstantPattern(make_object()); } ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } -CallPattern IsCallTIR(const String& name, Optional var_args, Optional tir_vars) { +CallPattern IsCallTIR(const String& name, Optional var_args, + Optional tir_vars) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard();