|
22 | 22 | from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU |
23 | 23 | from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator |
24 | 24 | from tvm.relay.backend.contrib.ethosu import util |
| 25 | +from tvm.relay.expr_functor import ExprMutator |
| 26 | +from tvm.ir.transform import Pass |
| 27 | + |
| 28 | +# pylint: disable=unused-import |
| 29 | +from tvm.relay.backend.contrib.ethosu.op import op_attrs |
| 30 | +from tvm.relay.backend.contrib.ethosu import op |
| 31 | + |
| 32 | + |
| 33 | +class OptimizeLUTs(ExprMutator): |
| 34 | + """A pass to merge an identity operator with a LUT based activation function with |
| 35 | + a preceding operator provided that operator can do a table lookup for the activation |
| 36 | + in the hardware""" |
| 37 | + |
| 38 | + def __init__(self): |
| 39 | + super().__init__() |
| 40 | + self.lut_ops = { |
| 41 | + "contrib.ethosu.conv2d": op.ethosu_conv2d, |
| 42 | + "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, |
| 43 | + "contrib.ethosu.pooling": op.ethosu_pooling, |
| 44 | + } |
| 45 | + |
| 46 | + def create_op_with_lut(self, call): |
| 47 | + """Extract the parameters and attributes from the NPU operator and create |
| 48 | + a new operator with LUT. |
| 49 | +
|
| 50 | + Parameters |
| 51 | + ---------- |
| 52 | + call : tvm.relay.expr.Call |
| 53 | + The current call node being visited. |
| 54 | +
|
| 55 | + Returns |
| 56 | + ------- |
| 57 | + tvm.relay.expr.Call |
| 58 | + The new operator with LUT. |
| 59 | + """ |
| 60 | + identity = call |
| 61 | + ethosu_op = call.args[0] |
| 62 | + lut = identity.args[1] |
| 63 | + activation = identity.attrs.activation |
| 64 | + |
| 65 | + new_attrs = dict(ethosu_op.attrs) |
| 66 | + new_attrs["activation"] = activation |
| 67 | + |
| 68 | + # Assume that LUT is always the last argument |
| 69 | + new_args = ethosu_op.args[:-1] + [lut] |
| 70 | + assert ethosu_op.op.name in self.lut_ops.keys() |
| 71 | + |
| 72 | + return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs) |
| 73 | + |
| 74 | + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: |
| 75 | + """Recursively visit call nodes in the input graph and if an ethosu.identity |
| 76 | + operator with LUT is found and the preceding operator has a LUT attribute, create |
| 77 | + a new NPU operator. |
| 78 | +
|
| 79 | + Parameters |
| 80 | + ---------- |
| 81 | + call : tvm.relay.expr.Call |
| 82 | + The current call node being visited. |
| 83 | +
|
| 84 | + Returns |
| 85 | + ------- |
| 86 | + tvm.relay.expr.Call |
| 87 | + The input call node in the case the current call node does |
| 88 | + not refer to an Op. Else, a new call node with a new operator. |
| 89 | + """ |
| 90 | + new_call = call |
| 91 | + lut_activations = ["TANH", "LUT"] |
| 92 | + |
| 93 | + if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call): |
| 94 | + producer_op = call.args[0] |
| 95 | + # Check if the producer can do a LUT operation |
| 96 | + if ( |
| 97 | + producer_op.op.name in self.lut_ops.keys() |
| 98 | + and call.op.name == "contrib.ethosu.identity" |
| 99 | + and call.attrs.activation in lut_activations |
| 100 | + ): |
| 101 | + # Check the producer doesn't already have a LUT |
| 102 | + has_lut = producer_op.attrs.activation in lut_activations |
| 103 | + if not has_lut: |
| 104 | + new_call = self.create_op_with_lut(call) |
| 105 | + |
| 106 | + new_call = super().visit_call(new_call) |
| 107 | + |
| 108 | + return new_call |
| 109 | + |
| 110 | + |
| 111 | +@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") |
| 112 | +class LUTsOptimizer(Pass): |
| 113 | + """Register LUTsOptimizer as a relay pass.""" |
| 114 | + |
| 115 | + def transform_function( |
| 116 | + self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ |
| 117 | + ) -> tvm.IRModule: |
| 118 | + """Visit relay nodes in the given module. |
| 119 | +
|
| 120 | + Parameters |
| 121 | + ---------- |
| 122 | + func : tvm.relay.function.Function |
| 123 | + The function to apply the optimization pass for multiple LUTs to. |
| 124 | + mod : tvm.IRModule |
| 125 | + The module to apply the optimization pass for multiple LUTs to. |
| 126 | +
|
| 127 | + Returns |
| 128 | + ------- |
| 129 | + mod : tvm.IRModule |
| 130 | + New module with optimized LUTs. |
| 131 | + """ |
| 132 | + assert len(mod.functions.items()) == 1, "Module can only contain one function." |
| 133 | + return OptimizeLUTs().visit(func) |
25 | 134 |
|
26 | 135 |
|
27 | 136 | @tvm._ffi.register_func("relay.ext.ethos-u") |
@@ -74,6 +183,7 @@ def _compile(ext_func): |
74 | 183 | mod = tvm.IRModule() |
75 | 184 | mod["main"] = ext_func |
76 | 185 | mod = LegalizeEthosU()(mod) |
| 186 | + mod = LUTsOptimizer()(mod) |
77 | 187 | mod = relay.transform.InferType()(mod) |
78 | 188 | # We are currently using copy_constants scheduler In the long run, |
79 | 189 | # this should be a single intelligent and a composite scheduler |
|
0 commit comments