Skip to content

Commit 8d6a1bf

Browse files
committed
test IC=3 convolution
1 parent ffce47d commit 8d6a1bf

File tree

1 file changed

+42
-27
lines changed

1 file changed

+42
-27
lines changed

tests/python/contrib/test_cutlass.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -312,13 +312,14 @@ def convert_conv2d_layout(mod, desired_layouts):
312312

313313

314314
def verify_conv2d(
315-
mod_nchw,
316-
mod_ref,
315+
mod_nchw, # can be dynamic batch
316+
mod_ref, # always static batch
317317
d_shape,
318318
w_shape,
319319
sm=80,
320320
atol=1e-5,
321321
rtol=1e-5,
322+
use_cudnn_ref=False,
322323
run_benchmark=False,
323324
):
324325
if not has_cutlass():
@@ -332,52 +333,66 @@ def verify_conv2d(
332333
typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
333334
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
334335

336+
mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]})
337+
335338
if use_vm:
336-
rt_mod, dev, num_cutlass_partition = profile_and_build_vm(
337-
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm
338-
)
339+
rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm)
339340
out = get_output_vm(rt_mod, ["data"], [np_data])
340341
else:
341-
rt_mod, dev, num_cutlass_partition = profile_and_build(
342-
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}),
342+
rt_mod, _, num_cutlass_partition = profile_and_build(
343+
mod_weight_ohwi,
343344
params,
344345
sm,
345346
)
346347
out = get_output(rt_mod, ["data"], [np_data])
347348

348349
assert num_cutlass_partition > 0
349350

350-
rt_mod_ref, _ = get_ref_rt_mod(
351-
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
352-
params,
353-
target="cuda",
354-
)
355-
ref_out = get_output(rt_mod_ref, ["data"], [np_data])
351+
if use_cudnn_ref:
352+
rt_mod_ref, dev = get_ref_rt_mod(
353+
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}),
354+
params,
355+
target="cuda -libs=cudnn",
356+
)
357+
else:
358+
rt_mod_ref, dev = get_ref_rt_mod(
359+
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
360+
params,
361+
target="cuda",
362+
)
356363

357-
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
364+
ref_out = get_output(rt_mod_ref, ["data"], [np_data])
358365

359366
if run_benchmark:
360367
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
361368
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))
362369

370+
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
371+
363372

364373
def test_conv2d():
374+
for IC in [3, 16]:
375+
d_shape = (16, IC, 32, 32)
376+
w_shape = (32, IC, 3, 3)
377+
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
378+
379+
verify_conv2d(
380+
mod_nchw,
381+
mod_nchw,
382+
d_shape,
383+
w_shape,
384+
sm=80,
385+
atol=1e-5,
386+
rtol=1e-5,
387+
use_cudnn_ref=IC == 3,
388+
run_benchmark=False,
389+
)
390+
365391
d_shape = (16, 16, 32, 32)
366392
w_shape = (32, 16, 3, 3)
367-
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
368-
369-
verify_conv2d(
370-
mod_nchw,
371-
mod_nchw,
372-
d_shape,
373-
w_shape,
374-
sm=80,
375-
atol=1e-5,
376-
rtol=1e-5,
377-
run_benchmark=False,
378-
)
379-
380393
dyn_batch_shape = (relay.Any(),) + d_shape[1:]
394+
395+
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
381396
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)
382397

383398
verify_conv2d(

0 commit comments

Comments
 (0)