Skip to content

Commit adf560e

Browse files
authored
[CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)
* split non-gemm specific generator code to gen_tensor_op.py commit 250f915 Author: Masahiro Masuda <[email protected]> Date: Sun Nov 14 06:44:52 2021 +0900 remove conv2d stuff commit 1a6b27c Author: Masahiro Masuda <[email protected]> Date: Sun Nov 14 06:41:31 2021 +0900 remove unused import commit f7c3b5a Author: Masahiro Masuda <[email protected]> Date: Sun Nov 14 06:37:07 2021 +0900 add profiler boilarplate for conv2d commit ca1ae27 Author: Masahiro Masuda <[email protected]> Date: Sun Nov 14 06:22:06 2021 +0900 introduce gen_tensor_op.py commit 37bb918 Author: Masahiro Masuda <[email protected]> Date: Sun Nov 14 05:45:41 2021 +0900 more conv2d code commit 5c00398 Author: Masahiro Masuda <[email protected]> Date: Sun Nov 14 05:13:30 2021 +0900 Begin conv2d support * fix * use functools.partial * remove unused import
1 parent fb4b7e2 commit adf560e

File tree

3 files changed

+238
-211
lines changed

3 files changed

+238
-211
lines changed

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 20 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,29 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name
18-
"""Kernel generator and profiler for CUTLASS."""
19-
import logging
20-
import os
18+
"""GEMM kernel generator and profiler for CUTLASS."""
19+
from functools import partial
2120
import re
22-
import tempfile
23-
import subprocess
24-
import multiprocessing
2521
from .gemm_operation import GemmOperation, EmitGemmInstance
2622
from .gemm_profiler import GemmProfilerEmitter
23+
from .gen_tensor_op import (
24+
ProfilerEngine,
25+
generate_sm75_tensor_op_1688,
26+
generate_sm80_tensor_op_16816,
27+
)
2728
from .library import (
2829
EpilogueFunctor,
2930
SwizzlingFunctor,
3031
TensorDescription,
3132
DataTypeTag,
3233
LayoutType,
33-
MathInstruction,
34-
DataType,
35-
OpcodeClass,
36-
MathOperation,
37-
TileDescription,
3834
)
3935

40-
logger = logging.getLogger("cutlass")
41-
4236

4337
def create_gemm_operator(
44-
layouts,
4538
tile_descriptions,
4639
data_type,
4740
alignment_constraints,
48-
epilogue_functor=EpilogueFunctor.LinearCombination,
4941
swizzling_functor=SwizzlingFunctor.Identity8,
5042
batched=False,
5143
):
@@ -59,6 +51,10 @@ def create_gemm_operator(
5951
if batched:
6052
swizzling_functor = SwizzlingFunctor.Batched
6153

54+
layouts = [
55+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
56+
]
57+
6258
for layout in layouts:
6359
for tile_description in tile_descriptions:
6460
for alignment in alignment_constraints:
@@ -76,7 +72,7 @@ def create_gemm_operator(
7672
B,
7773
C,
7874
element_epilogue,
79-
epilogue_functor,
75+
EpilogueFunctor.LinearCombination,
8076
swizzling_functor,
8177
)
8278
op_bias = GemmOperation(
@@ -110,7 +106,6 @@ def create_gemm_operator(
110106
swizzling_functor,
111107
)
112108

113-
kernel_emitter = EmitGemmInstance()
114109
op_entry["op"] = op
115110
op_entry["name"] = op.procedural_name()
116111
op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
@@ -134,141 +129,12 @@ def create_gemm_operator(
134129
return ret
135130

136131

137-
def generate_tensor_op_common(
138-
math_instructions, alignment_constraints, get_tile_descriptions, batched=False
139-
):
140-
"""Common kernel generator to be used by archtecture specific generators."""
141-
ops = []
142-
layouts = [
143-
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
144-
]
145-
for math_inst in math_instructions:
146-
tile_descriptions = get_tile_descriptions(math_inst)
147-
data_type = [
148-
math_inst.element_a,
149-
math_inst.element_b,
150-
math_inst.element_accumulator,
151-
math_inst.element_accumulator,
152-
]
153-
154-
out = create_gemm_operator(
155-
layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
156-
)
157-
158-
ops.extend(out)
159-
160-
return ops
161-
162-
163-
def generate_sm75_tensor_op_1688(out_dtype, batched=False):
164-
"""Generate GEMM kernels for Turing."""
165-
assert out_dtype in ["float32", "float16"]
166-
math_instructions = {
167-
"float32": [
168-
MathInstruction(
169-
[16, 8, 8],
170-
DataType.f16,
171-
DataType.f16,
172-
DataType.f32,
173-
OpcodeClass.TensorOp,
174-
MathOperation.multiply_add,
175-
)
176-
],
177-
"float16": [
178-
MathInstruction(
179-
[16, 8, 8],
180-
DataType.f16,
181-
DataType.f16,
182-
DataType.f16,
183-
OpcodeClass.TensorOp,
184-
MathOperation.multiply_add,
185-
)
186-
],
187-
}[out_dtype]
188-
189-
alignment_constraints = [8, 4, 2, 1]
190-
191-
def get_tile_descriptions(math_inst):
192-
min_cc = 75
193-
max_cc = 1024
194-
return [
195-
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
196-
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
197-
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
198-
TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
199-
TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
200-
TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
201-
TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc),
202-
]
203-
204-
return generate_tensor_op_common(
205-
math_instructions, alignment_constraints, get_tile_descriptions, batched
206-
)
207-
208-
209-
def generate_sm80_tensor_op_16816(out_dtype, batched=False):
210-
"""Generate GEMM kernels for Ampere."""
211-
assert out_dtype in ["float32", "float16"]
212-
math_instructions = {
213-
"float32": [
214-
MathInstruction(
215-
[16, 8, 16],
216-
DataType.f16,
217-
DataType.f16,
218-
DataType.f32,
219-
OpcodeClass.TensorOp,
220-
MathOperation.multiply_add,
221-
)
222-
],
223-
"float16": [
224-
MathInstruction(
225-
[16, 8, 16],
226-
DataType.f16,
227-
DataType.f16,
228-
DataType.f16,
229-
OpcodeClass.TensorOp,
230-
MathOperation.multiply_add,
231-
)
232-
],
233-
}[out_dtype]
234-
235-
alignment_constraints = [8, 4, 2]
236-
237-
def get_tile_descriptions(math_inst):
238-
min_cc = 80
239-
max_cc = 1024
240-
max_cc_smem_limited = 80
241-
return [
242-
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
243-
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
244-
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
245-
TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
246-
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
247-
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
248-
TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
249-
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
250-
TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
251-
TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
252-
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
253-
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
254-
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
255-
TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
256-
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
257-
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
258-
TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
259-
TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
260-
]
261-
262-
return generate_tensor_op_common(
263-
math_instructions, alignment_constraints, get_tile_descriptions, batched
264-
)
265-
266-
267132
GENERATOR_FUNC_TABLE = {
268133
75: generate_sm75_tensor_op_1688,
269134
80: generate_sm80_tensor_op_16816,
270135
}
271136

137+
272138
# TODO(masahi): A sensible way to pick reasonable default kernels
273139
DEFAULT_KERNELS = {
274140
75: {
@@ -282,67 +148,7 @@ def get_tile_descriptions(math_inst):
282148
}
283149

284150

285-
class ProfilerEngine:
286-
"""Compile and run a given profiler executable."""
287-
288-
def __init__(self, cuda_arch, cutlass_path, binary_prefix):
289-
self.cuda_arch = cuda_arch
290-
self.binary_prefix = binary_prefix
291-
self.cutlass = cutlass_path
292-
self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
293-
cutlass=cutlass_path
294-
)
295-
self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
296-
self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format(
297-
arch=cuda_arch
298-
)
299-
self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
300-
self.cmd = "nvcc {cflags} {src} -o {output}"
301-
302-
def _compile(self, op):
303-
os.makedirs(self.binary_prefix, exist_ok=True)
304-
opath = os.path.join(self.binary_prefix, op["name"])
305-
if os.path.exists(opath):
306-
return
307-
fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu")
308-
fi.write(op["src"])
309-
fi.close()
310-
cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
311-
os.system(cmd)
312-
os.unlink(fi.name)
313-
314-
def compile_all(self, ops, use_multiprocessing=False):
315-
"""Compile all profiler executables."""
316-
if use_multiprocessing:
317-
pool = multiprocessing.Pool(multiprocessing.cpu_count())
318-
pool.map(self._compile, ops)
319-
else:
320-
for op in ops:
321-
self._compile(op)
322-
323-
def evaluate(self, op, args):
324-
"""Run the profiler executable corresponding to op_name with args."""
325-
op_name = op["name"]
326-
opath = os.path.join(self.binary_prefix, op_name)
327-
if not os.path.exists(opath):
328-
self._compile(op)
329-
cmd = [opath]
330-
if args is not None:
331-
cmd.append(str(args[0]))
332-
cmd.append(str(args[1]))
333-
cmd.append(str(args[2]))
334-
if len(args) > 3:
335-
cmd.append(str(args[3]))
336-
try:
337-
sp = subprocess.run(cmd, capture_output=True, check=True)
338-
rt = float(sp.stdout)
339-
logger.info("%s, %f", op_name, rt)
340-
except subprocess.CalledProcessError:
341-
rt = -1
342-
return rt
343-
344-
345-
class CutlassGemmProfiler(object):
151+
class CutlassGemmProfiler:
346152
"""Profile all candidate kernels and select the best one."""
347153

348154
def __init__(self, sm, cutlass_path, binary_path):
@@ -364,7 +170,9 @@ def get_default(self, out_dtype, batched=False):
364170
"""Return the default kernel for the requested architecture.
365171
For now, the default kernel was picked arbitrary.
366172
"""
367-
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
173+
ops = GENERATOR_FUNC_TABLE[self.sm](
174+
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
175+
)
368176
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
369177
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
370178
assert len(filtered) == 1
@@ -380,7 +188,9 @@ def profile(
380188
if (M, N, K) in self.cache:
381189
return self.cache[(M, N, K)]
382190

383-
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
191+
ops = GENERATOR_FUNC_TABLE[self.sm](
192+
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
193+
)
384194
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
385195

386196
for op in ops:

0 commit comments

Comments
 (0)