@@ -46,9 +46,12 @@ def __init__(self):
4646 def create_op_with_lut (self , call ):
4747 """Extract the parameters and attributes from the NPU operator and create
4848 a new operator with LUT.
49+
50+ Parameters
4951 ----------
5052 call : tvm.relay.expr.Call
5153 The current call node being visited.
54+
5255 Returns
5356 -------
5457 tvm.relay.expr.Call
@@ -63,8 +66,7 @@ def create_op_with_lut(self, call):
6366 new_attrs ["activation" ] = activation
6467
6568 # Assume that LUT is always the last argument
66- new_args = [ethosu_op .args [n ] for n in range (len (ethosu_op .args ) - 1 )]
67- new_args .append (lut )
69+ new_args = ethosu_op .args [:- 1 ] + [lut ]
6870 assert ethosu_op .op .name in self .lut_ops .keys ()
6971
7072 return self .lut_ops [ethosu_op .op .name ](* new_args , ** new_attrs )
@@ -73,10 +75,12 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
7375 """Recursively visit call nodes in the input graph and if an ethosu.identity
7476 operator with LUT is found and the preceding operator has a LUT attribute, create
7577 a new NPU operator.
78+
7679 Parameters
7780 ----------
7881 call : tvm.relay.expr.Call
7982 The current call node being visited.
83+
8084 Returns
8185 -------
8286 tvm.relay.expr.Call
@@ -104,24 +108,26 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
104108 return new_call
105109
106110
107- @relay .transform .function_pass (opt_level = 1 , name = "LutOptimizer " )
111+ @relay .transform .function_pass (opt_level = 1 , name = "LUTsOptimizer " )
108112class LUTsOptimizer (Pass ):
109- """Register LutOptimizer as a relay pass."""
113+ """Register LUTsOptimizer as a relay pass."""
110114
111115 def transform_function (
112116 self , func : tvm .relay .function .Function , mod : tvm .IRModule , _
113117 ) -> tvm .IRModule :
114118 """Visit relay nodes in the given module.
119+
115120 Parameters
116121 ----------
117122 func : tvm.relay.function.Function
118- The function to apply the layout optimization pass to.
123+ The function to apply the optimization pass for multiple LUTs to.
119124 mod : tvm.IRModule
120- The module to apply the layout optimization pass to.
125+ The module to apply the optimization pass for multiple LUTs to.
126+
121127 Returns
122128 -------
123129 mod : tvm.IRModule
124- New module with augmented layouts .
130+ New module with optimized LUTs .
125131 """
126132 assert len (mod .functions .items ()) == 1 , "Module can only contain one function."
127133 return OptimizeLUTs ().visit (func )
0 commit comments