Skip to content

Commit f7c3b5a

Browse files
committed
add profiler boilarplate for conv2d
1 parent ca1ae27 commit f7c3b5a

File tree

3 files changed

+91
-8
lines changed

3 files changed

+91
-8
lines changed

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
"""Conv2d kernel generator and profiler for CUTLASS."""
1919
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
2020
from .conv2d_profiler import Conv2dProfilerEmitter
21+
from .gemm_profiler import GemmProfilerEmitter
22+
from gen_tensor_op import (
23+
ProfilerEngine,
24+
generate_sm75_tensor_op_1688,
25+
generate_sm80_tensor_op_16816,
26+
)
2127
from .library import (
2228
EpilogueFunctor,
2329
SwizzlingFunctor,
@@ -36,7 +42,6 @@
3642

3743

3844
def create_conv2d_operator(
39-
layout,
4045
tile_descriptions,
4146
data_type,
4247
alignment_constraints,
@@ -51,6 +56,7 @@ def create_conv2d_operator(
5156
element_a, element_b, element_c, element_epilogue = data_type
5257
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
5358

59+
layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
5460
for tile in tile_descriptions:
5561
for alignment in alignment_constraints:
5662

@@ -105,3 +111,81 @@ def create_conv2d_operator(
105111
ret.append(op_entry)
106112

107113
return ret
114+
115+
116+
GENERATOR_FUNC_TABLE = {
117+
75: generate_sm75_tensor_op_1688,
118+
80: generate_sm80_tensor_op_16816,
119+
}
120+
121+
122+
# TODO(masahi): A sensible way to pick reasonable default kernels
123+
DEFAULT_KERNELS = {
124+
75: {
125+
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
126+
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
127+
},
128+
80: {
129+
"float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
130+
"float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
131+
},
132+
}
133+
134+
135+
class CutlassConv2DProfiler:
136+
"""Profile all candidate kernels and select the best one."""
137+
138+
def __init__(self, sm, cutlass_path, binary_path):
139+
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm
140+
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
141+
self.sm = sm
142+
self.cache = {}
143+
144+
def check_align(self, op_name, M):
145+
"""Filter out kernels that cannot be supported."""
146+
# TODO
147+
return True
148+
149+
def get_default(self, out_dtype):
150+
"""Return the default kernel for the requested architecture.
151+
For now, the default kernel was picked arbitrary.
152+
"""
153+
ops = GENERATOR_FUNC_TABLE[self.sm](
154+
out_dtype, op_creator=create_conv2d_operator
155+
)
156+
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
157+
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
158+
assert len(filtered) == 1
159+
return filtered[0]
160+
161+
def profile(
162+
self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False
163+
):
164+
"""Profile and select the best kernel from candidate kernels.
165+
If profile_all is False, return immediately after the first applicable kernel is found.
166+
If use_multiprocessing is True, compile all profiler executables in parallel.
167+
"""
168+
if (M, N, K) in self.cache:
169+
return self.cache[(M, N, K)]
170+
171+
ops = GENERATOR_FUNC_TABLE[self.sm](
172+
out_dtype, op_creator=create_conv2d_operator
173+
)
174+
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
175+
176+
for op in ops:
177+
op["runtime"] = -1
178+
179+
if profile_all:
180+
self.engine.compile_all(ops, use_multiprocessing)
181+
182+
for op in ops:
183+
out = self.engine.evaluate(op, [M, N, K])
184+
op["runtime"] = out
185+
if out > 0 and profile_all is False:
186+
break
187+
188+
valid_ops = filter(lambda op: op["runtime"] > 0, ops)
189+
output = sorted(valid_ops, key=lambda i: i["runtime"])
190+
self.cache[(M, N, K)] = output[0]
191+
return output[0]

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343

4444

4545
def _create_gemm_operator(
46-
layouts,
4746
tile_descriptions,
4847
data_type,
4948
alignment_constraints,
@@ -60,6 +59,10 @@ def _create_gemm_operator(
6059
if batched:
6160
swizzling_functor = SwizzlingFunctor.Batched
6261

62+
layouts = [
63+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
64+
]
65+
6366
for layout in layouts:
6467
for tile_description in tile_descriptions:
6568
for alignment in alignment_constraints:
@@ -135,15 +138,14 @@ def _create_gemm_operator(
135138

136139

137140
def create_gemm_operator(batched):
141+
# TODO: replace with partial
138142
def op_creator(
139-
layouts,
140143
tile_descriptions,
141144
data_type,
142145
alignment_constraints,
143146
swizzling_functor=SwizzlingFunctor.Identity8,
144147
):
145148
return _create_gemm_operator(
146-
layouts,
147149
tile_descriptions,
148150
data_type,
149151
alignment_constraints,

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ def generate_tensor_op_common(
4343
):
4444
"""Common kernel generator to be used by archtecture specific generators."""
4545
ops = []
46-
layouts = [
47-
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
48-
]
4946
for math_inst in math_instructions:
5047
tile_descriptions = get_tile_descriptions(math_inst)
5148
data_type = [
@@ -55,7 +52,7 @@ def generate_tensor_op_common(
5552
math_inst.element_accumulator,
5653
]
5754

58-
out = op_creator(layouts, tile_descriptions, data_type, alignment_constraints)
55+
out = op_creator(tile_descriptions, data_type, alignment_constraints)
5956

6057
ops.extend(out)
6158

0 commit comments

Comments
 (0)