2828from tvm .topi .utils import get_const_tuple
2929from tvm .topi .nn .conv2d import _get_workload
3030from tvm .topi .generic .conv2d import fallback_schedule_cpu_common_int8
31+ from tvm .testing .aot import get_dtype_range
3132
3233from common import Int8Fallback
3334import tvm .testing
@@ -125,6 +126,8 @@ def test_conv2d_NHWC_gemm_int8(params, device):
125126 add_relu ,
126127 ) = params
127128
129+ dtype = "int8"
130+
128131 # TODO(ekalda): These combinations hang during compilation
129132 failing_cases = [
130133 (devices [1 ], (1 , 128 , 17 , 192 , 7 , 1 , "SAME" , 2 , False , False )),
@@ -135,7 +138,7 @@ def test_conv2d_NHWC_gemm_int8(params, device):
135138 ), # this one passes but is just incredibly slow
136139 ]
137140 if (device , params ) in failing_cases :
138- return
141+ pytest . skip ( "Skipping because this test will hang" )
139142
140143 print ("Compiling for target: %s" % target )
141144
@@ -148,19 +151,15 @@ def test_conv2d_NHWC_gemm_int8(params, device):
148151
149152 in_height = in_width = in_size
150153
151- A = te .placeholder ((batch , in_height , in_width , in_channel ), name = "A" , dtype = "int8" )
152- W = te .placeholder ((kernel , kernel , in_channel , num_filter ), name = "W" , dtype = "int8" )
153- bias = te .placeholder ((num_filter ,), name = "bias" , dtype = "int8" )
154-
155- a_shape = get_const_tuple (A .shape )
156- w_shape = get_const_tuple (W .shape )
157- bias_shape = get_const_tuple (bias .shape )
158- dtype = A .dtype
154+ a_shape = (batch , in_height , in_width , in_channel )
155+ w_shape = (kernel , kernel , in_channel , num_filter )
156+ bias_shape = (num_filter ,)
159157
160- @memoize ("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw " )
158+ @memoize ("topi.tests.test_topi_conv2d_int8.test_conv2d_NHWC_gemm_int8 " )
161159 def get_ref_data ():
162- a_np = np .random .randint (low = - 128 , high = 127 , size = a_shape ).astype (dtype )
163- w_np = np .random .randint (low = - 128 , high = 128 , size = w_shape ).astype (dtype )
160+ input_min , input_max = get_dtype_range (dtype )
161+ a_np = np .random .randint (low = input_min , high = input_max , size = a_shape ).astype (dtype )
162+ w_np = np .random .randint (low = input_min , high = input_max , size = w_shape ).astype (dtype )
164163 b_np = np .random .uniform (size = bias_shape ).astype (dtype )
165164 dw_np = tvm .topi .testing .dilate_python (w_np , (dilation , dilation , 1 , 1 ))
166165 c_np = tvm .topi .testing .conv2d_nhwc_python (a_np , dw_np , stride , padding ).astype (dtype )
@@ -173,28 +172,22 @@ def get_ref_data():
173172
174173 return a_np , w_np , b_np , c_np
175174
176- a_np , w_np , b_np , c_np = get_ref_data ()
177-
178- dev = tvm .device (target , 0 )
179175 with tvm .target .Target (target ) as tvm_target :
176+ A = te .placeholder (a_shape , name = "A" , dtype = dtype )
177+ W = te .placeholder (w_shape , name = "W" , dtype = dtype )
178+ bias = te .placeholder (bias_shape , name = "bias" , dtype = dtype )
180179 C = compute (A , W , (stride , stride ), padding , (dilation , dilation ), dtype )
181180 if add_bias :
182181 C = topi .add (C , bias )
183182 if add_relu :
184183 C = topi .nn .relu (C )
185184 s = schedule ([C ])
186185
187- a = tvm .nd .array (a_np , dev )
188- w = tvm .nd .array (w_np , dev )
189- b = tvm .nd .array (b_np , dev )
190- c = tvm .nd .array (np .zeros (get_const_tuple (C .shape ), dtype = C .dtype ), dev )
191-
192- build_inputs = [A , W , bias , C ] if add_bias else [A , W , C ]
193- inference_inputs = (a , w , b , c ) if add_bias else (a , w , c )
186+ build_args = [A , W , bias , C ] if add_bias else [A , W , C ]
194187
195188 func = tvm .build (
196189 s ,
197- build_inputs ,
190+ build_args ,
198191 target ,
199192 name = "relu_%d_%d_%d_%d_%d_%d_%d_%d"
200193 % (
@@ -211,10 +204,22 @@ def get_ref_data():
211204
212205 build_only = tvm_target .features .is_aarch64 and (platform .machine () != "aarch64" )
213206
214- if not build_only :
215- print ("Running on target: %s" % target )
216- func (* inference_inputs )
217- tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-5 )
207+ if build_only :
208+ return
209+
210+ print ("Running on target: %s" % target )
211+
212+ dev = tvm .device (target , 0 )
213+ a_np , w_np , b_np , c_np = get_ref_data ()
214+ a = tvm .nd .array (a_np , dev )
215+ w = tvm .nd .array (w_np , dev )
216+ b = tvm .nd .array (b_np , dev )
217+ c = tvm .nd .array (np .zeros (get_const_tuple (C .shape ), dtype = C .dtype ), dev )
218+
219+ run_args = [a , w , b , c ] if add_bias else [a , w , c ]
220+ func (* run_args )
221+
222+ tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-5 )
218223
219224
220225@pytest .mark .parametrize ("in_dtype" , ["int8" , "uint8" ])
@@ -339,27 +344,28 @@ def test_conv2d_NCHWc_int8(in_dtype, params):
339344 w_shape = get_const_tuple (W .shape )
340345 dtype = A .dtype
341346 out_dtype = "int32" if in_dtype == "int8" else "uint32"
342- lo = - 128 if in_dtype == "int8" else 0
343- hi = 127 if in_dtype == "int8" else 255
347+ input_min , input_max = get_dtype_range (in_dtype )
344348
345349 def check_target (target , compute , schedule , oc_block_factor , build_only ):
346350 dev = tvm .device (target , 0 )
347351 if not tvm .testing .device_enabled (target ):
348- print ("Skip because %s is not enabled" % target )
349- return
352+ pytest .skip (reason = "Skip because %s is not enabled" % target )
350353 if target == "cuda" and not tvm .contrib .nvcc .have_int8 (dev .compute_version ):
351- print ("Skip because int8 intrinsics are not available" )
352- return
354+ pytest .skip (reason = "Skip because %s is not enabled" % target )
353355
354356 bias = te .placeholder (
355357 (num_filter // oc_block_factor , 1 , 1 , oc_block_factor ), name = "bias" , dtype = out_dtype
356358 )
357359 bias_shape = get_const_tuple (bias .shape )
358360
359- @memoize ("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw " )
361+ @memoize ("topi.tests.test_topi_conv2d_int8.test_conv2d_NCHWc_int8 " )
360362 def get_ref_data ():
361- a_np = np .random .randint (low = lo , high = hi , size = a_shape ).astype (out_dtype )
362- w_np = np .random .randint (low = lo , high = hi , size = w_shape ).astype (out_dtype )
363+ a_np = np .random .randint (low = input_min , high = input_max , size = a_shape ).astype (
364+ out_dtype
365+ )
366+ w_np = np .random .randint (low = input_min , high = input_max , size = w_shape ).astype (
367+ out_dtype
368+ )
363369 b_np = np .random .uniform (size = bias_shape ).astype (out_dtype )
364370 dw_np = tvm .topi .testing .dilate_python (w_np , (1 , 1 , dilation , dilation ))
365371 c_np = tvm .topi .testing .conv2d_nchw_python (a_np , dw_np , stride , padding ).astype (
@@ -380,8 +386,6 @@ def get_ref_data():
380386
381387 return a_np , w_np , b_np , c_np
382388
383- a_np , w_np , b_np , c_np = get_ref_data ()
384-
385389 with tvm .target .Target (target ):
386390 C = compute (
387391 A ,
@@ -399,17 +403,7 @@ def get_ref_data():
399403 C = topi .nn .relu (C )
400404 s = schedule ([C ])
401405
402- a = tvm .nd .array (a_np .astype (dtype ), dev )
403- w = tvm .nd .array (w_np .astype (dtype ), dev )
404- b = tvm .nd .array (b_np .astype (out_dtype ), dev )
405- c = tvm .nd .array (np .zeros (get_const_tuple (C .shape ), dtype = C .dtype ), dev )
406-
407- if add_bias :
408- compile_args = [A , W , bias , C ]
409- run_args = [a , w , b , c ]
410- else :
411- compile_args = [A , W , C ]
412- run_args = [a , w , c ]
406+ compile_args = [A , W , bias , C ] if add_bias else [A , W , C ]
413407
414408 func = tvm .build (
415409 s ,
@@ -422,6 +416,14 @@ def get_ref_data():
422416 if build_only :
423417 return
424418
419+ a_np , w_np , b_np , c_np = get_ref_data ()
420+
421+ a = tvm .nd .array (a_np .astype (dtype ), dev )
422+ w = tvm .nd .array (w_np .astype (dtype ), dev )
423+ b = tvm .nd .array (b_np .astype (out_dtype ), dev )
424+ c = tvm .nd .array (np .zeros (get_const_tuple (C .shape ), dtype = C .dtype ), dev )
425+ run_args = [a , w , b , c ] if add_bias else [a , w , c ]
426+
425427 print ("Running on target: %s" % target )
426428
427429 func (* run_args )
@@ -531,7 +533,7 @@ def test_conv2d_nchw_int8(in_dtype, params):
531533 bias_shape = get_const_tuple (bias .shape )
532534 dtype = A .dtype
533535
534- @memoize ("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw " )
536+ @memoize ("topi.tests.test_topi_conv2d_int8.test_conv2d_nchw_int8 " )
535537 def get_ref_data ():
536538 a_np = np .random .randint (low = - 128 , high = 127 , size = a_shape ).astype (dtype )
537539 w_np = np .random .randint (low = - 128 , high = 128 , size = w_shape ).astype (dtype )
@@ -550,7 +552,7 @@ def get_ref_data():
550552 a_np , w_np , b_np , c_np = get_ref_data ()
551553
552554 def verify_workload_padding ():
553- _ , _ , out_height , out_width = get_const_tuple (c_np .shape )
555+ _ , _ , _ , out_width = get_const_tuple (c_np .shape )
554556 wkl = _get_workload (A , W , (stride , stride ), padding , dilation , dtype )
555557
556558 # for testing functionality,
@@ -568,11 +570,9 @@ def verify_workload_padding():
568570 def check_target (target ):
569571 dev = tvm .device (target , 0 )
570572 if not tvm .testing .device_enabled (target ):
571- print ("Skip because %s is not enabled" % target )
572- return
573+ pytest .skip ("Skip because %s is not enabled" % target )
573574 if target == "cuda" and not tvm .contrib .nvcc .have_int8 (dev .compute_version ):
574- print ("Skip because int8 intrinsics are not available" )
575- return
575+ pytest .skip ("Skip because int8 intrinsics are not available" )
576576
577577 print ("Running on target: %s" % target )
578578 with tvm .target .Target (target ):
@@ -585,52 +585,39 @@ def check_target(target):
585585 C = topi .nn .relu (C )
586586 s = topi .cuda .schedule_conv2d_nchw_int8 ([C ])
587587
588+ build_args = [A , W , bias , C ] if add_bias else [A , W , C ]
589+
590+ func = tvm .build (
591+ s ,
592+ build_args ,
593+ target ,
594+ name = "relu_%d_%d_%d_%d_%d_%d_%d_%d"
595+ % (
596+ batch ,
597+ in_channel ,
598+ in_size ,
599+ num_filter ,
600+ kernel ,
601+ stride ,
602+ padding_sum ,
603+ dilation ,
604+ ),
605+ )
606+
588607 a = tvm .nd .array (a_np , dev )
589608 w = tvm .nd .array (w_np , dev )
590609 b = tvm .nd .array (b_np , dev )
591610 c = tvm .nd .array (np .zeros (get_const_tuple (C .shape ), dtype = C .dtype ), dev )
592- if add_bias :
593- func = tvm .build (
594- s ,
595- [A , W , bias , C ],
596- target ,
597- name = "relu_%d_%d_%d_%d_%d_%d_%d_%d"
598- % (
599- batch ,
600- in_channel ,
601- in_size ,
602- num_filter ,
603- kernel ,
604- stride ,
605- padding_sum ,
606- dilation ,
607- ),
608- )
609- func (a , w , b , c )
610- else :
611- func = tvm .build (
612- s ,
613- [A , W , C ],
614- target ,
615- name = "relu_%d_%d_%d_%d_%d_%d_%d_%d"
616- % (
617- batch ,
618- in_channel ,
619- in_size ,
620- num_filter ,
621- kernel ,
622- stride ,
623- padding_sum ,
624- dilation ,
625- ),
626- )
627- func (a , w , c )
611+
612+ run_args = [a , w , b , c ] if add_bias else [a , w , c ]
613+
614+ func (* run_args )
615+
628616 tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-5 )
629617
630618 verify_workload_padding ()
631619
632- for target in ["cuda" ]:
633- check_target (target )
620+ check_target ("cuda" )
634621
635622
636623if __name__ == "__main__" :
0 commit comments