77
88from typing import Any , List
99
10+ import executorch .backends .arm .tosa .quant_utils as tqutils
11+ import executorch .backends .arm .tosa .utils as tutils
12+
1013from executorch .backends .arm .operators .node_visitor import (
1114 NodeVisitor ,
1215 register_node_visitor ,
1619 validate_same_dtype ,
1720 validate_valid_dtype ,
1821)
22+ from executorch .backends .arm .tosa import TosaSpecification
1923from executorch .backends .arm .tosa .mapping import TosaArg
20- from executorch .backends .arm .tosa .specification import TosaSpecification
2124from torch .fx import Node
2225
2326
2427@register_node_visitor
25- class SubVisitor (NodeVisitor ):
28+ class SubVisitor_INT (NodeVisitor ):
2629 target = "aten.sub.Tensor"
2730
2831 tosa_specs = [
2932 TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
30- TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
3133 ]
3234
35+ def __init__ (self , * args ):
36+ super ().__init__ (* args )
37+
3338 def define_node (
3439 self ,
3540 node : Node ,
@@ -45,18 +50,105 @@ def define_node(
4550 validate_valid_dtype (
4651 self .target ,
4752 [* inputs , output ],
48- [ts .DType .INT32 , ts .DType .FP32 ],
53+ [ts .DType .INT8 , ts .DType .INT16 , ts . DType . INT32 ],
4954 output .tosa_spec ,
5055 )
5156
57+ scale_back = 1.0
58+ if inputs [0 ].dtype == ts .DType .INT8 :
59+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32_maxscale (
60+ tosa_graph , inputs , node , self .tosa_spec
61+ )
62+ elif inputs [0 ].dtype == ts .DType .INT16 :
63+ rescaled_inputs , scale_back = (
64+ tqutils .insert_rescale_ops_int16_to_int32_maxscale (
65+ tosa_graph , inputs , node , self .tosa_spec
66+ )
67+ )
68+ else :
69+ # input[0].dtype == ts.DType.INT32
70+ # Non quantized input, natively support by TOSA.SUB
71+ rescaled_inputs = inputs
72+
73+ if output .dtype in [ts .DType .INT8 , ts .DType .INT16 ]:
74+ broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
75+ sub_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
76+ else :
77+ # output.dtype == ts.DType.INT32
78+ sub_output = output
79+
80+ # Do the INT32 Sub
5281 self ._serialize_operator (
5382 node ,
5483 tosa_graph ,
5584 ts .TosaOp .Op ().SUB ,
5685 [
57- inputs [0 ].name ,
58- inputs [1 ].name ,
86+ rescaled_inputs [0 ].name ,
87+ rescaled_inputs [1 ].name ,
5988 ],
60- [output .name ],
89+ [sub_output .name ],
6190 None ,
6291 )
92+
93+ if output .dtype == ts .DType .INT8 :
94+ # Scale output back to 8 bit
95+ # pyre-ignore
96+ tqutils .insert_rescale_op_to_int8 (
97+ tosa_graph ,
98+ sub_output ,
99+ scale_back ,
100+ node ,
101+ compute_rescale = False ,
102+ tosa_spec = self .tosa_spec ,
103+ ) # type: ignore[possibly-undefined]
104+ elif output .dtype == ts .DType .INT16 :
105+ tqutils .insert_rescale_op_to_int16 (
106+ tosa_graph ,
107+ sub_output ,
108+ scale_back ,
109+ node ,
110+ compute_rescale = False ,
111+ tosa_spec = self .tosa_spec ,
112+ ) # type: ignore[possibly-undefined]
113+
114+
115+ @register_node_visitor
116+ class SubVisitor_FP (SubVisitor_INT ):
117+ # inheriting 'target' from INT class
118+
119+ tosa_specs = [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
120+
121+ def __init__ (self , * args ):
122+ super ().__init__ (* args )
123+
124+ def define_node (
125+ self ,
126+ node : Node ,
127+ tosa_graph : Any ,
128+ inputs : List [TosaArg ],
129+ output : TosaArg ,
130+ ) -> None :
131+
132+ import serializer .tosa_serializer as ts # type: ignore
133+
134+ validate_num_inputs (self .target , inputs , 2 )
135+ validate_same_dtype (self .target , [* inputs , output ], ts )
136+
137+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
138+ # Call the inherited define_node for handling integers
139+ super ().define_node (node , tosa_graph , inputs , output )
140+ else :
141+ # FP32 Sub lowering
142+ validate_valid_dtype (
143+ self .target , [* inputs , output ], ts .DType .FP32 , output .tosa_spec
144+ )
145+
146+ # MI lowering
147+ self ._serialize_operator (
148+ node ,
149+ tosa_graph ,
150+ ts .TosaOp .Op ().SUB ,
151+ [inputs [0 ].name , inputs [1 ].name ],
152+ [output .name ],
153+ None ,
154+ )
0 commit comments