Skip to content

Commit ee7e95e

Browse files
authored
Arm backend: Support quantized cond and while (#15849)
Arm backend: Support quantized while_loop - Add annotation logic. - Extend cond handling in q-dq folding to while - Extend InsertCondRescale pass to handle while. ---------------------------------------------------- Arm backend: Add initial while_loop support. - Refactor CondSupported to also test while, move to own file and split into one check for submodule nodes, and one for ops. - Add node visitor - Add tests ----------------------------------------------------- Arm backend: Initial quantization support for conditional The standard prepare/convert_pt2 does not seem to support quantization out of the box. Instead, a quantization call is introduced in the TOSAQuantizer, that does the neccessary steps to get correct quantization on submodules. A custom Quantize step is needed in the ArmTester to make this work in testing. Additionally, getting correct quantization parameters needs some delicate handling. The model is calibrated twice, once for each code path. Because of this, the observers outside the if/else submodules see different data than the observers inside the submodules. Rescales need to be inserted to handle this. To get a correctly traced graph at all times, we 1. Fold the outmost quant ops in the submodules at the same time as the cond is folded. Add qparam meta to folded nodes inside submodule. 2. Use this meta in the InsertCondRescale pass to insert a tosa.RESCALE to handle the different qparams. 3. After this, the submodule's q-dq nodes can be folded normally. Signed-off-by: Erik Lundell <[email protected]>
1 parent 65d4b94 commit ee7e95e

22 files changed

+1019
-168
lines changed

backends/arm/_passes/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@
8888
from .insert_int32_casts_after_int64_placeholders import ( # noqa
8989
InsertInt32CastsAfterInt64PlaceholdersPass,
9090
)
91-
from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa
91+
from .insert_rescales_pass import ( # noqa
92+
InsertControlFlowRescalesPass,
93+
InsertRescaleInt32Pass,
94+
InsertRescalePass,
95+
)
9296
from .insert_table_ops import InsertTableOpsPass # noqa
9397
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
9498
from .match_arg_ranks_pass import MatchArgRanksPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
FuseEqualPlaceholdersPass,
8686
FuseQuantizedActivationPass,
8787
FuseViewCopyTransformPass,
88+
InsertControlFlowRescalesPass,
8889
InsertInt32CastsAfterInt64PlaceholdersPass,
8990
InsertRescaleInt32Pass,
9091
InsertRescalePass,
@@ -195,6 +196,7 @@ def _tosa_pipeline(
195196
# Ticket: MLETORCH-1539
196197
DecomposeLinearPass(),
197198
InsertRescaleInt32Pass(),
199+
InsertControlFlowRescalesPass(),
198200
]
199201
)
200202

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from typing import cast, Optional, Set, Type
1111

12+
import torch
1213
from executorch.backends.arm._passes import ArmPass
1314
from executorch.backends.arm._passes.arm_pass_utils import (
1415
get_param_tensor,
@@ -152,6 +153,83 @@ def fold_and_annotate_arg(
152153
if len(n.users) == 0:
153154
graph_module.graph.erase_node(n)
154155

156+
def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
157+
"""Fold outmost quant nodes inside submodule.
158+
placeholders => qs => dqs => ... => qs => dqs => output
159+
becomes
160+
placeholders => dqs => ... => qs => output,
161+
With output_qparams meta in the placeholders, and input_qparams meta in the output node.
162+
"""
163+
match node.target:
164+
case torch.ops.higher_order.cond:
165+
submodule_nodes = cast(list[Node], node.args[1:3])
166+
args = cast(list[Node], node.args[-1])
167+
case torch.ops.higher_order.while_loop:
168+
submodule_nodes = cast(list[Node], node.args[0:2])
169+
args = cast(list[Node], node.args[-2])
170+
case _:
171+
raise ValueError(f"Unhandled target {node.target}")
172+
submodules = (
173+
graph_module.get_submodule(str(submodule_node.target))
174+
for submodule_node in submodule_nodes
175+
)
176+
for submodule in submodules:
177+
submodule = cast(GraphModule, submodule)
178+
output_node = submodule.graph.output_node()
179+
output_node.meta["input_qparams"] = {}
180+
nodes_to_remove = []
181+
arg_id = 0
182+
for submodule_node in submodule.graph.nodes:
183+
# Remove initial q nodes and ending dq nodes in the module.
184+
submodule_node = cast(Node, submodule_node)
185+
if (
186+
submodule_node.target in Q_OPS
187+
and list(submodule_node.all_input_nodes)[0].op == "placeholder"
188+
):
189+
input_node = cast(Node, submodule_node.args[0])
190+
input_node.meta["val"] = submodule_node.meta["val"]
191+
quant_args = QuantArgs.from_operator(
192+
submodule_node.target, submodule_node.args
193+
)
194+
input_node.meta["output_qparams"] = {0: quant_args}
195+
196+
submodule_node.replace_all_uses_with(input_node)
197+
nodes_to_remove.append(submodule_node)
198+
if submodule_node.target in DQ_OPS:
199+
has_non_output_user = False
200+
for user in copy.copy(submodule_node.users):
201+
if user.op != "output":
202+
has_non_output_user = True
203+
else:
204+
input_node = cast(Node, submodule_node.args[0])
205+
submodule_node.replace_all_uses_with(input_node)
206+
arg_index = cast(list[Node], output_node.args[0]).index(
207+
input_node
208+
)
209+
quant_args = QuantArgs.from_operator(
210+
submodule_node.target, submodule_node.args
211+
)
212+
output_node.meta["input_qparams"][arg_index] = quant_args
213+
214+
# Remove dq node if it only has the output node as its user.
215+
if not has_non_output_user:
216+
nodes_to_remove.append(submodule_node)
217+
# Placeholders without users won't be retraced with correct dtype, do it manually.
218+
# Control flow node input is matched to placeholder nodes in the submodule by index.
219+
# This means it will break if another pass inserts a placeholder before this pass.
220+
if submodule_node.op == "placeholder":
221+
if len(submodule_node.users) == 0:
222+
submodule_node.meta["val"] = args[arg_id].meta["val"]
223+
arg_id += 1
224+
if arg_id > len(args):
225+
raise RuntimeError(
226+
"Submodule had more placeholders than calling node had inputs."
227+
" This is probably due to a placeholder being inserted in a pass."
228+
)
229+
for node_to_remove in nodes_to_remove:
230+
submodule.graph.erase_node(node_to_remove)
231+
return
232+
155233
def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
156234

157235
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
@@ -181,8 +259,8 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
181259
n.meta["input_qparams"] = {}
182260
n.meta["output_qparams"] = {}
183261
for i, arg in enumerate(n.args):
184-
if isinstance(arg, list):
185-
self.fold_and_annotate_arg(graph_module, n, arg, i)
262+
if isinstance(arg, (list, tuple)):
263+
self.fold_and_annotate_arg(graph_module, n, arg, i) # type: ignore
186264

187265
elif isinstance(arg, Node):
188266
self.fold_and_annotate_arg(graph_module, n, [arg], i)
@@ -211,6 +289,12 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
211289
output_dtype = output_qparams[0].dtype
212290
set_node_arg(n, "dtype", output_dtype)
213291

292+
if n.target in (
293+
torch.ops.higher_order.cond,
294+
torch.ops.higher_order.while_loop,
295+
):
296+
self._handle_control_flow_node(n, graph_module)
297+
214298
# retrace the graph to update the fake tensor types
215299
graph_module = super().call(graph_module).graph_module
216300

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,216 @@ def call(self, graph_module: GraphModule) -> PassResult:
369369
graph_module.recompile()
370370

371371
return PassResult(graph_module, modified)
372+
373+
374+
class InsertControlFlowRescalesPass(ArmPass):
375+
"""The quantization parameters for tensors going into and coming out of a submodule are not guaranteed to
376+
match the quantization parameters for the corresponding tensors inside the submodule. For example, cond has
377+
different annotation on input and output, while the entire graph inside the submodule could be using shared
378+
annotation. This pass solves this by inserting rescales in the beginning and end of the submodule
379+
that transform the tensor from one set of quantization parameters to another.
380+
381+
The pass is run by the graph_module containing the control flow operator, but requires that the affected nodes
382+
inside the submodule have been q-dq folded and have input/output_qparams meta.
383+
"""
384+
385+
_passes_required_after: Set[Type[ExportPass]] = set()
386+
387+
def _get_input_nodes(self, graph_module: GraphModule):
388+
return [node for node in graph_module.graph.nodes if node.op == "placeholder"]
389+
390+
def _insert_rescale(
391+
self,
392+
in_qparams: QuantArgs,
393+
out_qparams: QuantArgs,
394+
from_node: Node,
395+
graph_module: GraphModule,
396+
):
397+
"""Insert a rescale into the graph, inheriting meta from `from_node`.
398+
The node is not connected to anything, that is up to the user."""
399+
400+
new_scales = [
401+
in_qparams.get_scale_per_tensor() / out_qparams.get_scale_per_tensor()
402+
]
403+
404+
rescale_node = create_node(
405+
graph_module.graph,
406+
exir_ops.backend.tosa.RESCALE.default,
407+
(
408+
None,
409+
out_qparams.dtype,
410+
new_scales,
411+
in_qparams.get_zp_per_tensor(), # Old zero point
412+
out_qparams.get_zp_per_tensor(), # New zero point
413+
),
414+
from_node=from_node,
415+
)
416+
return rescale_node
417+
418+
def _rescale_submodule_inputs(
419+
self, submodule: GraphModule, input_qparams_map: Dict[int, QuantArgs]
420+
) -> bool:
421+
"""Insert rescales at the inputs of `submodule` to match the qparams outside the submodule.
422+
Matching the correct qparams gets a bit tricky:
423+
Containing module: | submodule:
424+
ops => cond | => placeholders => ...
425+
426+
The dq->q qparam pair we want to convert to a rescale is:
427+
(input qparams of op, output qparams of placeholder)
428+
And the rescale is inserted after the placeholder.
429+
430+
Args:
431+
submodule: GraphModule: the GraphModule in which to rescale the inputs.
432+
input_qparams_map: A map of input indexes mapping to QuantArgs. Not guaranteed to contain a mapping
433+
for every submodule input.
434+
Returns:
435+
True if at least one rescale was inserted, False otherwise.
436+
"""
437+
438+
modified = False
439+
input_nodes = self._get_input_nodes(submodule)
440+
for qargs_index in input_qparams_map:
441+
input_node = input_nodes[qargs_index]
442+
if len(input_node.users) == 0:
443+
continue
444+
if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1:
445+
raise ValueError(
446+
f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}"
447+
)
448+
in_qparams = input_qparams_map[qargs_index]
449+
out_qparams = cast(QuantArgs, out_qparams_map[0])
450+
451+
# Remove qparam meta to not confuse folding pass.
452+
del input_node.meta["output_qparams"]
453+
if in_qparams == out_qparams:
454+
continue
455+
with submodule.graph.inserting_after(input_node):
456+
modified = True
457+
rescale_node = self._insert_rescale(
458+
in_qparams, out_qparams, input_node, submodule
459+
)
460+
input_node.replace_all_uses_with(replace_with=rescale_node)
461+
rescale_node.update_arg(0, input_node)
462+
return modified
463+
464+
def _rescale_submodule_outputs(
465+
self, submodule: GraphModule, output_qparams_map: Dict[int, QuantArgs]
466+
) -> bool:
467+
"""Insert rescales at the outputs of `submodule` to match the qparams outside the submodule.
468+
Matching the correct qparams gets a bit tricky:
469+
Submodule: | Containing module:
470+
output_nodes => output |=> getitems => ...
471+
472+
The dq->q qparam pair we want to convert to a rescale is:
473+
(input qparam of output_node, output qparam of getitem)
474+
And the rescale is inserted between op and output. Note that the output qparam of op is called input_qargs,
475+
since the it is the input to the dq-q pair.
476+
477+
Args:
478+
submodule: GraphModule: the GraphModule in which to rescale the outputs.
479+
output_qparams_map: A map of output indexes mapping to QuantArgs. Not guaranteed to contain a mapping
480+
for every submodule output.
481+
Returns:
482+
True if at least one rescale was inserted, False otherwise.
483+
"""
484+
485+
modified = False
486+
output_node = submodule.graph.output_node()
487+
output_args = list(cast(tuple[Node], output_node.args[0]))
488+
input_qparams_map = cast(
489+
dict[int, QuantArgs], output_node.meta["input_qparams"]
490+
)
491+
for qargs_index in output_qparams_map:
492+
output_arg_node = output_args[qargs_index]
493+
in_qparams = input_qparams_map[qargs_index]
494+
out_qparams = output_qparams_map[qargs_index]
495+
if in_qparams == out_qparams:
496+
continue
497+
with submodule.graph.inserting_before(output_node):
498+
modified = True
499+
rescale_node = self._insert_rescale(
500+
in_qparams, out_qparams, output_arg_node, submodule
501+
)
502+
output_args[qargs_index] = rescale_node
503+
rescale_node.update_arg(0, output_arg_node)
504+
output_node.update_arg(0, tuple(output_args))
505+
# Remove qparam meta to not confuse folding pass.
506+
del output_node.meta["input_qparams"]
507+
return modified
508+
509+
def _get_input_qparams_map(self, node: Node, idx: int):
510+
input_qparams_meta = cast(
511+
dict[int, QuantArgs], node.meta.get("input_qparams", None)
512+
)
513+
if input_qparams_meta:
514+
input_qparams = cast(QuantArgs, input_qparams_meta.get(idx, None))
515+
if not input_qparams:
516+
raise ValueError(
517+
f"Expected entry with key {idx} in input_qparams meta, got {input_qparams_meta}"
518+
)
519+
num_inputs = len(cast(list, node.args[idx]))
520+
521+
# Currently, infra only supports one set of qparams for a list of inputs
522+
# Map all inputs to the same qparams.
523+
input_qparams_map = {i: input_qparams for i in range(num_inputs)}
524+
return input_qparams_map
525+
return None
526+
527+
def _get_output_qparams_map(self, node: Node):
528+
output_qparams_map: dict[int, QuantArgs] = {}
529+
for getitem_node in node.users:
530+
idx = cast(int, getitem_node.args[1])
531+
qparam = getitem_node.meta.get("output_qparams", None)
532+
if qparam:
533+
output_qparams_map[idx] = cast(QuantArgs, qparam[0])
534+
return output_qparams_map
535+
536+
def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> bool:
537+
modified = False
538+
if_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore
539+
else_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[2].target)) # type: ignore
540+
input_qparams_map = self._get_input_qparams_map(node, 3)
541+
if input_qparams_map:
542+
modified |= self._rescale_submodule_inputs(if_graph, input_qparams_map)
543+
modified |= self._rescale_submodule_inputs(else_graph, input_qparams_map)
544+
545+
output_qparams_map = self._get_output_qparams_map(node)
546+
if output_qparams_map:
547+
modified |= self._rescale_submodule_outputs(if_graph, output_qparams_map)
548+
modified |= self._rescale_submodule_outputs(else_graph, output_qparams_map)
549+
return modified
550+
551+
def _rescale_while_submodules(self, node: Node, graph_module: GraphModule):
552+
modified = False
553+
cond_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[0].target)) # type: ignore
554+
body_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore
555+
556+
input_qparams_map = self._get_input_qparams_map(node, 2)
557+
if input_qparams_map:
558+
modified |= self._rescale_submodule_inputs(cond_graph, input_qparams_map)
559+
modified |= self._rescale_submodule_inputs(body_graph, input_qparams_map)
560+
561+
output_qparams_map = self._get_output_qparams_map(node)
562+
if output_qparams_map:
563+
modified |= self._rescale_submodule_outputs(body_graph, output_qparams_map)
564+
return modified
565+
566+
def call(self, graph_module: GraphModule) -> PassResult:
567+
modified = False
568+
569+
for node in list(graph_module.graph.nodes):
570+
node = cast(Node, node)
571+
if node.op != "call_function":
572+
continue
573+
574+
if node.target == torch.ops.higher_order.cond:
575+
modified = self._rescale_cond_submodules(node, graph_module)
576+
if node.target == torch.ops.higher_order.while_loop:
577+
modified = self._rescale_while_submodules(node, graph_module)
578+
579+
if modified:
580+
# Retrace the graph to update the fake tensor types
581+
graph_module = super().call(graph_module).graph_module
582+
graph_module.recompile()
583+
584+
return PassResult(graph_module, modified)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from . import ( # noqa
88
clone_dim_order_support,
9+
control_flow_support,
910
convolution_support,
1011
embedding_support,
1112
ethos_u55_support,

0 commit comments

Comments
 (0)