@@ -312,13 +312,14 @@ def convert_conv2d_layout(mod, desired_layouts):
312312
313313
314314def verify_conv2d (
315- mod_nchw ,
316- mod_ref ,
315+ mod_nchw , # can be dynamic batch
316+ mod_ref , # always static batch
317317 d_shape ,
318318 w_shape ,
319319 sm = 80 ,
320320 atol = 1e-5 ,
321321 rtol = 1e-5 ,
322+ use_cudnn_ref = False ,
322323 run_benchmark = False ,
323324):
324325 if not has_cutlass ():
@@ -332,52 +333,66 @@ def verify_conv2d(
332333 typ = relay .transform .InferType ()(mod_nchw )["main" ].body .checked_type
333334 use_vm = any (isinstance (s , tvm .tir .Any ) for s in typ .shape )
334335
336+ mod_weight_ohwi = convert_conv2d_layout (mod_nchw , {"nn.conv2d" : ["NHWC" , "OHWI" ]})
337+
335338 if use_vm :
336- rt_mod , dev , num_cutlass_partition = profile_and_build_vm (
337- convert_conv2d_layout (mod_nchw , {"nn.conv2d" : ["NHWC" , "OHWI" ]}), params , sm
338- )
339+ rt_mod , _ , num_cutlass_partition = profile_and_build_vm (mod_weight_ohwi , params , sm )
339340 out = get_output_vm (rt_mod , ["data" ], [np_data ])
340341 else :
341- rt_mod , dev , num_cutlass_partition = profile_and_build (
342- convert_conv2d_layout ( mod_nchw , { "nn.conv2d" : [ "NHWC" , "OHWI" ]}) ,
342+ rt_mod , _ , num_cutlass_partition = profile_and_build (
343+ mod_weight_ohwi ,
343344 params ,
344345 sm ,
345346 )
346347 out = get_output (rt_mod , ["data" ], [np_data ])
347348
348349 assert num_cutlass_partition > 0
349350
350- rt_mod_ref , _ = get_ref_rt_mod (
351- convert_conv2d_layout (mod_ref , {"nn.conv2d" : ["NHWC" , "HWIO" ]}),
352- params ,
353- target = "cuda" ,
354- )
355- ref_out = get_output (rt_mod_ref , ["data" ], [np_data ])
351+ if use_cudnn_ref :
352+ rt_mod_ref , dev = get_ref_rt_mod (
353+ convert_conv2d_layout (mod_ref , {"nn.conv2d" : ["NHWC" , "OHWI" ]}),
354+ params ,
355+ target = "cuda -libs=cudnn" ,
356+ )
357+ else :
358+ rt_mod_ref , dev = get_ref_rt_mod (
359+ convert_conv2d_layout (mod_ref , {"nn.conv2d" : ["NHWC" , "HWIO" ]}),
360+ params ,
361+ target = "cuda" ,
362+ )
356363
357- np . testing . assert_allclose ( out , ref_out , atol = atol , rtol = rtol )
364+ ref_out = get_output ( rt_mod_ref , [ "data" ], [ np_data ] )
358365
359366 if run_benchmark :
360367 print ("CUTLASS:" , rt_mod .benchmark (dev , number = 1 , repeat = 600 ))
361368 print ("TVM Tensorcore (no tuning):" , rt_mod_ref .benchmark (dev , number = 1 , repeat = 600 ))
362369
370+ np .testing .assert_allclose (out , ref_out , atol = atol , rtol = rtol )
371+
363372
364373def test_conv2d ():
374+ for IC in [3 , 16 ]:
375+ d_shape = (16 , IC , 32 , 32 )
376+ w_shape = (32 , IC , 3 , 3 )
377+ mod_nchw = get_conv2d_nchw (d_shape , w_shape )
378+
379+ verify_conv2d (
380+ mod_nchw ,
381+ mod_nchw ,
382+ d_shape ,
383+ w_shape ,
384+ sm = 80 ,
385+ atol = 1e-5 ,
386+ rtol = 1e-5 ,
387+ use_cudnn_ref = IC == 3 ,
388+ run_benchmark = False ,
389+ )
390+
365391 d_shape = (16 , 16 , 32 , 32 )
366392 w_shape = (32 , 16 , 3 , 3 )
367- mod_nchw = get_conv2d_nchw (d_shape , w_shape )
368-
369- verify_conv2d (
370- mod_nchw ,
371- mod_nchw ,
372- d_shape ,
373- w_shape ,
374- sm = 80 ,
375- atol = 1e-5 ,
376- rtol = 1e-5 ,
377- run_benchmark = False ,
378- )
379-
380393 dyn_batch_shape = (relay .Any (),) + d_shape [1 :]
394+
395+ mod_nchw = get_conv2d_nchw (d_shape , w_shape )
381396 mod_dyn = get_conv2d_nchw (dyn_batch_shape , w_shape )
382397
383398 verify_conv2d (
0 commit comments