Skip to content

Commit f21c5a4

Browse files
authored
[DPL] Support tir_vars field in is_call_tir pattern (#16494)
* [DPL] Support tir_vars field in is_call_tir pattern * add doc * lint * lint
1 parent 00a2c7d commit f21c5a4

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

python/tvm/relax/dpl/pattern.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -871,19 +871,23 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
871871
def _is_call_tir(
872872
func_pattern: DFPattern,
873873
args: Union[List, Tuple, TuplePattern] = None,
874+
tir_vars: Optional[DFPattern] = None,
874875
) -> CallPattern:
875876
if args is None:
876877
args = wildcard()
877878
elif isinstance(args, (list, tuple)):
878879
args = TuplePattern(args)
879880

880-
return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
881+
if tir_vars is None:
882+
return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
883+
return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False)
881884

882885

883886
# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
884887
def is_call_tir(
885888
func_name: str,
886889
args: Union[List, Tuple, TuplePattern] = None,
890+
tir_vars: Optional[DFPattern] = None,
887891
) -> CallPattern:
888892
"""
889893
Syntax sugar for creating a CallPattern for call_tir that calls an function through global var.
@@ -894,14 +898,15 @@ def is_call_tir(
894898
Name of the CPS function to call.
895899
args : Union[List[DFPattern], Tuple[DFPattern]], optional
896900
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments
897-
901+
tir_vars : Optional[DFPattern]
902+
Pattern to match the tuple of integers that are unpacked when calling the tir func.
898903
Returns
899904
-------
900905
CallPattern
901906
The resulting CallPattern
902907
"""
903908
func_pattern = GlobalVarPattern(func_name)
904-
return _is_call_tir(func_pattern, args)
909+
return _is_call_tir(func_pattern, args, tir_vars)
905910

906911

907912
def _is_call_dps_packed(

src/relax/ir/dataflow_pattern.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,14 +574,18 @@ ConstantPattern IsConst() { return ConstantPattern(make_object<ConstantPatternNo
574574
WildcardPattern Wildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
575575
ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
576576
ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
577-
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args) {
577+
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args,
578+
Optional<DFPattern> tir_vars) {
578579
DFPattern arg_pattern;
579580
if (!var_args.defined()) {
580581
arg_pattern = Wildcard();
581582
} else {
582583
arg_pattern = var_args.value();
583584
}
584585

586+
if (tir_vars.defined()) {
587+
return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, tir_vars.value());
588+
}
585589
return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern);
586590
}
587591

tests/python/relax/test_dataflow_pattern.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,27 @@ def tir_relu(x: T.handle, y: T.handle):
5656
vi, vj = T.axis.remap("SS", [i, j])
5757
B[vi, vj] = T.max(A[vi, vj], 0.0)
5858

59+
@T.prim_func
60+
def tir_zeros(x: T.handle, n: T.int64):
61+
T.func_attr({"global_symbol": "tir_zeros"})
62+
A = T.match_buffer(x, [n])
63+
for i in range(n):
64+
with T.block():
65+
vi = T.axis.remap("S", [i])
66+
A[vi] = 1.0
67+
5968
@R.function
60-
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
69+
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
6170
cls = Module
6271
with R.dataflow():
6372
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
6473
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
65-
R.output(lv1)
66-
return lv1
74+
lv2 = R.call_tir(
75+
cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
76+
)
77+
gv = (lv1, lv2)
78+
R.output(gv)
79+
return gv
6780

6881

6982
main_fn = Module["main"]
@@ -293,10 +306,12 @@ def test_match_call_attr():
293306

294307
def test_is_call_tir():
295308
lv1_val = bindings[1].value
309+
lv2_val = bindings[2].value
296310
var2val = get_var2val(Module["main"])
297311
assert is_call_tir("tir_relu").match(lv1_val)
298312
assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
299313
assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
314+
assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)
300315

301316

302317
@R.function

0 commit comments

Comments
 (0)