Skip to content

Commit 8afa6d2

Browse files
[CUTLASS][Cherry-pick] Introduce several features of cutlass profiler (#15573)
* [Contrib] Introduce several features of cutlass profiler - allow Conv2d using different alignment factors for input and epilogue, which can influence performance - store the profiler cache on disk, reducing CUTLASS profiler overhead across different runs - use the same set of default tile configurations as CUTLASS for sm80 https://github.com/NVIDIA/cutlass/blob/master/tools/library/scripts/generator.py#L1881 * skip profiling all conv2d output alignments when possible --------- Co-authored-by: Bohan Hou <[email protected]>
1 parent 760c030 commit 8afa6d2

File tree

3 files changed

+94
-47
lines changed

3 files changed

+94
-47
lines changed

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# under the License.
1717
# pylint: disable=invalid-name, dangerous-default-value
1818
"""Conv2d kernel generator and profiler for CUTLASS."""
19+
import os
20+
import pickle
1921
from functools import partial
2022
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
2123
from .gen_gemm import CutlassGemmProfiler
@@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue(
4042
tile_description,
4143
data_type,
4244
alignment,
45+
alignment_epilogue,
4346
swizzling_functor,
4447
split_k_slices,
4548
):
@@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue(
7881

7982
A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
8083
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
81-
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
84+
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue)
8285

8386
op = Conv2dOperation(
8487
conv_kind,
@@ -110,6 +113,7 @@ def enumerate_conv2d_operators(
110113
conv_kind,
111114
stride_support,
112115
split_k_slices,
116+
alignment_c,
113117
tile_descriptions,
114118
data_type,
115119
alignment_constraints,
@@ -128,47 +132,49 @@ def enumerate_conv2d_operators(
128132

129133
for split_k_slice in split_k_slices:
130134
for tile in tile_descriptions:
131-
for alignment in alignment_constraints:
132-
133-
A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
134-
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
135-
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
136-
137-
if element_c == DataType.s32 and A.alignment == 1:
138-
tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128)
139-
tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128)
140-
141-
op = Conv2dOperation(
142-
conv_kind,
143-
IteratorAlgorithm.Optimized,
144-
tile.minimum_compute_capability,
145-
tile,
146-
A,
147-
B,
148-
C,
149-
element_epilogue,
150-
stride_support,
151-
EpilogueFunctor.LinearCombination,
152-
swizzling_functor,
153-
split_k_slice,
154-
)
155-
156-
ret.append(
157-
{
158-
"src": profiler_emitter.emit(
159-
kernel_emitter.emit(op, emit_reduction=split_k_slice > 1),
160-
op.procedural_name(),
161-
element_output=element_c,
162-
split_k_slices=split_k_slice,
163-
),
164-
"name": op.procedural_name(),
165-
"tile_description": tile,
166-
"alignment": alignment,
167-
"data_type": data_type,
168-
"swizzle_functor": swizzling_functor,
169-
"split_k_slices": split_k_slice,
170-
}
171-
)
135+
for alignmentAB in alignment_constraints:
136+
for alignmentC in alignment_c:
137+
138+
A = TensorDescription(element_a, LayoutType.TensorNHWC, alignmentAB)
139+
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignmentAB)
140+
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignmentC)
141+
142+
if element_c == DataType.s32 and A.alignment == 1:
143+
tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128)
144+
tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128)
145+
146+
op = Conv2dOperation(
147+
conv_kind,
148+
IteratorAlgorithm.Optimized,
149+
tile.minimum_compute_capability,
150+
tile,
151+
A,
152+
B,
153+
C,
154+
element_epilogue,
155+
stride_support,
156+
EpilogueFunctor.LinearCombination,
157+
swizzling_functor,
158+
split_k_slice,
159+
)
160+
161+
ret.append(
162+
{
163+
"src": profiler_emitter.emit(
164+
kernel_emitter.emit(op, emit_reduction=split_k_slice > 1),
165+
op.procedural_name(),
166+
element_output=element_c,
167+
split_k_slices=split_k_slice,
168+
),
169+
"name": op.procedural_name(),
170+
"tile_description": tile,
171+
"alignment": alignmentAB,
172+
"alignment_epilogue": alignmentC,
173+
"data_type": data_type,
174+
"swizzle_functor": swizzling_functor,
175+
"split_k_slices": split_k_slice,
176+
}
177+
)
172178

173179
return ret
174180

@@ -181,7 +187,11 @@ def __init__(self, sm, cutlass_path, binary_path):
181187
self.sm = sm
182188
assert sm in GENERATOR_FUNC_TABLE, f"sm{sm} not supported yet."
183189
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
184-
self.cache = {}
190+
self.cache_path = os.path.join(binary_path, "cutlass_conv2d_cache.pickle")
191+
if os.path.exists(self.cache_path):
192+
self.cache = pickle.load(open(self.cache_path, "rb"))
193+
else:
194+
self.cache = {}
185195

186196
def get_default(
187197
self,
@@ -216,6 +226,7 @@ def get_default(
216226
tile_description,
217227
data_type,
218228
alignment,
229+
alignment,
219230
swizzling_functor,
220231
split_k_slices=1,
221232
)
@@ -265,12 +276,32 @@ def select_op(
265276
if workload in self.cache:
266277
return self.cache[workload]
267278

279+
def alignments(dtype):
280+
if dtype in ["float16"]:
281+
alignments = [8, 4, 2, 1]
282+
elif dtype in ["float", "float32"]:
283+
alignments = [4, 2, 1]
284+
else:
285+
raise ValueError("Unsupported data type: %s" % dtype)
286+
return alignments
287+
288+
alignments_c = [align for align in alignments(out_dtype) if OC % align == 0]
289+
290+
if not profile_all_alignments:
291+
alignments_c = [alignments_c[0]]
292+
268293
ops = GENERATOR_FUNC_TABLE[self.sm](
269294
out_dtype,
270295
data_dtype,
271296
weight_dtype,
272-
partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices),
273-
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
297+
partial(
298+
enumerate_conv2d_operators,
299+
conv_kind,
300+
stride_support,
301+
split_k_slices,
302+
alignments_c,
303+
),
304+
lambda align: all([dim % align == 0 for dim in [IC]]),
274305
use_3xtf32,
275306
profile_all_alignments,
276307
# Use fp32 accumulation for wgrad to align with cuDNN
@@ -294,6 +325,8 @@ def select_op(
294325

295326
op = min(ops, key=lambda i: i["runtime"])
296327
self.cache[workload] = op
328+
with open(self.cache_path, "wb") as f:
329+
pickle.dump(self.cache, f)
297330
return op
298331

299332
def profile(
@@ -350,6 +383,7 @@ def profile(
350383
op["tile_description"],
351384
op["data_type"],
352385
op["alignment"],
386+
op["alignment_epilogue"],
353387
op["swizzle_functor"],
354388
op["split_k_slices"],
355389
)

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""GEMM kernel generator and profiler for CUTLASS."""
19+
import os
20+
import pickle
21+
1922
from .gemm_operation import EmitGemmInstance, GemmOperation
2023
from .gemm_profiler import GemmProfilerEmitter
2124
from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine
@@ -152,7 +155,11 @@ def __init__(self, sm, cutlass_path, binary_path):
152155
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, f"sm{sm} not supported yet."
153156
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
154157
self.sm = sm
155-
self.cache = {}
158+
self.cache_path = os.path.join(binary_path, "cutlass_gemm_cache.pickle")
159+
if os.path.exists(self.cache_path):
160+
self.cache = pickle.load(open(self.cache_path, "rb"))
161+
else:
162+
self.cache = {}
156163

157164
def get_default(
158165
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
@@ -242,6 +249,8 @@ def select_op(
242249

243250
op = min(ops, key=lambda i: i["runtime"])
244251
self.cache[(M, N, K)] = op
252+
with open(self.cache_path, "wb") as f:
253+
pickle.dump(self.cache, f)
245254
return op
246255

247256
def profile(

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ def generate_sm80_tensor_op_16816(
213213

214214
def get_default_tile_descriptions(block_k_factor):
215215
return [
216-
([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc),
217216
([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc),
217+
([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc),
218+
([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc),
218219
([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc),
219220
([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc),
220221
([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
@@ -228,6 +229,9 @@ def get_default_tile_descriptions(block_k_factor):
228229
([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited),
229230
([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited),
230231
([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc),
232+
([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc),
233+
([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, max_cc),
234+
([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
231235
([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
232236
([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
233237
([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc),

0 commit comments

Comments
 (0)