Skip to content

Commit c7e04e4

Browse files
Make Relu quantization non-shared
1 parent e8e0ea4 commit c7e04e4

File tree

1 file changed

+24
-30
lines changed

1 file changed

+24
-30
lines changed

backends/nxp/quantizer/patterns.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,24 @@ def get_anchors(
121121
)
122122

123123

124+
class SingleInputBasicPattern(QuantizationPattern):
125+
@abstractmethod
126+
def partition_types(self) -> list[OpOverload]:
127+
pass
128+
129+
def get_anchors(
130+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
131+
) -> PartitionAnchors | None:
132+
node = fused_partition[0].nodes[-1]
133+
134+
return PartitionAnchors(
135+
inputs=[(node, NodeArgsIdx(0))],
136+
weights=[],
137+
biases=[],
138+
output=[(node,)],
139+
)
140+
141+
124142
def get_anchors_for_fixed_quant_specs(
125143
fused_partition: list[fx.GraphModule],
126144
scale: float,
@@ -376,7 +394,7 @@ def partition_types(self):
376394
return [torch.ops.aten.flatten.using_ints]
377395

378396

379-
class HardTanhPattern(QuantizationPattern):
397+
class HardTanhPattern(SingleInputBasicPattern):
380398
"""
381399
Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
382400
computation layer.
@@ -385,23 +403,12 @@ class HardTanhPattern(QuantizationPattern):
385403
def partition_types(self):
386404
return [torch.ops.aten.hardtanh.default]
387405

388-
def get_anchors(
389-
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
390-
) -> PartitionAnchors | None:
391-
node = fused_partition[0].nodes[-1]
392-
393-
return PartitionAnchors(
394-
inputs=[(node, NodeArgsIdx(0))],
395-
weights=[],
396-
biases=[],
397-
output=[(node,)],
398-
)
399406

400407
def replacement_op(self):
401408
raise AssertionError()
402409

403410

404-
class HardTanhInPlacePattern(QuantizationPattern):
411+
class HardTanhInPlacePattern(SingleInputBasicPattern):
405412
"""
406413
Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
407414
functions usually follows computation layer.
@@ -410,18 +417,6 @@ class HardTanhInPlacePattern(QuantizationPattern):
410417
def partition_types(self):
411418
return [torch.ops.aten.hardtanh_.default]
412419

413-
def get_anchors(
414-
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
415-
) -> PartitionAnchors | None:
416-
node = fused_partition[0].nodes[-1]
417-
418-
return PartitionAnchors(
419-
inputs=[(node, NodeArgsIdx(0))],
420-
weights=[],
421-
biases=[],
422-
output=[(node,)],
423-
)
424-
425420
def replacement_op(self):
426421
raise AssertionError()
427422

@@ -513,19 +508,18 @@ def partition_types(self):
513508
return [torch.ops.aten.permute.default]
514509

515510

516-
class ReluPattern(SharedSpecPattern):
511+
class ReluPattern(SingleInputBasicPattern):
517512
"""
518-
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.
513+
Quantizer for Relu operator.
519514
"""
520515

521516
def partition_types(self):
522517
return [torch.ops.aten.relu.default]
523518

524519

525-
class ReluInPlacePattern(SharedSpecPattern):
520+
class ReluInPlacePattern(SingleInputBasicPattern):
526521
"""
527-
Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually
528-
follows computation layer.
522+
Quantizer for Relu operator with param inplace=True.
529523
"""
530524

531525
def partition_types(self):

0 commit comments

Comments
 (0)