Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 79 additions & 45 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, dangerous-default-value
"""Conv2d kernel generator and profiler for CUTLASS."""
import os
import pickle
from functools import partial
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
Expand All @@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue(
tile_description,
data_type,
alignment,
alignment_epilogue,
swizzling_functor,
split_k_slices,
):
Expand Down Expand Up @@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue(

A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue)

op = Conv2dOperation(
conv_kind,
Expand Down Expand Up @@ -110,6 +113,7 @@ def enumerate_conv2d_operators(
conv_kind,
stride_support,
split_k_slices,
alignment_c,
tile_descriptions,
data_type,
alignment_constraints,
Expand All @@ -128,47 +132,49 @@ def enumerate_conv2d_operators(

for split_k_slice in split_k_slices:
for tile in tile_descriptions:
for alignment in alignment_constraints:

A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

if element_c == DataType.s32 and A.alignment == 1:
tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128)
tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128)

op = Conv2dOperation(
conv_kind,
IteratorAlgorithm.Optimized,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
split_k_slice,
)

ret.append(
{
"src": profiler_emitter.emit(
kernel_emitter.emit(op, emit_reduction=split_k_slice > 1),
op.procedural_name(),
element_output=element_c,
split_k_slices=split_k_slice,
),
"name": op.procedural_name(),
"tile_description": tile,
"alignment": alignment,
"data_type": data_type,
"swizzle_functor": swizzling_functor,
"split_k_slices": split_k_slice,
}
)
for alignmentAB in alignment_constraints:
for alignmentC in alignment_c:

A = TensorDescription(element_a, LayoutType.TensorNHWC, alignmentAB)
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignmentAB)
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignmentC)

if element_c == DataType.s32 and A.alignment == 1:
tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128)
tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128)

op = Conv2dOperation(
conv_kind,
IteratorAlgorithm.Optimized,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
split_k_slice,
)

ret.append(
{
"src": profiler_emitter.emit(
kernel_emitter.emit(op, emit_reduction=split_k_slice > 1),
op.procedural_name(),
element_output=element_c,
split_k_slices=split_k_slice,
),
"name": op.procedural_name(),
"tile_description": tile,
"alignment": alignmentAB,
"alignment_epilogue": alignmentC,
"data_type": data_type,
"swizzle_functor": swizzling_functor,
"split_k_slices": split_k_slice,
}
)

return ret

Expand All @@ -181,7 +187,11 @@ def __init__(self, sm, cutlass_path, binary_path):
self.sm = sm
assert sm in GENERATOR_FUNC_TABLE, f"sm{sm} not supported yet."
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}
self.cache_path = os.path.join(binary_path, "cutlass_conv2d_cache.pickle")
if os.path.exists(self.cache_path):
self.cache = pickle.load(open(self.cache_path, "rb"))
else:
self.cache = {}

def get_default(
self,
Expand Down Expand Up @@ -216,6 +226,7 @@ def get_default(
tile_description,
data_type,
alignment,
alignment,
swizzling_functor,
split_k_slices=1,
)
Expand Down Expand Up @@ -265,12 +276,32 @@ def select_op(
if workload in self.cache:
return self.cache[workload]

def alignments(dtype):
if dtype in ["float16"]:
alignments = [8, 4, 2, 1]
elif dtype in ["float", "float32"]:
alignments = [4, 2, 1]
else:
raise ValueError("Unsupported data type: %s" % dtype)
return alignments

alignments_c = [align for align in alignments(out_dtype) if OC % align == 0]

if not profile_all_alignments:
alignments_c = [alignments_c[0]]

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
data_dtype,
weight_dtype,
partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices),
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
partial(
enumerate_conv2d_operators,
conv_kind,
stride_support,
split_k_slices,
alignments_c,
),
lambda align: all([dim % align == 0 for dim in [IC]]),
use_3xtf32,
profile_all_alignments,
# Use fp32 accumulation for wgrad to align with cuDNN
Expand All @@ -294,6 +325,8 @@ def select_op(

op = min(ops, key=lambda i: i["runtime"])
self.cache[workload] = op
with open(self.cache_path, "wb") as f:
pickle.dump(self.cache, f)
return op

def profile(
Expand Down Expand Up @@ -350,6 +383,7 @@ def profile(
op["tile_description"],
op["data_type"],
op["alignment"],
op["alignment_epilogue"],
op["swizzle_functor"],
op["split_k_slices"],
)
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""GEMM kernel generator and profiler for CUTLASS."""
import os
import pickle

from .gemm_operation import EmitGemmInstance, GemmOperation
from .gemm_profiler import GemmProfilerEmitter
from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine
Expand Down Expand Up @@ -152,7 +155,11 @@ def __init__(self, sm, cutlass_path, binary_path):
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, f"sm{sm} not supported yet."
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.sm = sm
self.cache = {}
self.cache_path = os.path.join(binary_path, "cutlass_gemm_cache.pickle")
if os.path.exists(self.cache_path):
self.cache = pickle.load(open(self.cache_path, "rb"))
else:
self.cache = {}

def get_default(
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
Expand Down Expand Up @@ -242,6 +249,8 @@ def select_op(

op = min(ops, key=lambda i: i["runtime"])
self.cache[(M, N, K)] = op
with open(self.cache_path, "wb") as f:
pickle.dump(self.cache, f)
return op

def profile(
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ def generate_sm80_tensor_op_16816(

def get_default_tile_descriptions(block_k_factor):
return [
([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc),
([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc),
([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc),
([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc),
([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc),
([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc),
([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
Expand All @@ -228,6 +229,9 @@ def get_default_tile_descriptions(block_k_factor):
([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited),
([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited),
([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc),
([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc),
([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, max_cc),
([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc),
([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc),
Expand Down