Skip to content

Commit 059f629

Browse files
authored
[BYOC] Skip processed functions in FuseOpsByPattern and RunCodegen (#16567)
1 parent 07ecb34 commit 059f629

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,8 @@ class CompositeFunctionAnnotator : public ExprMutator {
11991199
auto all_functions = mod->functions;
12001200
for (const auto& entry : all_functions) {
12011201
if (const auto* func = entry.second.as<FunctionNode>()) {
1202-
if (func->GetAttr<String>(attr::kComposite).defined()) {
1202+
if (func->GetAttr<String>(attr::kComposite).defined() ||
1203+
func->GetAttr<String>(attr::kCodegen).defined()) {
12031204
continue;
12041205
}
12051206
auto new_body = VisitExpr(func->body);
@@ -1270,6 +1271,13 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
12701271
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
12711272
continue;
12721273
}
1274+
const FunctionNode* function = entry.second.as<FunctionNode>();
1275+
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
1276+
function->GetAttr<String>(attr::kComposite).defined() ||
1277+
function->GetAttr<String>(attr::kCodegen).defined()) {
1278+
continue;
1279+
}
1280+
12731281
auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
12741282
pattern->annotation_patterns,
12751283
pattern->check.value_or(nullptr), entry.second,

tests/python/relax/test_transform_fuse_ops_by_pattern.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,5 +1046,14 @@ def main(
10461046
assert "fused_relax_permute_dims_relax_matmul_cublas" in func_names # add is not fused
10471047

10481048

1049+
def test_multple_runs():
1050+
check(
1051+
Conv2dReLU_composite_annotated,
1052+
[("dnnl.conv2d_relu", conv2d_relu_pat)],
1053+
Conv2dReLU_composite_annotated,
1054+
annotate_codegen=True,
1055+
)
1056+
1057+
10491058
if __name__ == "__main__":
10501059
pytest.main([__file__])

0 commit comments

Comments
 (0)