Skip to content

Commit 2f2ecaf

Browse files
committed
More testing cleanup
1 parent 8e168f6 commit 2f2ecaf

File tree

2 files changed

+88
-95
lines changed

2 files changed

+88
-95
lines changed

python/tvm/topi/nn/conv2d.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,14 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
606606
if N % tile_rows != 0:
607607
pad_N = tile_rows - (N % tile_rows)
608608

609-
if K % (tile_cols * 4) != 0:
610-
pad_K = (tile_cols * 4) - (K % (tile_cols * 4))
609+
# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
610+
# that the columns is multiple of 4
611+
column_multiplier = 4
612+
tile_cols_multiplied = tile_cols * column_multiplier
613+
K_misalignment = K % tile_cols_multiplied
614+
615+
if K_misalignment != 0:
616+
pad_K = tile_cols_multiplied - K_misalignment
611617

612618
N_padded = N + pad_N
613619
K_padded = K + pad_K

tests/python/topi/python/test_topi_conv2d_int8.py

Lines changed: 80 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tvm.topi.utils import get_const_tuple
2929
from tvm.topi.nn.conv2d import _get_workload
3030
from tvm.topi.generic.conv2d import fallback_schedule_cpu_common_int8
31+
from tvm.testing.aot import get_dtype_range
3132

3233
from common import Int8Fallback
3334
import tvm.testing
@@ -125,6 +126,8 @@ def test_conv2d_NHWC_gemm_int8(params, device):
125126
add_relu,
126127
) = params
127128

129+
dtype = "int8"
130+
128131
# TODO(ekalda): These combinations hang during compilation
129132
failing_cases = [
130133
(devices[1], (1, 128, 17, 192, 7, 1, "SAME", 2, False, False)),
@@ -135,7 +138,7 @@ def test_conv2d_NHWC_gemm_int8(params, device):
135138
), # this one passes but is just incredibly slow
136139
]
137140
if (device, params) in failing_cases:
138-
return
141+
pytest.skip("Skipping because this test will hang")
139142

140143
print("Compiling for target: %s" % target)
141144

@@ -148,19 +151,15 @@ def test_conv2d_NHWC_gemm_int8(params, device):
148151

149152
in_height = in_width = in_size
150153

151-
A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int8")
152-
W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8")
153-
bias = te.placeholder((num_filter,), name="bias", dtype="int8")
154-
155-
a_shape = get_const_tuple(A.shape)
156-
w_shape = get_const_tuple(W.shape)
157-
bias_shape = get_const_tuple(bias.shape)
158-
dtype = A.dtype
154+
a_shape = (batch, in_height, in_width, in_channel)
155+
w_shape = (kernel, kernel, in_channel, num_filter)
156+
bias_shape = (num_filter,)
159157

160-
@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
158+
@memoize("topi.tests.test_topi_conv2d_int8.test_conv2d_NHWC_gemm_int8")
161159
def get_ref_data():
162-
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
163-
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
160+
input_min, input_max = get_dtype_range(dtype)
161+
a_np = np.random.randint(low=input_min, high=input_max, size=a_shape).astype(dtype)
162+
w_np = np.random.randint(low=input_min, high=input_max, size=w_shape).astype(dtype)
164163
b_np = np.random.uniform(size=bias_shape).astype(dtype)
165164
dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
166165
c_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding).astype(dtype)
@@ -173,28 +172,22 @@ def get_ref_data():
173172

174173
return a_np, w_np, b_np, c_np
175174

176-
a_np, w_np, b_np, c_np = get_ref_data()
177-
178-
dev = tvm.device(target, 0)
179175
with tvm.target.Target(target) as tvm_target:
176+
A = te.placeholder(a_shape, name="A", dtype=dtype)
177+
W = te.placeholder(w_shape, name="W", dtype=dtype)
178+
bias = te.placeholder(bias_shape, name="bias", dtype=dtype)
180179
C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
181180
if add_bias:
182181
C = topi.add(C, bias)
183182
if add_relu:
184183
C = topi.nn.relu(C)
185184
s = schedule([C])
186185

187-
a = tvm.nd.array(a_np, dev)
188-
w = tvm.nd.array(w_np, dev)
189-
b = tvm.nd.array(b_np, dev)
190-
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
191-
192-
build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
193-
inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
186+
build_args = [A, W, bias, C] if add_bias else [A, W, C]
194187

195188
func = tvm.build(
196189
s,
197-
build_inputs,
190+
build_args,
198191
target,
199192
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
200193
% (
@@ -211,10 +204,22 @@ def get_ref_data():
211204

212205
build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
213206

214-
if not build_only:
215-
print("Running on target: %s" % target)
216-
func(*inference_inputs)
217-
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
207+
if build_only:
208+
return
209+
210+
print("Running on target: %s" % target)
211+
212+
dev = tvm.device(target, 0)
213+
a_np, w_np, b_np, c_np = get_ref_data()
214+
a = tvm.nd.array(a_np, dev)
215+
w = tvm.nd.array(w_np, dev)
216+
b = tvm.nd.array(b_np, dev)
217+
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
218+
219+
run_args = [a, w, b, c] if add_bias else [a, w, c]
220+
func(*run_args)
221+
222+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
218223

219224

220225
@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
@@ -339,27 +344,28 @@ def test_conv2d_NCHWc_int8(in_dtype, params):
339344
w_shape = get_const_tuple(W.shape)
340345
dtype = A.dtype
341346
out_dtype = "int32" if in_dtype == "int8" else "uint32"
342-
lo = -128 if in_dtype == "int8" else 0
343-
hi = 127 if in_dtype == "int8" else 255
347+
input_min, input_max = get_dtype_range(in_dtype)
344348

345349
def check_target(target, compute, schedule, oc_block_factor, build_only):
346350
dev = tvm.device(target, 0)
347351
if not tvm.testing.device_enabled(target):
348-
print("Skip because %s is not enabled" % target)
349-
return
352+
pytest.skip(reason="Skip because %s is not enabled" % target)
350353
if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
351-
print("Skip because int8 intrinsics are not available")
352-
return
354+
pytest.skip(reason="Skip because %s is not enabled" % target)
353355

354356
bias = te.placeholder(
355357
(num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
356358
)
357359
bias_shape = get_const_tuple(bias.shape)
358360

359-
@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
361+
@memoize("topi.tests.test_topi_conv2d_int8.test_conv2d_NCHWc_int8")
360362
def get_ref_data():
361-
a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
362-
w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
363+
a_np = np.random.randint(low=input_min, high=input_max, size=a_shape).astype(
364+
out_dtype
365+
)
366+
w_np = np.random.randint(low=input_min, high=input_max, size=w_shape).astype(
367+
out_dtype
368+
)
363369
b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
364370
dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
365371
c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
@@ -380,8 +386,6 @@ def get_ref_data():
380386

381387
return a_np, w_np, b_np, c_np
382388

383-
a_np, w_np, b_np, c_np = get_ref_data()
384-
385389
with tvm.target.Target(target):
386390
C = compute(
387391
A,
@@ -399,17 +403,7 @@ def get_ref_data():
399403
C = topi.nn.relu(C)
400404
s = schedule([C])
401405

402-
a = tvm.nd.array(a_np.astype(dtype), dev)
403-
w = tvm.nd.array(w_np.astype(dtype), dev)
404-
b = tvm.nd.array(b_np.astype(out_dtype), dev)
405-
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
406-
407-
if add_bias:
408-
compile_args = [A, W, bias, C]
409-
run_args = [a, w, b, c]
410-
else:
411-
compile_args = [A, W, C]
412-
run_args = [a, w, c]
406+
compile_args = [A, W, bias, C] if add_bias else [A, W, C]
413407

414408
func = tvm.build(
415409
s,
@@ -422,6 +416,14 @@ def get_ref_data():
422416
if build_only:
423417
return
424418

419+
a_np, w_np, b_np, c_np = get_ref_data()
420+
421+
a = tvm.nd.array(a_np.astype(dtype), dev)
422+
w = tvm.nd.array(w_np.astype(dtype), dev)
423+
b = tvm.nd.array(b_np.astype(out_dtype), dev)
424+
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
425+
run_args = [a, w, b, c] if add_bias else [a, w, c]
426+
425427
print("Running on target: %s" % target)
426428

427429
func(*run_args)
@@ -531,7 +533,7 @@ def test_conv2d_nchw_int8(in_dtype, params):
531533
bias_shape = get_const_tuple(bias.shape)
532534
dtype = A.dtype
533535

534-
@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
536+
@memoize("topi.tests.test_topi_conv2d_int8.test_conv2d_nchw_int8")
535537
def get_ref_data():
536538
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
537539
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
@@ -550,7 +552,7 @@ def get_ref_data():
550552
a_np, w_np, b_np, c_np = get_ref_data()
551553

552554
def verify_workload_padding():
553-
_, _, out_height, out_width = get_const_tuple(c_np.shape)
555+
_, _, _, out_width = get_const_tuple(c_np.shape)
554556
wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
555557

556558
# for testing functionality,
@@ -568,11 +570,9 @@ def verify_workload_padding():
568570
def check_target(target):
569571
dev = tvm.device(target, 0)
570572
if not tvm.testing.device_enabled(target):
571-
print("Skip because %s is not enabled" % target)
572-
return
573+
pytest.skip("Skip because %s is not enabled" % target)
573574
if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
574-
print("Skip because int8 intrinsics are not available")
575-
return
575+
pytest.skip("Skip because int8 intrinsics are not available")
576576

577577
print("Running on target: %s" % target)
578578
with tvm.target.Target(target):
@@ -585,52 +585,39 @@ def check_target(target):
585585
C = topi.nn.relu(C)
586586
s = topi.cuda.schedule_conv2d_nchw_int8([C])
587587

588+
build_args = [A, W, bias, C] if add_bias else [A, W, C]
589+
590+
func = tvm.build(
591+
s,
592+
build_args,
593+
target,
594+
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
595+
% (
596+
batch,
597+
in_channel,
598+
in_size,
599+
num_filter,
600+
kernel,
601+
stride,
602+
padding_sum,
603+
dilation,
604+
),
605+
)
606+
588607
a = tvm.nd.array(a_np, dev)
589608
w = tvm.nd.array(w_np, dev)
590609
b = tvm.nd.array(b_np, dev)
591610
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
592-
if add_bias:
593-
func = tvm.build(
594-
s,
595-
[A, W, bias, C],
596-
target,
597-
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
598-
% (
599-
batch,
600-
in_channel,
601-
in_size,
602-
num_filter,
603-
kernel,
604-
stride,
605-
padding_sum,
606-
dilation,
607-
),
608-
)
609-
func(a, w, b, c)
610-
else:
611-
func = tvm.build(
612-
s,
613-
[A, W, C],
614-
target,
615-
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
616-
% (
617-
batch,
618-
in_channel,
619-
in_size,
620-
num_filter,
621-
kernel,
622-
stride,
623-
padding_sum,
624-
dilation,
625-
),
626-
)
627-
func(a, w, c)
611+
612+
run_args = [a, w, b, c] if add_bias else [a, w, c]
613+
614+
func(*run_args)
615+
628616
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
629617

630618
verify_workload_padding()
631619

632-
for target in ["cuda"]:
633-
check_target(target)
620+
check_target("cuda")
634621

635622

636623
if __name__ == "__main__":

0 commit comments

Comments
 (0)