1818"""Conv2d kernel generator and profiler for CUTLASS."""
1919from .conv2d_operation import Conv2dOperation , EmitConv2dInstance
2020from .conv2d_profiler import Conv2dProfilerEmitter
21+ from .gemm_profiler import GemmProfilerEmitter
22+ from gen_tensor_op import (
23+ ProfilerEngine ,
24+ generate_sm75_tensor_op_1688 ,
25+ generate_sm80_tensor_op_16816 ,
26+ )
2127from .library import (
2228 EpilogueFunctor ,
2329 SwizzlingFunctor ,
3642
3743
3844def create_conv2d_operator (
39- layout ,
4045 tile_descriptions ,
4146 data_type ,
4247 alignment_constraints ,
@@ -51,6 +56,7 @@ def create_conv2d_operator(
5156 element_a , element_b , element_c , element_epilogue = data_type
5257 iterator_algorithms = [IteratorAlgorithm .Analytic , IteratorAlgorithm .Optimized ]
5358
59+ layout = (LayoutType .TensorNHWC , LayoutType .TensorNHWC , LayoutType .TensorNHWC )
5460 for tile in tile_descriptions :
5561 for alignment in alignment_constraints :
5662
@@ -105,3 +111,81 @@ def create_conv2d_operator(
105111 ret .append (op_entry )
106112
107113 return ret
114+
115+
116+ GENERATOR_FUNC_TABLE = {
117+ 75 : generate_sm75_tensor_op_1688 ,
118+ 80 : generate_sm80_tensor_op_16816 ,
119+ }
120+
121+
122+ # TODO(masahi): A sensible way to pick reasonable default kernels
123+ DEFAULT_KERNELS = {
124+ 75 : {
125+ "float16" : "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4" ,
126+ "float32" : "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4" ,
127+ },
128+ 80 : {
129+ "float16" : "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4" ,
130+ "float32" : "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4" ,
131+ },
132+ }
133+
134+
135+ class CutlassConv2DProfiler :
136+ """Profile all candidate kernels and select the best one."""
137+
138+ def __init__ (self , sm , cutlass_path , binary_path ):
139+ assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS , "sm%d not supported yet." % sm
140+ self .engine = ProfilerEngine (sm , cutlass_path , binary_path )
141+ self .sm = sm
142+ self .cache = {}
143+
144+ def check_align (self , op_name , M ):
145+ """Filter out kernels that cannot be supported."""
146+ # TODO
147+ return True
148+
149+ def get_default (self , out_dtype ):
150+ """Return the default kernel for the requested architecture.
151+ For now, the default kernel was picked arbitrary.
152+ """
153+ ops = GENERATOR_FUNC_TABLE [self .sm ](
154+ out_dtype , op_creator = create_conv2d_operator
155+ )
156+ default_kernel_name = DEFAULT_KERNELS [self .sm ][out_dtype ]
157+ filtered = list (filter (lambda op : op ["name" ] == default_kernel_name , ops ))
158+ assert len (filtered ) == 1
159+ return filtered [0 ]
160+
161+ def profile (
162+ self , M , N , K , out_dtype , profile_all = True , use_multiprocessing = False
163+ ):
164+ """Profile and select the best kernel from candidate kernels.
165+ If profile_all is False, return immediately after the first applicable kernel is found.
166+ If use_multiprocessing is True, compile all profiler executables in parallel.
167+ """
168+ if (M , N , K ) in self .cache :
169+ return self .cache [(M , N , K )]
170+
171+ ops = GENERATOR_FUNC_TABLE [self .sm ](
172+ out_dtype , op_creator = create_conv2d_operator
173+ )
174+ ops = list (filter (lambda op : self .check_align (op ["name" ], M ), ops ))
175+
176+ for op in ops :
177+ op ["runtime" ] = - 1
178+
179+ if profile_all :
180+ self .engine .compile_all (ops , use_multiprocessing )
181+
182+ for op in ops :
183+ out = self .engine .evaluate (op , [M , N , K ])
184+ op ["runtime" ] = out
185+ if out > 0 and profile_all is False :
186+ break
187+
188+ valid_ops = filter (lambda op : op ["runtime" ] > 0 , ops )
189+ output = sorted (valid_ops , key = lambda i : i ["runtime" ])
190+ self .cache [(M , N , K )] = output [0 ]
191+ return output [0 ]
0 commit comments