Skip to content

Commit 35d2e8b

Browse files
author
Tristan Konolige
authored
[TE COMPILER] Propagate structural hash from relay function to TIR function (#10475)
The structural hash of each relay function is copied to the TIR function so that users can associate relay functions with their lowered TIR version.
1 parent 7688db7 commit 35d2e8b

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/relay/backend/te_compiler.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,14 @@ class TECompilerImpl : public TECompilerNode {
350350
GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint
351351
? value->cached_func->prim_fn_var
352352
: kv.first;
353-
value->cached_func->funcs->Add(global_var, kv.second);
353+
auto func = kv.second;
354+
// Propagate the structural hash of the relay function to the tir
355+
// function so associations can be made between the two.
356+
Optional<String> hash = key->source_func->attrs.GetAttr<String>("hash");
357+
if (hash) {
358+
func = WithAttrs(Downcast<tir::PrimFunc>(func), {{String("hash"), hash.value()}});
359+
}
360+
value->cached_func->funcs->Add(global_var, func);
354361
}
355362
ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var)
356363
.as<tir::PrimFuncNode>());

tests/python/relay/test_relay_te_compiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,30 @@ def test_compile_nhwc_pack():
261261
relay.build(mod, target="llvm")
262262

263263

264+
def test_compile_propogate_hash():
265+
data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8")
266+
weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8")
267+
p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32")
268+
conv = relay.nn.conv2d(
269+
data,
270+
weight,
271+
kernel_size=(1, 1),
272+
data_layout="NHWC",
273+
kernel_layout="HWIO",
274+
out_dtype="int32",
275+
)
276+
multiply = relay.multiply(relay.const(-22, dtype="int32"), p2)
277+
tile = relay.tile(multiply, reps=(1, 1, 1, 1001))
278+
subtract = relay.subtract(conv, tile)
279+
280+
func = subtract
281+
mod = tvm.IRModule.from_expr(relay.Function(relay.analysis.free_vars(func), func))
282+
vm = relay.vm.VMCompiler()
283+
opt_mod, _ = vm.optimize(mod, target="llvm")
284+
for f in opt_mod.functions.values():
285+
assert "hash" in f.attrs.keys()
286+
287+
264288
if __name__ == "__main__":
265289
test_get_valid_implementations()
266290
test_select_implementation()

0 commit comments

Comments
 (0)