Skip to content

Commit 228b68d

Browse files
cccclaifacebook-github-bot
authored andcommitted
Support eq.Scalar
Differential Revision: D86891707
1 parent 6de1f4e commit 228b68d

File tree

6 files changed

+194
-7
lines changed

6 files changed

+194
-7
lines changed

backends/qualcomm/_passes/replace_inplace_copy.py

Whitespace-only changes.

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def define_tensor(
466466
tensor_source_node, target_build_node
467467
)
468468
dtype = self.get_data_type(tensor, quant_configs)
469+
print(f"tensor_name: {tensor_name}, tensor_type: {tensor_type}, dtype: {dtype}")
469470
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
470471
tensor_wrapper = PyQnnWrapper.TensorWrapper(
471472
tensor_name,

backends/qualcomm/builders/op_eq.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,22 @@
88
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
99

1010
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
1112

1213
from .node_visitor import NodeVisitor
1314
from .node_visitor_manager import register_node_visitor
1415
from .qnn_constants import OpElementWiseEqual, QNN_OP_PACKAGE_NAME_QTI_AISW
15-
16+
from executorch.backends.qualcomm.utils.constants import (
17+
QCOM_QUANT_ATTRS,
18+
QCOM_QUANT_MAX,
19+
QCOM_QUANT_MIN,
20+
QCOM_SCALE,
21+
QCOM_ZERO_POINT,
22+
)
1623

1724
@register_node_visitor
1825
class Equal(NodeVisitor):
19-
target = ["aten.eq.Tensor"]
26+
target = ["aten.eq.Tensor", "aten.eq.Scalar"]
2027

2128
def __init__(self, *args) -> None:
2229
super().__init__(*args)
@@ -37,11 +44,43 @@ def define_node(
3744
output_tensors = [output_tensor_wrapper]
3845

3946
input_tensors = []
40-
for index in range(2):
41-
input_node = self.get_node(node.args[index])
42-
input_tensor = self.get_tensor(input_node, node)
43-
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
44-
47+
for index, arg in enumerate(node.args):
48+
if isinstance(arg, torch.fx.Node):
49+
# Normal tensor input
50+
input_node = self.get_node(arg)
51+
input_tensor = self.get_tensor(input_node, node)
52+
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
53+
else:
54+
assert index == 1, f"eq op arg at index 1 has to be int, but the type is {type(arg)}"
55+
assert isinstance(arg, int), f"eq op arg {arg} has to be int , but the type is {type(arg)}"
56+
print(f"arg is {arg}, type is {type(arg)}")
57+
# Handle scalar input (e.g., int or float)
58+
scalar = arg
59+
scalar_value = float(scalar)
60+
input_tensor = torch.tensor(
61+
scalar_value, dtype=torch.int32
62+
)
63+
input_node = torch.fx.Node(
64+
node.graph,
65+
node.name + "_runtime_scalar",
66+
"call_function",
67+
exir_ops.edge.aten.scalar_tensor.default,
68+
(), # args
69+
{}, # kwargs
70+
)
71+
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
72+
if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS):
73+
quant_attrs = quant_attrs.copy()
74+
quant_range = (
75+
quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
76+
)
77+
quant_attrs[QCOM_ZERO_POINT] = (
78+
0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX]
79+
)
80+
quant_attrs[QCOM_SCALE] = (
81+
scalar / quant_range if scalar >= 0 else -scalar / quant_range
82+
)
83+
input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
4584
input_tensor_wrapper = self.define_tensor(
4685
input_node,
4786
node,

backends/qualcomm/tests/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,17 @@ def forward(self, x):
944944
return x == self.constant
945945

946946

947+
class EqualFromInplaceCopyDecomp(torch.nn.Module):
948+
def __init__(self, hidden_size=4):
949+
super().__init__()
950+
# a small state tensor
951+
self.register_buffer("h", torch.zeros((1, hidden_size)))
952+
953+
def forward(self, x):
954+
self.h[0] = x
955+
return self.h[0]
956+
957+
947958
class ExpandCopy(torch.nn.Module):
948959
def __init__(self):
949960
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,19 @@ def test_qnn_backend_equal(self):
765765
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
766766
)
767767

768+
def test_qnn_backend_equal_debug(self):
769+
test_comb = [
770+
{
771+
QCOM_MODULE: EqualFromInplaceCopyDecomp(), # noqa: F405
772+
QCOM_SAMPLE_INPUTS: (torch.tensor([1.0, 2.0, 3.0, 4.0]), ),
773+
},
774+
]
775+
for i, test in enumerate(test_comb):
776+
with self.subTest(i=i):
777+
self.lower_module_and_test_output(
778+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
779+
)
780+
768781
def test_qnn_backend_expand(self):
769782
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
770783
sample_inputs = [
@@ -2842,6 +2855,26 @@ def test_qnn_backend_equal(self):
28422855
)
28432856
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
28442857

2858+
def test_qnn_backend_equal_debug(self):
2859+
test_comb = [
2860+
{
2861+
QCOM_MODULE: EqualFromInplaceCopyDecomp(), # noqa: F405
2862+
QCOM_SAMPLE_INPUTS: (torch.tensor([1.0, 2.0, 3.0, 4.0]), ),
2863+
},
2864+
]
2865+
for i, test in enumerate(test_comb):
2866+
with self.subTest(i=i):
2867+
module = self.get_qdq_module(
2868+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
2869+
)
2870+
2871+
print("quantized module")
2872+
module.print_readable()
2873+
2874+
self.lower_module_and_test_output(
2875+
module, test[QCOM_SAMPLE_INPUTS]
2876+
)
2877+
28452878
def test_qnn_backend_expand(self):
28462879
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
28472880
sample_inputs = [
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from torch.export import ExportedProgram
14+
15+
16+
def _is_index_put(node: torch.fx.Node) -> bool:
17+
"""Check if a node is an index_put operation."""
18+
return node.op == "call_function" and node.target in (
19+
torch.ops.aten.index_put.default,
20+
ops.edge.aten.index_put.default,
21+
)
22+
23+
24+
def _is_safe_to_reinplace(
25+
node: torch.fx.Node,
26+
later_nodes: Set[torch.fx.Node],
27+
inputs: Set[torch.fx.Node],
28+
mutable_inputs: Set[torch.fx.Node],
29+
) -> bool:
30+
# This node is used later in the graph so we can't reinplace it
31+
# There is probably a faster way to do this but this works for now.
32+
if node in later_nodes:
33+
return False
34+
# If its not an input then we can reinplace it
35+
if node not in inputs:
36+
return True
37+
# If its a mutable input then we can reinplace it
38+
elif node in mutable_inputs:
39+
return True
40+
else: # input but not mutable input
41+
return False
42+
43+
44+
def _is_mutable_user_input(
45+
node: torch.fx.Node, exported_program: ExportedProgram
46+
) -> bool:
47+
return (
48+
node.target in exported_program.graph_signature.user_inputs_to_mutate.values()
49+
)
50+
51+
52+
def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
53+
if node.target not in exported_program.graph_signature.inputs_to_buffers:
54+
return False
55+
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
56+
return buf in exported_program.graph_signature.buffers_to_mutate.values()
57+
58+
59+
def functionalize_pass(ep: ExportedProgram) -> ExportedProgram:
60+
"""
61+
Pass that loops over nodes in an exported program and collects the first argument
62+
of every call_function node that is a view_copy operation.
63+
64+
Args:
65+
exported_program: The ExportedProgram to analyze
66+
67+
Returns:
68+
Set of nodes that are first arguments to view_copy operations
69+
"""
70+
seen_nodes: Set[torch.fx.Node] = set()
71+
# Get all placeholders
72+
inputs = set()
73+
for node in ep.graph.nodes:
74+
if node.op == "placeholder":
75+
inputs.add(node)
76+
# Get all inputs that we could potentially mutate
77+
mutable_nodes = set(
78+
[
79+
node
80+
for node in inputs
81+
if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep)
82+
]
83+
)
84+
85+
results = set()
86+
for node in reversed(ep.graph.nodes):
87+
if _is_index_put(node):
88+
# Check if this index_put node is safe to inplace
89+
# The first argument is the base tensor being indexed into
90+
first_arg = node.args[0]
91+
if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes):
92+
# This index_put is safe to reinplace
93+
with ep.graph.inserting_before(node):
94+
new_node = ep.graph.call_function(
95+
ops.edge.aten.index_put_.default, args=node.args
96+
)
97+
new_node.meta["val"] = node.meta["val"]
98+
node.replace_all_uses_with(new_node)
99+
ep.graph.erase_node(node)
100+
results.add(first_arg)
101+
elif node.op == "call_function":
102+
seen_nodes.update(node.all_input_nodes)
103+
return ep

0 commit comments

Comments
 (0)