1515# specific language governing permissions and limitations
1616# under the License.
1717"""Codegen for Arm(R) Ethos(TM)-U NPU"""
18+ from collections import defaultdict
1819
1920import tvm
2021from tvm import relay
2425from tvm .relay .backend .contrib .ethosu .legalize import LegalizeEthosU
2526from tvm .relay .backend .contrib .ethosu import tir_to_cs_translator
2627from tvm .relay .backend .contrib .ethosu import util
27- from tvm .relay .expr_functor import ExprMutator
28+ from tvm .relay .expr_functor import ExprMutator , ExprVisitor
2829
2930# pylint: disable=unused-import
3031from tvm .relay .backend .contrib .ethosu .op import op_attrs
@@ -138,38 +139,76 @@ def __call__(self, *args, **kwargs):
138139 pass
139140
140141
141- class LayoutOptimization (ExprMutator ):
142- """A pass to optimize the layout of NPU operations. If both the
143- producer and consumer of a tensor are NPU operators, then the
144- layout is converted from NHWC to NHCWB16.
142+ class AnalyzeConsumers (ExprVisitor ):
143+ """Traverses the graph to determine consumers that are NPU operations. The
144+ result is maintained in `npu_consumers`.
145145
146146 Attributes
147147 ----------
148- children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]]
149- A map from current call to a list of calls that rely on the current
150- call. This allows the graph to be traversed backwards, which is useful
151- for checking whether the output layouts can be rewritten.
152- optimize_op : Dict[str, Callable]
153- A map from NPU op name to function that creates NPU op.
148+ npu_consumers : Dict[tvm.relay.expr.Call, List[bool]]
149+ Mapping from NPU operation to list of boolean values that represent
150+ whether or not each consumer is an NPU operation.
151+ optimize_ops : Dict[str, Callable]
152+ A map from NPU operation name to function that creates NPU operation.
154153 """
155154
156- def __init__ (self ):
157- self .children = {}
158- self .optimize_op = {
159- "contrib.ethosu.conv2d" : op .ethosu_conv2d ,
160- "contrib.ethosu.depthwise_conv2d" : op .ethosu_depthwise_conv2d ,
161- "contrib.ethosu.pooling" : op .ethosu_pooling ,
162- "contrib.ethosu.binary_elementwise" : op .ethosu_binary_elementwise ,
163- "contrib.ethosu.unary_elementwise" : op .ethosu_unary_elementwise ,
164- }
155+ def __init__ (self , optimize_ops ):
156+ self .npu_consumers = defaultdict (list )
157+ self .optimize_ops = optimize_ops
158+ super ().__init__ ()
159+
160+ def visit_call (self , call : relay .Call ):
161+ is_npu_consumer = call .op .name in self .optimize_ops
162+ args = []
165163
164+ # Expand tuples
165+ for arg in call .args :
166+ if isinstance (arg , relay .Tuple ):
167+ args .extend (arg .fields )
168+ else :
169+ args .append (arg )
170+
171+ for arg in args :
172+ if isinstance (arg , relay .Call ) and arg .op .name in self .optimize_ops :
173+ self .npu_consumers [arg ].append (is_npu_consumer )
174+
175+ super ().visit_call (call )
176+
177+
178+ class LayoutOptimization (ExprMutator ):
179+ """A pass to optimize the layout of NPU operations by converting to brick format (NHCWB16).
180+ This pass traverses the graph and attempts to alter the input/output layouts when an NPU
181+ operation is visited. Whether or not the input/output layout can be altered for a given NPU
182+ operation depends on the following:
183+
184+ Check alter input layout: For each argument, if the producer is also an NPU operation and
185+ its output is altered to brick format, then the input layout with respect to the current
186+ argument is altered to brick format.
187+
188+ Check alter output layout: If all consumers (child nodes) are an NPU operation, then the
189+ output layout is altered to brick format.
190+
191+ Note
192+ ----
193+ In order for this pass to be run, the consumers of each NPU operation must first be analyzed
194+ by the `AnalyzeConsumers` pass, since Relay doesn't keep a reference to child nodes.
195+
196+ Attributes
197+ ----------
198+ npu_consumers : Dict[tvm.relay.expr.Call, bool]
199+ A map from current call to a list boolean values that state whether or not each consumer
200+ is an NPU operation.
201+ optimize_ops : Dict[str, Callable]
202+ A map from NPU operation name to function that creates NPU operation.
203+ """
204+
205+ def __init__ (self , npu_consumers , optimize_ops ):
206+ self .npu_consumers = npu_consumers
207+ self .optimize_ops = optimize_ops
166208 super ().__init__ ()
167209
168210 def alter_ethosu_op_layout (self , call : tvm .relay .expr .Call ) -> tvm .relay .expr .Call :
169- """Alter the input and output layouts of an NPU operation if needed.
170- Input layout is only altered if the producing operation is an NPU
171- operation. Likewise, the output layout is only altered if the consuming
172- operation is an NPU operation.
211+ """Alter the layouts of given NPU operation to brick format if possible.
173212
174213 Parameters
175214 ----------
@@ -189,46 +228,26 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca
189228 )
190229
191230 new_attrs = dict (call .attrs )
192- parents = []
193231
194232 # Check if we can rewrite the input layouts
195233 input_count = 0
196234 for arg in call .args :
197235 input_count += 1
198- if not isinstance ( arg , tvm . relay . expr . Call ) :
236+ if arg not in self . npu_consumers :
199237 continue
200- if isinstance (arg .op , tvm .ir .op .Op ) and arg .op .name in self .optimize_op :
238+ consumers = self .npu_consumers [arg ]
239+ parent_has_brick_output = consumers and all (consumers )
240+ if parent_has_brick_output :
201241 layout_string = "ifm_layout" if input_count <= 1 else f"ifm{ input_count } _layout"
202242 new_attrs [layout_string ] = "NHCWB16"
203- parents .append (arg )
204243
205244 # Check if we can rewrite the output layouts
206- if call in self .children :
207- children = self .children [call ]
208- if all (
209- isinstance (child , tvm .relay .expr .Call )
210- and isinstance (child .op , tvm .ir .op .Op )
211- and child .op .name in self .optimize_op
212- and child .attrs ["ifm_layout" ] == "NHCWB16"
213- for child in children
214- ):
215- new_attrs ["ofm_layout" ] = "NHCWB16"
245+ consumers = self .npu_consumers [call ]
246+ if consumers and all (consumers ):
247+ new_attrs ["ofm_layout" ] = "NHCWB16"
216248
217249 name = call .op .name
218- assert name in self .optimize_op , (
219- f"Could not create operator '{ name } ' as the creation function "
220- "is unknown. Please provide a mapping."
221- )
222- new_call = self .optimize_op [name ](* call .args , ** new_attrs )
223-
224- # Update map of children
225- for input_arg in parents :
226- if input_arg in self .children :
227- self .children [input_arg ].append (new_call )
228- else :
229- self .children [input_arg ] = [new_call ]
230-
231- return super ().visit_call (new_call )
250+ return self .optimize_ops [name ](* call .args , ** new_attrs )
232251
233252 def visit_call (self , call : tvm .relay .expr .Call ) -> tvm .relay .expr .Call :
234253 """Recursively visit call nodes in the input graph and alter the
@@ -246,23 +265,33 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
246265 not refer to an Op. Else, a new call node with altered Op
247266 attributes.
248267 """
249- if isinstance (call .op , tvm .ir .op . Op ) and call .op .name in self .optimize_op :
250- return self .alter_ethosu_op_layout (call )
268+ if isinstance (call .op , tvm .ir .Op ) and call .op .name in self .optimize_ops :
269+ call = self .alter_ethosu_op_layout (call )
251270 return super ().visit_call (call )
252271
253272
254273@ir .transform .module_pass (opt_level = 1 , name = "LayoutOptimizer" )
255274class LayoutOptimizer :
256275 """Register LayoutOptimizer as a Relay pass."""
257276
277+ OPTIMIZE_OPS = {
278+ "contrib.ethosu.conv2d" : op .ethosu_conv2d ,
279+ "contrib.ethosu.depthwise_conv2d" : op .ethosu_depthwise_conv2d ,
280+ "contrib.ethosu.pooling" : op .ethosu_pooling ,
281+ "contrib.ethosu.binary_elementwise" : op .ethosu_binary_elementwise ,
282+ "contrib.ethosu.unary_elementwise" : op .ethosu_unary_elementwise ,
283+ }
284+
258285 def transform_module (self , mod : tvm .ir .IRModule , _ ) -> tvm .IRModule :
259286 """A pass to optimize the layout of NPU operations. If both the
260287 producer and consumer of a tensor are NPU operators, then the
261288 layout is converted from NHWC to NHCWB16 as this is the layout NPU
262289 uses internally."""
263290 assert len (mod .functions .items ()) == 1 , "Module can only contain one function."
264291 global_var , func = mod .functions .items ()[0 ]
265- optimized_func = LayoutOptimization ().visit (func )
292+ analyze = AnalyzeConsumers (self .OPTIMIZE_OPS )
293+ analyze .visit (func )
294+ optimized_func = LayoutOptimization (analyze .npu_consumers , self .OPTIMIZE_OPS ).visit (func )
266295 mod .update_func (global_var , optimized_func )
267296 return mod
268297
0 commit comments