Skip to content

Commit 349d736

Browse files
ibsidorenkoMikael Sevenier
authored andcommitted
[QNN] Change in Pass Context for lookup table calculation (apache#13660)
Motivation: It is possible to disable specific passes through the "disabled_pass" parameter in the Pass Context. These "disabled" passes can be optional for one target and mandatory for another one. Since lookup table for some QNN operations (tanh, round and etc.) is calculated on the host and some of disabled passes can be required for the host, no need to disable these passes. This constant calculation/ evaluation is orthogonal to the compilation process for specific target. What was changed: This commit creates its own compilation Pass Context for lookup table calculation and evaluation (for elemwise QNN ops: tanh, sqrt ...).
1 parent 5e6ae53 commit 349d736

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

python/tvm/relay/qnn/op/canonicalizations.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,25 @@
2323

2424

2525
def run_const_expr(expr: "relay.Expr") -> np.ndarray:
26-
"""Evaluate a const expression, receiving result as np array."""
27-
mod = tvm.IRModule.from_expr(expr)
28-
vm_exe = relay.create_executor("vm", mod=mod)
29-
return vm_exe.evaluate()().asnumpy()
26+
"""Evaluate a const expression, receiving result as np array.
27+
28+
If a number of passes are disabled in the current Pass Context, then there is no need to disable
29+
these passes for const expression evaluation as well. That's why we use empty list
30+
"disabled_pass=[]", all other arguments are inherited from the current Pass Context.
31+
"""
32+
curr_pass_ctx = tvm.ir.transform.PassContext.current()
33+
with tvm.ir.transform.PassContext(
34+
opt_level=curr_pass_ctx.opt_level,
35+
required_pass=curr_pass_ctx.required_pass,
36+
disabled_pass=[],
37+
instruments=curr_pass_ctx.instruments,
38+
config=curr_pass_ctx.config,
39+
):
40+
mod = tvm.IRModule.from_expr(expr)
41+
vm_exe = relay.create_executor("vm", mod=mod)
42+
output = vm_exe.evaluate()().asnumpy()
43+
44+
return output
3045

3146

3247
def create_integer_lookup_table(

0 commit comments

Comments
 (0)