1515# specific language governing permissions and limitations
1616# under the License.
1717# pylint: disable=invalid-name
18- """Kernel generator and profiler for CUTLASS."""
19- import logging
20- import os
18+ """GEMM kernel generator and profiler for CUTLASS."""
19+ from functools import partial
2120import re
22- import tempfile
23- import subprocess
24- import multiprocessing
2521from .gemm_operation import GemmOperation , EmitGemmInstance
2622from .gemm_profiler import GemmProfilerEmitter
23+ from .gen_tensor_op import (
24+ ProfilerEngine ,
25+ generate_sm75_tensor_op_1688 ,
26+ generate_sm80_tensor_op_16816 ,
27+ )
2728from .library import (
2829 EpilogueFunctor ,
2930 SwizzlingFunctor ,
3031 TensorDescription ,
3132 DataTypeTag ,
3233 LayoutType ,
33- MathInstruction ,
34- DataType ,
35- OpcodeClass ,
36- MathOperation ,
37- TileDescription ,
3834)
3935
40- logger = logging .getLogger ("cutlass" )
41-
4236
4337def create_gemm_operator (
44- layouts ,
4538 tile_descriptions ,
4639 data_type ,
4740 alignment_constraints ,
48- epilogue_functor = EpilogueFunctor .LinearCombination ,
4941 swizzling_functor = SwizzlingFunctor .Identity8 ,
5042 batched = False ,
5143):
@@ -59,6 +51,10 @@ def create_gemm_operator(
5951 if batched :
6052 swizzling_functor = SwizzlingFunctor .Batched
6153
54+ layouts = [
55+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .RowMajor ),
56+ ]
57+
6258 for layout in layouts :
6359 for tile_description in tile_descriptions :
6460 for alignment in alignment_constraints :
@@ -76,7 +72,7 @@ def create_gemm_operator(
7672 B ,
7773 C ,
7874 element_epilogue ,
79- epilogue_functor ,
75+ EpilogueFunctor . LinearCombination ,
8076 swizzling_functor ,
8177 )
8278 op_bias = GemmOperation (
@@ -110,7 +106,6 @@ def create_gemm_operator(
110106 swizzling_functor ,
111107 )
112108
113- kernel_emitter = EmitGemmInstance ()
114109 op_entry ["op" ] = op
115110 op_entry ["name" ] = op .procedural_name ()
116111 op_entry ["opdef" ] = kernel_emitter .emit (op , batched = batched )
@@ -134,141 +129,12 @@ def create_gemm_operator(
134129 return ret
135130
136131
137- def generate_tensor_op_common (
138- math_instructions , alignment_constraints , get_tile_descriptions , batched = False
139- ):
140- """Common kernel generator to be used by archtecture specific generators."""
141- ops = []
142- layouts = [
143- (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .RowMajor ),
144- ]
145- for math_inst in math_instructions :
146- tile_descriptions = get_tile_descriptions (math_inst )
147- data_type = [
148- math_inst .element_a ,
149- math_inst .element_b ,
150- math_inst .element_accumulator ,
151- math_inst .element_accumulator ,
152- ]
153-
154- out = create_gemm_operator (
155- layouts , tile_descriptions , data_type , alignment_constraints , batched = batched
156- )
157-
158- ops .extend (out )
159-
160- return ops
161-
162-
163- def generate_sm75_tensor_op_1688 (out_dtype , batched = False ):
164- """Generate GEMM kernels for Turing."""
165- assert out_dtype in ["float32" , "float16" ]
166- math_instructions = {
167- "float32" : [
168- MathInstruction (
169- [16 , 8 , 8 ],
170- DataType .f16 ,
171- DataType .f16 ,
172- DataType .f32 ,
173- OpcodeClass .TensorOp ,
174- MathOperation .multiply_add ,
175- )
176- ],
177- "float16" : [
178- MathInstruction (
179- [16 , 8 , 8 ],
180- DataType .f16 ,
181- DataType .f16 ,
182- DataType .f16 ,
183- OpcodeClass .TensorOp ,
184- MathOperation .multiply_add ,
185- )
186- ],
187- }[out_dtype ]
188-
189- alignment_constraints = [8 , 4 , 2 , 1 ]
190-
191- def get_tile_descriptions (math_inst ):
192- min_cc = 75
193- max_cc = 1024
194- return [
195- TileDescription ([256 , 128 , 32 ], 2 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
196- TileDescription ([128 , 256 , 32 ], 2 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
197- TileDescription ([128 , 128 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
198- TileDescription ([64 , 128 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
199- TileDescription ([128 , 64 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
200- TileDescription ([64 , 64 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
201- TileDescription ([64 , 128 , 64 ], 2 , [1 , 2 , 2 ], math_inst , min_cc , max_cc ),
202- ]
203-
204- return generate_tensor_op_common (
205- math_instructions , alignment_constraints , get_tile_descriptions , batched
206- )
207-
208-
209- def generate_sm80_tensor_op_16816 (out_dtype , batched = False ):
210- """Generate GEMM kernels for Ampere."""
211- assert out_dtype in ["float32" , "float16" ]
212- math_instructions = {
213- "float32" : [
214- MathInstruction (
215- [16 , 8 , 16 ],
216- DataType .f16 ,
217- DataType .f16 ,
218- DataType .f32 ,
219- OpcodeClass .TensorOp ,
220- MathOperation .multiply_add ,
221- )
222- ],
223- "float16" : [
224- MathInstruction (
225- [16 , 8 , 16 ],
226- DataType .f16 ,
227- DataType .f16 ,
228- DataType .f16 ,
229- OpcodeClass .TensorOp ,
230- MathOperation .multiply_add ,
231- )
232- ],
233- }[out_dtype ]
234-
235- alignment_constraints = [8 , 4 , 2 ]
236-
237- def get_tile_descriptions (math_inst ):
238- min_cc = 80
239- max_cc = 1024
240- max_cc_smem_limited = 80
241- return [
242- TileDescription ([256 , 128 , 32 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
243- TileDescription ([128 , 256 , 32 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
244- TileDescription ([256 , 64 , 32 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
245- TileDescription ([64 , 256 , 32 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
246- TileDescription ([128 , 128 , 32 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
247- TileDescription ([128 , 128 , 32 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
248- TileDescription ([128 , 128 , 32 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
249- TileDescription ([128 , 64 , 32 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
250- TileDescription ([64 , 128 , 32 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
251- TileDescription ([64 , 64 , 32 ], 10 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
252- TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
253- TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
254- TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
255- TileDescription ([64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
256- TileDescription ([128 , 128 , 64 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
257- TileDescription ([128 , 64 , 64 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
258- TileDescription ([64 , 128 , 64 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
259- TileDescription ([64 , 64 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
260- ]
261-
262- return generate_tensor_op_common (
263- math_instructions , alignment_constraints , get_tile_descriptions , batched
264- )
265-
266-
267132GENERATOR_FUNC_TABLE = {
268133 75 : generate_sm75_tensor_op_1688 ,
269134 80 : generate_sm80_tensor_op_16816 ,
270135}
271136
137+
272138# TODO(masahi): A sensible way to pick reasonable default kernels
273139DEFAULT_KERNELS = {
274140 75 : {
@@ -282,67 +148,7 @@ def get_tile_descriptions(math_inst):
282148}
283149
284150
285- class ProfilerEngine :
286- """Compile and run a given profiler executable."""
287-
288- def __init__ (self , cuda_arch , cutlass_path , binary_prefix ):
289- self .cuda_arch = cuda_arch
290- self .binary_prefix = binary_prefix
291- self .cutlass = cutlass_path
292- self .cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11" .format (
293- cutlass = cutlass_path
294- )
295- self .cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
296- self .cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]" .format (
297- arch = cuda_arch
298- )
299- self .cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
300- self .cmd = "nvcc {cflags} {src} -o {output}"
301-
302- def _compile (self , op ):
303- os .makedirs (self .binary_prefix , exist_ok = True )
304- opath = os .path .join (self .binary_prefix , op ["name" ])
305- if os .path .exists (opath ):
306- return
307- fi = tempfile .NamedTemporaryFile ("w" , delete = False , suffix = ".cu" )
308- fi .write (op ["src" ])
309- fi .close ()
310- cmd = self .cmd .format (cflags = self .cflags , src = fi .name , output = opath )
311- os .system (cmd )
312- os .unlink (fi .name )
313-
314- def compile_all (self , ops , use_multiprocessing = False ):
315- """Compile all profiler executables."""
316- if use_multiprocessing :
317- pool = multiprocessing .Pool (multiprocessing .cpu_count ())
318- pool .map (self ._compile , ops )
319- else :
320- for op in ops :
321- self ._compile (op )
322-
323- def evaluate (self , op , args ):
324- """Run the profiler executable corresponding to op_name with args."""
325- op_name = op ["name" ]
326- opath = os .path .join (self .binary_prefix , op_name )
327- if not os .path .exists (opath ):
328- self ._compile (op )
329- cmd = [opath ]
330- if args is not None :
331- cmd .append (str (args [0 ]))
332- cmd .append (str (args [1 ]))
333- cmd .append (str (args [2 ]))
334- if len (args ) > 3 :
335- cmd .append (str (args [3 ]))
336- try :
337- sp = subprocess .run (cmd , capture_output = True , check = True )
338- rt = float (sp .stdout )
339- logger .info ("%s, %f" , op_name , rt )
340- except subprocess .CalledProcessError :
341- rt = - 1
342- return rt
343-
344-
345- class CutlassGemmProfiler (object ):
151+ class CutlassGemmProfiler :
346152 """Profile all candidate kernels and select the best one."""
347153
348154 def __init__ (self , sm , cutlass_path , binary_path ):
@@ -364,7 +170,9 @@ def get_default(self, out_dtype, batched=False):
364170 """Return the default kernel for the requested architecture.
365171 For now, the default kernel was picked arbitrary.
366172 """
367- ops = GENERATOR_FUNC_TABLE [self .sm ](out_dtype , batched )
173+ ops = GENERATOR_FUNC_TABLE [self .sm ](
174+ out_dtype , op_creator = partial (create_gemm_operator , batched = batched )
175+ )
368176 default_kernel_name = DEFAULT_KERNELS [self .sm ][out_dtype ]
369177 filtered = list (filter (lambda op : op ["name" ] == default_kernel_name , ops ))
370178 assert len (filtered ) == 1
@@ -380,7 +188,9 @@ def profile(
380188 if (M , N , K ) in self .cache :
381189 return self .cache [(M , N , K )]
382190
383- ops = GENERATOR_FUNC_TABLE [self .sm ](out_dtype , batched )
191+ ops = GENERATOR_FUNC_TABLE [self .sm ](
192+ out_dtype , op_creator = partial (create_gemm_operator , batched = batched )
193+ )
384194 ops = list (filter (lambda op : self .check_align (op ["name" ], M ), ops ))
385195
386196 for op in ops :
0 commit comments