Skip to content

Commit 6efddba

Browse files
authored
Support sine operator on XNNPACK (#14711)
Summary: Wire up the unary sine operator in xnnpack for fp32 and fp16. Differential Revision: D83623086
1 parent 9560800 commit 6efddba

File tree

9 files changed

+160
-0
lines changed

9 files changed

+160
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
op_relu,
4242
op_rsqrt,
4343
op_sigmoid,
44+
op_sin,
4445
op_skip_ops,
4546
op_slice_copy,
4647
op_softmax,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNSin,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class SinVisitor(NodeVisitor):
24+
target = "aten.sin.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNSin(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ReciprocalSquareRootConfig,
4646
ReLUConfig,
4747
SigmoidConfig,
48+
SinConfig,
4849
SliceCopyConfig,
4950
SoftmaxConfig,
5051
SquareRootConfig,
@@ -105,6 +106,7 @@
105106
TanhConfig,
106107
ToDimOrderCopyConfig,
107108
SigmoidConfig,
109+
SinConfig,
108110
SliceCopyConfig,
109111
SoftmaxConfig,
110112
SquareRootConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,10 @@ class BMMConfig(GenericNodePartitionerConfig):
636636

637637
def supported_precision_types(self) -> List[ConfigPrecisionType]:
638638
return [ConfigPrecisionType.FP32]
639+
640+
641+
class SinConfig(GenericNodePartitionerConfig):
642+
target_name = "sin.default"
643+
644+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
645+
return [ConfigPrecisionType.FP32]

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,7 @@ _DEFINE_UNARY_NODE_NO_PARAMS(Log, xnn_unary_log)
16901690
_DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate)
16911691
_DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square)
16921692
_DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs)
1693+
_DEFINE_UNARY_NODE_NO_PARAMS(Sin, xnn_unary_sine)
16931694

16941695
// Unary Ops with min/max params
16951696
_DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp)
@@ -1737,6 +1738,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17371738
_DEFINE(Floor)
17381739
_DEFINE(PReLU)
17391740
_DEFINE(Sigmoid)
1741+
_DEFINE(Sin)
17401742

17411743
// Others
17421744
_DEFINE(FullyConnected)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ union XNodeUnion {
156156
XNNGelu: _XNNNode1x1,
157157
XNNTanh: _XNNNode1x1,
158158
XNNExp: _XNNNode1x1,
159+
XNNSin: _XNNNode1x1,
159160
}
160161

161162
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ union XNodeUnion {
152152
XNNGelu: _XNNNode1x1,
153153
XNNTanh: _XNNNode1x1,
154154
XNNExp: _XNNNode1x1,
155+
XNNSin: _XNNNode1x1,
155156
}
156157

157158
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,11 @@ class XNNPReLU(XNNNode2x1):
347347
pass
348348

349349

350+
@dataclass
351+
class XNNSin(XNNNode1x1):
352+
pass
353+
354+
350355
@dataclass
351356
class XNNScaledDotProductAttention:
352357
query_id: int
@@ -402,6 +407,8 @@ class XNNScaledDotProductAttention:
402407
XNNLog,
403408
XNNGelu,
404409
XNNTanh,
410+
XNNExp,
411+
XNNSin,
405412
]
406413

407414

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestSin(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Sin(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
z = torch.sin(x)
23+
return z
24+
25+
def _test_sin(self, inputs, legacy_mode: bool = False):
26+
tester = (
27+
Tester(self.Sin(), inputs)
28+
.export()
29+
.check_count({"torch.ops.aten.sin.default": 1})
30+
)
31+
32+
if legacy_mode:
33+
tester = tester.to_edge().partition()
34+
else:
35+
tester = tester.to_edge_transform_and_lower()
36+
37+
(
38+
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
39+
.check_not(["executorch_exir_dialects_edge__ops_aten_sin_default"])
40+
.to_executorch()
41+
.serialize()
42+
.run_method_and_compare_outputs()
43+
)
44+
45+
def test_fp16_sin(self):
46+
inputs = (
47+
torch.Tensor(
48+
[
49+
[0.0, 0.1, 0.5, 0.785398],
50+
[-0.5, -0.785398, 1.5708, -1.5708],
51+
],
52+
).to(torch.float16),
53+
)
54+
self._test_sin(inputs, legacy_mode=False)
55+
56+
def test_fp16_sin_legacy_mode(self):
57+
inputs = (
58+
torch.Tensor(
59+
[
60+
[0.0, 0.1, 0.5, 0.785398],
61+
[-0.5, -0.785398, 1.5708, -1.5708],
62+
],
63+
).to(torch.float16),
64+
)
65+
self._test_sin(inputs, legacy_mode=True)
66+
67+
def test_fp32_sin(self):
68+
inputs = (
69+
torch.Tensor(
70+
[
71+
[0.0, 0.1, 0.5, 0.785398],
72+
[-0.5, -0.785398, 1.5708, -1.5708],
73+
],
74+
),
75+
)
76+
self._test_sin(inputs, legacy_mode=False)
77+
78+
def test_fp32_sin_legacy_mode(self):
79+
inputs = (
80+
torch.Tensor(
81+
[
82+
[0.0, 0.1, 0.5, 0.785398],
83+
[-0.5, -0.785398, 1.5708, -1.5708],
84+
],
85+
),
86+
)
87+
self._test_sin(inputs, legacy_mode=True)

0 commit comments

Comments
 (0)