Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,19 +871,23 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
def _is_call_tir(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
tir_vars: Optional[DFPattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
if tir_vars is None:
return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False)


# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
def is_call_tir(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
tir_vars: Optional[DFPattern] = None,
) -> CallPattern:
"""
Syntax sugar for creating a CallPattern for call_tir that calls an function through global var.
Expand All @@ -894,14 +898,15 @@ def is_call_tir(
Name of the CPS function to call.
args : Union[List[DFPattern], Tuple[DFPattern]], optional
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments

tir_vars : Optional[DFPattern]
Pattern to match the tuple of integers that are unpacked when calling the tir func.
Returns
-------
CallPattern
The resulting CallPattern
"""
func_pattern = GlobalVarPattern(func_name)
return _is_call_tir(func_pattern, args)
return _is_call_tir(func_pattern, args, tir_vars)


def _is_call_dps_packed(
Expand Down
6 changes: 5 additions & 1 deletion src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,18 @@ ConstantPattern IsConst() { return ConstantPattern(make_object<ConstantPatternNo
WildcardPattern Wildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args) {
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args,
Optional<DFPattern> tir_vars) {
DFPattern arg_pattern;
if (!var_args.defined()) {
arg_pattern = Wildcard();
} else {
arg_pattern = var_args.value();
}

if (tir_vars.defined()) {
return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, tir_vars.value());
}
return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern);
}

Expand Down
21 changes: 18 additions & 3 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,27 @@ def tir_relu(x: T.handle, y: T.handle):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.max(A[vi, vj], 0.0)

@T.prim_func
def tir_zeros(x: T.handle, n: T.int64):
T.func_attr({"global_symbol": "tir_zeros"})
A = T.match_buffer(x, [n])
for i in range(n):
with T.block():
vi = T.axis.remap("S", [i])
A[vi] = 1.0

@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
lv2 = R.call_tir(
cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
)
gv = (lv1, lv2)
R.output(gv)
return gv


main_fn = Module["main"]
Expand Down Expand Up @@ -293,10 +306,12 @@ def test_match_call_attr():

def test_is_call_tir():
lv1_val = bindings[1].value
lv2_val = bindings[2].value
var2val = get_var2val(Module["main"])
assert is_call_tir("tir_relu").match(lv1_val)
assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)


@R.function
Expand Down