1616# under the License.
1717# pylint: disable=invalid-name, dangerous-default-value
1818"""Conv2d kernel generator and profiler for CUTLASS."""
19+ import os
20+ import pickle
1921from functools import partial
2022from .conv2d_operation import Conv2dOperation , EmitConv2dInstance
2123from .gen_gemm import CutlassGemmProfiler
@@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue(
4042 tile_description ,
4143 data_type ,
4244 alignment ,
45+ alignment_epilogue ,
4346 swizzling_functor ,
4447 split_k_slices ,
4548):
@@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue(
7881
7982 A = TensorDescription (element_a , LayoutType .TensorNHWC , alignment )
8083 B = TensorDescription (element_b , LayoutType .TensorNHWC , alignment )
81- C = TensorDescription (element_c , LayoutType .TensorNHWC , alignment )
84+ C = TensorDescription (element_c , LayoutType .TensorNHWC , alignment_epilogue )
8285
8386 op = Conv2dOperation (
8487 conv_kind ,
@@ -110,6 +113,7 @@ def enumerate_conv2d_operators(
110113 conv_kind ,
111114 stride_support ,
112115 split_k_slices ,
116+ alignment_c ,
113117 tile_descriptions ,
114118 data_type ,
115119 alignment_constraints ,
@@ -128,47 +132,49 @@ def enumerate_conv2d_operators(
128132
129133 for split_k_slice in split_k_slices :
130134 for tile in tile_descriptions :
131- for alignment in alignment_constraints :
132-
133- A = TensorDescription (element_a , LayoutType .TensorNHWC , alignment )
134- B = TensorDescription (element_b , LayoutType .TensorNHWC , alignment )
135- C = TensorDescription (element_c , LayoutType .TensorNHWC , alignment )
136-
137- if element_c == DataType .s32 and A .alignment == 1 :
138- tile .threadblock_shape [0 ] = min (tile .threadblock_shape [0 ], 128 )
139- tile .threadblock_shape [1 ] = min (tile .threadblock_shape [1 ], 128 )
140-
141- op = Conv2dOperation (
142- conv_kind ,
143- IteratorAlgorithm .Optimized ,
144- tile .minimum_compute_capability ,
145- tile ,
146- A ,
147- B ,
148- C ,
149- element_epilogue ,
150- stride_support ,
151- EpilogueFunctor .LinearCombination ,
152- swizzling_functor ,
153- split_k_slice ,
154- )
155-
156- ret .append (
157- {
158- "src" : profiler_emitter .emit (
159- kernel_emitter .emit (op , emit_reduction = split_k_slice > 1 ),
160- op .procedural_name (),
161- element_output = element_c ,
162- split_k_slices = split_k_slice ,
163- ),
164- "name" : op .procedural_name (),
165- "tile_description" : tile ,
166- "alignment" : alignment ,
167- "data_type" : data_type ,
168- "swizzle_functor" : swizzling_functor ,
169- "split_k_slices" : split_k_slice ,
170- }
171- )
135+ for alignmentAB in alignment_constraints :
136+ for alignmentC in alignment_c :
137+
138+ A = TensorDescription (element_a , LayoutType .TensorNHWC , alignmentAB )
139+ B = TensorDescription (element_b , LayoutType .TensorNHWC , alignmentAB )
140+ C = TensorDescription (element_c , LayoutType .TensorNHWC , alignmentC )
141+
142+ if element_c == DataType .s32 and A .alignment == 1 :
143+ tile .threadblock_shape [0 ] = min (tile .threadblock_shape [0 ], 128 )
144+ tile .threadblock_shape [1 ] = min (tile .threadblock_shape [1 ], 128 )
145+
146+ op = Conv2dOperation (
147+ conv_kind ,
148+ IteratorAlgorithm .Optimized ,
149+ tile .minimum_compute_capability ,
150+ tile ,
151+ A ,
152+ B ,
153+ C ,
154+ element_epilogue ,
155+ stride_support ,
156+ EpilogueFunctor .LinearCombination ,
157+ swizzling_functor ,
158+ split_k_slice ,
159+ )
160+
161+ ret .append (
162+ {
163+ "src" : profiler_emitter .emit (
164+ kernel_emitter .emit (op , emit_reduction = split_k_slice > 1 ),
165+ op .procedural_name (),
166+ element_output = element_c ,
167+ split_k_slices = split_k_slice ,
168+ ),
169+ "name" : op .procedural_name (),
170+ "tile_description" : tile ,
171+ "alignment" : alignmentAB ,
172+ "alignment_epilogue" : alignmentC ,
173+ "data_type" : data_type ,
174+ "swizzle_functor" : swizzling_functor ,
175+ "split_k_slices" : split_k_slice ,
176+ }
177+ )
172178
173179 return ret
174180
@@ -181,7 +187,11 @@ def __init__(self, sm, cutlass_path, binary_path):
181187 self .sm = sm
182188 assert sm in GENERATOR_FUNC_TABLE , f"sm{ sm } not supported yet."
183189 self .engine = ProfilerEngine (sm , cutlass_path , binary_path )
184- self .cache = {}
190+ self .cache_path = os .path .join (binary_path , "cutlass_conv2d_cache.pickle" )
191+ if os .path .exists (self .cache_path ):
192+ self .cache = pickle .load (open (self .cache_path , "rb" ))
193+ else :
194+ self .cache = {}
185195
186196 def get_default (
187197 self ,
@@ -216,6 +226,7 @@ def get_default(
216226 tile_description ,
217227 data_type ,
218228 alignment ,
229+ alignment ,
219230 swizzling_functor ,
220231 split_k_slices = 1 ,
221232 )
@@ -265,12 +276,32 @@ def select_op(
265276 if workload in self .cache :
266277 return self .cache [workload ]
267278
279+ def alignments (dtype ):
280+ if dtype in ["float16" ]:
281+ alignments = [8 , 4 , 2 , 1 ]
282+ elif dtype in ["float" , "float32" ]:
283+ alignments = [4 , 2 , 1 ]
284+ else :
285+ raise ValueError ("Unsupported data type: %s" % dtype )
286+ return alignments
287+
288+ alignments_c = [align for align in alignments (out_dtype ) if OC % align == 0 ]
289+
290+ if not profile_all_alignments :
291+ alignments_c = [alignments_c [0 ]]
292+
268293 ops = GENERATOR_FUNC_TABLE [self .sm ](
269294 out_dtype ,
270295 data_dtype ,
271296 weight_dtype ,
272- partial (enumerate_conv2d_operators , conv_kind , stride_support , split_k_slices ),
273- lambda align : all ([dim % align == 0 for dim in [IC , OC ]]),
297+ partial (
298+ enumerate_conv2d_operators ,
299+ conv_kind ,
300+ stride_support ,
301+ split_k_slices ,
302+ alignments_c ,
303+ ),
304+ lambda align : all ([dim % align == 0 for dim in [IC ]]),
274305 use_3xtf32 ,
275306 profile_all_alignments ,
276307 # Use fp32 accumulation for wgrad to align with cuDNN
@@ -294,6 +325,8 @@ def select_op(
294325
295326 op = min (ops , key = lambda i : i ["runtime" ])
296327 self .cache [workload ] = op
328+ with open (self .cache_path , "wb" ) as f :
329+ pickle .dump (self .cache , f )
297330 return op
298331
299332 def profile (
@@ -350,6 +383,7 @@ def profile(
350383 op ["tile_description" ],
351384 op ["data_type" ],
352385 op ["alignment" ],
386+ op ["alignment_epilogue" ],
353387 op ["swizzle_functor" ],
354388 op ["split_k_slices" ],
355389 )
0 commit comments