@@ -134,6 +134,134 @@ def transform_function(
134134 return OptimizeLUTs ().visit (func )
135135
136136
137+ class LayoutOptimization (ExprMutator ):
138+ """A pass to optimize the layout of NPU operations. If both the
139+ producer and consumer of a tensor are NPU operators, then the
140+ layout is converted from NHWC to NHCWB16.
141+
142+ Attributes
143+ ----------
144+ children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]]
145+ A map from current call to a list of calls that rely on the current
146+ call. This allows the graph to be traversed backwards, which is useful
147+ for checking whether the output layouts can be rewritten.
148+ optimize_op : Dict[str, Callable]
149+ A map from NPU op name to function that creates NPU op.
150+ """
151+
152+ def __init__ (self ):
153+ self .children = {}
154+ self .optimize_op = {
155+ "contrib.ethosu.conv2d" : op .ethosu_conv2d ,
156+ "contrib.ethosu.depthwise_conv2d" : op .ethosu_depthwise_conv2d ,
157+ "contrib.ethosu.pooling" : op .ethosu_pooling ,
158+ "contrib.ethosu.binary_elementwise" : op .ethosu_binary_elementwise ,
159+ "contrib.ethosu.unary_elementwise" : op .ethosu_unary_elementwise ,
160+ }
161+
162+ super ().__init__ ()
163+
164+ def alter_ethosu_op_layout (self , call : tvm .relay .expr .Call ) -> tvm .relay .expr .Call :
165+ """Alter the input and output layouts of an NPU operation if needed.
166+ Input layout is only altered if the producing operation is an NPU
167+ operation. Likewise, the output layout is only altered if the consuming
168+ operation is an NPU operation.
169+
170+ Parameters
171+ ----------
172+ call : tvm.relay.expr.Call
173+ The call pointing to an NPU operation that will be checked if
174+ the layout needs altering.
175+
176+ Returns
177+ -------
178+ new_call : tvm.relay.expr.Call
179+ New call with altered layouts.
180+ """
181+ assert isinstance (call .attrs , tvm .ir .Attrs ), (
182+ f"The attributes for operator '{ call .op .name } ' could not be "
183+ "found. Did you register the relay.attrs.Ethosu<opname>Attrs "
184+ "object in python api?"
185+ )
186+
187+ new_attrs = dict (call .attrs )
188+ parents = []
189+
190+ # Check if we can rewrite the input layouts
191+ input_count = 0
192+ for arg in call .args :
193+ input_count += 1
194+ if not isinstance (arg , tvm .relay .expr .Call ):
195+ continue
196+ if isinstance (arg .op , tvm .ir .op .Op ) and arg .op .name in self .optimize_op :
197+ layout_string = "ifm_layout" if input_count <= 1 else f"ifm{ input_count } _layout"
198+ new_attrs [layout_string ] = "NHCWB16"
199+ parents .append (arg )
200+
201+ # Check if we can rewrite the output layouts
202+ if call in self .children :
203+ children = self .children [call ]
204+ if all (
205+ isinstance (child , tvm .relay .expr .Call )
206+ and isinstance (child .op , tvm .ir .op .Op )
207+ and child .op .name in self .optimize_op
208+ and child .attrs ["ifm_layout" ] == "NHCWB16"
209+ for child in children
210+ ):
211+ new_attrs ["ofm_layout" ] = "NHCWB16"
212+
213+ name = call .op .name
214+ assert name in self .optimize_op , (
215+ f"Could not create operator '{ name } ' as the creation function "
216+ "is unknown. Please provide a mapping."
217+ )
218+ new_call = self .optimize_op [name ](* call .args , ** new_attrs )
219+
220+ # Update map of children
221+ for input_arg in parents :
222+ if input_arg in self .children :
223+ self .children [input_arg ].append (new_call )
224+ else :
225+ self .children [input_arg ] = [new_call ]
226+
227+ return super ().visit_call (new_call )
228+
229+ def visit_call (self , call : tvm .relay .expr .Call ) -> tvm .relay .expr .Call :
230+ """Recursively visit call nodes in the input graph and alter the
231+ layout of an op if needed.
232+
233+ Parameters
234+ ----------
235+ call : tvm.relay.expr.Call
236+ The current call node being visited.
237+
238+ Returns
239+ -------
240+ tvm.relay.expr.Call
241+ The input call node in the case the current call node does
242+ not refer to an Op. Else, a new call node with altered Op
243+ attributes.
244+ """
245+ if isinstance (call .op , tvm .ir .op .Op ) and call .op .name in self .optimize_op :
246+ return self .alter_ethosu_op_layout (call )
247+ return super ().visit_call (call )
248+
249+
250+ @relay .transform .function_pass (opt_level = 1 , name = "LayoutOptimizer" )
251+ class LayoutOptimizer (Pass ):
252+ """Register LayoutOptimizer as a Relay pass."""
253+
254+ def transform_function (
255+ self , func : tvm .relay .function .Function , mod : tvm .IRModule , _
256+ ) -> tvm .IRModule :
257+ """A pass to optimize the layout of NPU operations. If both the
258+ producer and consumer of a tensor are NPU operators, then the
259+ layout is converted from NHWC to NHCWB16 as this is the layout NPU
260+ uses internally."""
261+ assert len (mod .functions .items ()) == 1 , "Module can only contain one function."
262+ return LayoutOptimization ().visit (func )
263+
264+
137265@tvm ._ffi .register_func ("relay.ext.ethos-u.constant_updater" )
138266def constant_updater (expr , symbol ): # pylint: disable=unused-argument
139267 """
0 commit comments