@@ -41,20 +41,28 @@ def _get_model(
4141):
4242 """Return a model and any parameters it may have"""
4343
44- iinfo = np .iinfo (dtype )
45- data_min = iinfo .min
46- data_max = iinfo .max
44+ def create_or_assign_constant (shape , dtype , default_data ):
45+ """Creates new numpy array or assigns default_data if available."""
46+
47+ iinfo = np .iinfo (dtype )
48+ data_min = iinfo .min
49+ data_max = iinfo .max
50+
51+ nparray = None
52+ if default_data :
53+ nparray = np .array (default_data , dtype = dtype ).reshape (shape )
54+ else :
55+ nparray = np .random .randint (data_min , data_max + 1 , size = shape , dtype = dtype )
56+
57+ return relay .const (nparray , dtype = dtype )
4758
4859 if lhs_is_constant :
49- a_data = np .array (constant_data , dtype = dtype ).reshape (lhs_shape )
50- a = relay .const (a_data , dtype = dtype )
60+ a = create_or_assign_constant (lhs_shape , dtype , constant_data )
5161 else :
5262 a = relay .var ("a" , shape = lhs_shape , dtype = dtype )
5363
5464 if rhs_is_constant :
55- b_data = np .array (constant_data , dtype = dtype ).reshape (rhs_shape )
56- np .random .randint (data_min , data_max + 1 , size = rhs_shape , dtype = dtype )
57- b = relay .const (b_data , dtype = dtype )
65+ b = create_or_assign_constant (rhs_shape , dtype , constant_data )
5866 else :
5967 b = relay .var ("b" , shape = rhs_shape , dtype = dtype )
6068
@@ -125,6 +133,46 @@ def test_addition(dtype, shape):
125133 tei .verify (outputs , dtype , 1 )
126134
127135
136+ @requires_ethosn
137+ @pytest .mark .parametrize ("dtype" , ["uint8" , "int8" ])
138+ @pytest .mark .parametrize (
139+ "lhs_shape,lhs_is_constant,rhs_shape,rhs_is_constant" ,
140+ [
141+ ((1 , 4 , 4 , 8 ), True , (1 , 1 , 1 , 8 ), True ),
142+ ((4 ,), True , (1 , 16 , 12 , 4 ), True ),
143+ ((1 , 1 , 1 , 8 ), True , (1 , 4 , 4 , 8 ), True ),
144+ ((1 , 16 , 12 , 4 ), True , (4 ,), True ),
145+ ],
146+ )
147+ def test_addition_both_inputs_constants (
148+ dtype , lhs_shape , lhs_is_constant , rhs_shape , rhs_is_constant
149+ ):
150+ """Check if addition is simplified when both inputs are constants."""
151+ np .random .seed (0 )
152+
153+ lhs_zp , lhs_sc , rhs_zp , rhs_sc , out_zp , out_sc = _get_addition_qnn_params (dtype )
154+
155+ model = _get_model (
156+ lhs_shape ,
157+ rhs_shape ,
158+ lhs_zp ,
159+ lhs_sc ,
160+ rhs_zp ,
161+ rhs_sc ,
162+ out_zp ,
163+ out_sc ,
164+ dtype ,
165+ lhs_is_constant = lhs_is_constant ,
166+ rhs_is_constant = rhs_is_constant ,
167+ )
168+ from tvm .relay .op .contrib import partition_for_ethosn # pylint: disable=import-outside-toplevel
169+
170+ mod = tei .make_module (model , {})
171+ assert "qnn.add" in mod .astext (False )
172+ mod = partition_for_ethosn (mod , {})
173+ assert "qnn.add" not in mod .astext (False )
174+
175+
128176@requires_ethosn
129177@pytest .mark .parametrize ("dtype" , ["uint8" , "int8" ])
130178@pytest .mark .parametrize (
@@ -145,9 +193,6 @@ def test_addition_to_depthwise(dtype, lhs_shape, lhs_is_constant, rhs_shape, rhs
145193 data_max = iinfo .max
146194 lhs_zp , lhs_sc , rhs_zp , rhs_sc , out_zp , out_sc = _get_addition_qnn_params (dtype )
147195
148- constant_shape = lhs_shape if lhs_is_constant else rhs_shape
149- constant_data = np .random .randint (data_min , data_max + 1 , size = constant_shape , dtype = dtype )
150-
151196 model = _get_model (
152197 lhs_shape ,
153198 rhs_shape ,
@@ -160,7 +205,6 @@ def test_addition_to_depthwise(dtype, lhs_shape, lhs_is_constant, rhs_shape, rhs
160205 dtype ,
161206 lhs_is_constant = lhs_is_constant ,
162207 rhs_is_constant = rhs_is_constant ,
163- constant_data = constant_data ,
164208 )
165209 input_shape = rhs_shape if lhs_is_constant else lhs_shape
166210 input_name = "b" if lhs_is_constant else "a"
0 commit comments