diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index cfc0ad0087fc..30eba069ca3a 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -350,7 +350,14 @@ class TECompilerImpl : public TECompilerNode { GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint ? value->cached_func->prim_fn_var : kv.first; - value->cached_func->funcs->Add(global_var, kv.second); + auto func = kv.second; + // Propagate the structural hash of the relay function to the tir + // function so associations can be made between the two. + Optional hash = key->source_func->attrs.GetAttr("hash"); + if (hash) { + func = WithAttrs(Downcast(func), {{String("hash"), hash.value()}}); + } + value->cached_func->funcs->Add(global_var, func); } ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var) .as()); diff --git a/tests/python/relay/test_relay_te_compiler.py b/tests/python/relay/test_relay_te_compiler.py index f8498ae83648..e200e79c1532 100644 --- a/tests/python/relay/test_relay_te_compiler.py +++ b/tests/python/relay/test_relay_te_compiler.py @@ -261,6 +261,30 @@ def test_compile_nhwc_pack(): relay.build(mod, target="llvm") +def test_compile_propogate_hash(): + data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") + weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") + p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") + conv = relay.nn.conv2d( + data, + weight, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) + tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + subtract = relay.subtract(conv, tile) + + func = subtract + mod = tvm.IRModule.from_expr(relay.Function(relay.analysis.free_vars(func), func)) + vm = relay.vm.VMCompiler() + opt_mod, _ = vm.optimize(mod, target="llvm") + for f in opt_mod.functions.values(): + assert "hash" in f.attrs.keys() + + if __name__ == "__main__": test_get_valid_implementations() test_select_implementation()