Skip to content

Commit 0867b29

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Including mixed quant GRU op in Jarvis (#15011)
Summary: # Context With the goal of porting mHML on Executorch, a few operators are missing. The main focus is on improving performance for the operators used by the model. # Summary This diff includes a general and HiFi4 optimized GRU operator. Specifically, it adds both a standard GRU implementation and a version optimized for HiFi4 DSPs, ensuring better performance on supported hardware. --- #hthtemplate Reviewed By: skrtskrtfb, mcremon-meta Differential Revision: D81703253
1 parent 3bfd5e0 commit 0867b29

File tree

6 files changed

+132
-2
lines changed

6 files changed

+132
-2
lines changed

backends/cadence/aot/compiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from executorch.backends.cadence.aot.quantizer.quantizer import (
2525
CadenceDefaultQuantizer,
2626
CadenceQuantizer,
27+
CadenceW8A32MixedQuantizer,
2728
)
2829
from executorch.backends.cadence.aot.utils import (
2930
get_default_memory_config,
@@ -59,6 +60,7 @@ def trace(
5960
model: torch.nn.Module,
6061
inputs: tuple[object, ...],
6162
dump_graphs: bool = False,
63+
quantizer: Optional[CadenceQuantizer] = None,
6264
) -> ExportedProgram:
6365
"""
6466
Trace the model with export and return an ExportedProgram.
@@ -73,6 +75,12 @@ def trace(
7375
torch.ops.aten.rms_norm.default,
7476
]
7577

78+
if isinstance(quantizer, CadenceW8A32MixedQuantizer):
79+
ops_to_keep += [
80+
torch.ops.aten.gru.input,
81+
torch.ops.aten.gru.data,
82+
]
83+
7684
program = trace_fn(
7785
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
7886
)
@@ -99,7 +107,7 @@ def prepare_pt2(
99107
Returns a GraphModule with the prepared model.
100108
"""
101109

102-
traced_program = trace(model, inputs, dump_graphs=dump_graphs)
110+
traced_program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
103111
prepared_program = prepare_traced_pt2(
104112
traced_program, quantizer, dump_graphs=dump_graphs
105113
)
@@ -184,7 +192,7 @@ def get_fake_quant_model(
184192
# Make the model inference mode by calling model.eval()
185193
model.eval()
186194

187-
program = trace(model, inputs, dump_graphs=dump_graphs)
195+
program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
188196

189197
if dump_graphs:
190198
logging.info("Graph after trace:")

backends/cadence/aot/functions_hifi.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,8 @@
558558
kernels:
559559
- arg_meta: null
560560
kernel_name: impl::HiFi::quantized_w8a32_conv_out
561+
562+
- func: cadence::quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)
563+
kernels:
564+
- arg_meta: null
565+
kernel_name: impl::HiFi::quantized_w8a32_gru_out

backends/cadence/aot/ops_registrations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@
578578
"quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
579579
)
580580

581+
lib.define(
582+
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor"
583+
)
584+
585+
lib.define(
586+
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)"
587+
)
588+
589+
581590
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
582591
aten_lib = Library("aten", "FRAGMENT")
583592
aten_lib.define(
@@ -2646,3 +2655,18 @@ def quantized_w8a32_conv_meta(
26462655
channel_last=False,
26472656
)
26482657
return src.new_empty(output_size, dtype=src.dtype)
2658+
2659+
@register_fake("cadence::quantized_w8a32_gru")
2660+
def quantized_w8a32_gru_meta(
2661+
inputs: torch.Tensor,
2662+
hidden: torch.Tensor,
2663+
weights_inputs: torch.Tensor,
2664+
w_i_scale: float,
2665+
weights_hidden: torch.Tensor,
2666+
w_h_scale: float,
2667+
bias_inputs: torch.Tensor,
2668+
b_i_scale: float,
2669+
bias_hidden: torch.Tensor,
2670+
b_h_scale: float
2671+
) -> torch.Tensor:
2672+
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MatmulPattern,
2727
MixedW8A32ConvPattern,
2828
MixedW8A32LinearPattern,
29+
MixedW8A32GruPattern,
2930
ReluPattern0,
3031
ReluPattern1,
3132
SoftmaxPattern,
@@ -528,6 +529,41 @@ def get_args_and_kwargs_mixed_w8a32_conv(
528529
return args, kwargs
529530

530531

532+
def get_args_and_kwargs_mixed_w8a32_gru(
533+
graph_module: GraphModule,
534+
other_inputs: List[fx.Node],
535+
weights_inputs: List[fx.Node],
536+
dequants_weights: List[fx.Node],
537+
bias_inputs: List[fx.Node],
538+
dequants_biases: List[fx.Node],
539+
op_node: fx.Node,
540+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
541+
# Stride, padding, dilation, groups not supported yet
542+
543+
assert len(dequants_weights) == 2
544+
assert len(dequants_biases) == 2
545+
w_i_scale = dequants_weights[0].args[1]
546+
w_h_scale = dequants_weights[1].args[1]
547+
b_i_scale = dequants_biases[0].args[1]
548+
b_h_scale = dequants_biases[1].args[1]
549+
550+
args = (
551+
other_inputs[0],
552+
other_inputs[1],
553+
weights_inputs[0],
554+
w_i_scale,
555+
weights_inputs[1],
556+
w_h_scale,
557+
bias_inputs[0],
558+
b_i_scale,
559+
bias_inputs[1],
560+
b_h_scale
561+
)
562+
kwargs = {}
563+
564+
return args, kwargs
565+
566+
531567
class QuantFusion(ExportPass):
532568
# pyre-ignore[2]: Parameter `patterns` has no type specified
533569
def __init__(self, patterns) -> None:
@@ -707,6 +743,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
707743
dequants_biases,
708744
op_node,
709745
)
746+
elif isinstance(pattern, MixedW8A32GruPattern):
747+
args, kwargs = get_args_and_kwargs_mixed_w8a32_gru(
748+
graph_module,
749+
other_inputs,
750+
weights_inputs,
751+
dequants_weights,
752+
bias_inputs,
753+
dequants_biases,
754+
op_node,
755+
)
710756

711757
fused = graph_module.graph.call_function(
712758
pattern.replacement_op(),

backends/cadence/aot/quantizer/patterns.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,48 @@ def get_anchors(
661661

662662
def replacement_op(self) -> OpOverload:
663663
return torch.ops.cadence.quantized_w8a32_conv.default
664+
665+
666+
class MixedW8A32GruPattern(QuantizationPattern):
667+
def partition_types(self) -> List[OpOverload]:
668+
return [torch.ops.aten.gru.input]
669+
670+
def get_anchors(
671+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
672+
) -> Tuple[PartitionAnchors, fx.Node]:
673+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
674+
gru_layer = fused_partition[0].nodes[-1]
675+
if len(gru_layer.kwargs) > 0:
676+
return (PartitionAnchors(
677+
empty=True,
678+
), gru_layer)
679+
680+
# Bail if input or states are not multiple of 4 (SIMD)
681+
if gru_layer.args[0].meta['tensor_meta'].shape[-1] % 4 != 0:
682+
return (PartitionAnchors(
683+
empty=True,
684+
), gru_layer)
685+
if gru_layer.args[1].meta['tensor_meta'].shape[-1] % 4 != 0:
686+
return (PartitionAnchors(
687+
empty=True,
688+
), gru_layer)
689+
690+
class Wrapper:
691+
def __init__(self, args, meta):
692+
self.args = args
693+
self.meta = meta
694+
695+
wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)
696+
697+
return (PartitionAnchors(
698+
inputs=[],
699+
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
700+
weights=[(wrapper, 0), (wrapper, 1)],
701+
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
702+
biases=[(wrapper, 2), (wrapper, 3)],
703+
output=[],
704+
others=[(gru_layer, 0), (gru_layer, 1)],
705+
), gru_layer)
706+
707+
def replacement_op(self) -> OpOverload:
708+
return torch.ops.cadence.quantized_w8a32_gru.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MatmulPattern,
2727
MixedW8A32ConvPattern,
2828
MixedW8A32LinearPattern,
29+
MixedW8A32GruPattern,
2930
QuantizationPattern,
3031
ReluPattern0,
3132
ReluPattern1,
@@ -325,6 +326,7 @@ def __init__(self) -> None:
325326
quantizers.append(
326327
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
327328
)
329+
quantizers.append(CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym))
328330
super().__init__(quantizers)
329331

330332

0 commit comments

Comments
 (0)