From ecb117c5707bb4d3640ee58aa0e10c08f9d830c7 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 3 Mar 2022 14:41:51 -0800 Subject: [PATCH] [TE COMPILER] Propagate structural hash from relay function to TIR function The structural hash of each relay function is copied to the TIR function so that users can associate relay functions with their lowered TIR version. --- src/relay/backend/te_compiler.cc | 9 +++++++- tests/python/relay/test_relay_te_compiler.py | 24 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index cfc0ad0087fc..30eba069ca3a 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -350,7 +350,14 @@ class TECompilerImpl : public TECompilerNode { GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint ? value->cached_func->prim_fn_var : kv.first; - value->cached_func->funcs->Add(global_var, kv.second); + auto func = kv.second; + // Propagate the structural hash of the relay function to the tir + // function so associations can be made between the two. + Optional hash = key->source_func->attrs.GetAttr("hash"); + if (hash) { + func = WithAttrs(Downcast(func), {{String("hash"), hash.value()}}); + } + value->cached_func->funcs->Add(global_var, func); } ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var) .as()); diff --git a/tests/python/relay/test_relay_te_compiler.py b/tests/python/relay/test_relay_te_compiler.py index f8498ae83648..e200e79c1532 100644 --- a/tests/python/relay/test_relay_te_compiler.py +++ b/tests/python/relay/test_relay_te_compiler.py @@ -261,6 +261,30 @@ def test_compile_nhwc_pack(): relay.build(mod, target="llvm") +def test_compile_propogate_hash(): + data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") + weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") + p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") + conv = relay.nn.conv2d( + data, + weight, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) + tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + subtract = relay.subtract(conv, tile) + + func = subtract + mod = tvm.IRModule.from_expr(relay.Function(relay.analysis.free_vars(func), func)) + vm = relay.vm.VMCompiler() + opt_mod, _ = vm.optimize(mod, target="llvm") + for f in opt_mod.functions.values(): + assert "hash" in f.attrs.keys() + + if __name__ == "__main__": test_get_valid_implementations() test_select_implementation()