Skip to content

Commit 37bb918

Browse files
committed
more conv2d code
1 parent 5c00398 commit 37bb918

File tree

5 files changed

+162
-29
lines changed

5 files changed

+162
-29
lines changed

python/tvm/contrib/cutlass/conv2d_operation.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,11 @@ def __init__(
5555
self.stride_support = stride_support
5656
self.swizzling_functor = swizzling_functor
5757

58-
#
59-
def is_complex(self):
60-
complex_operators = [
61-
MathOperation.multiply_add_complex,
62-
MathOperation.multiply_add_complex_gaussian,
63-
]
64-
return self.tile_description.math_instruction.math_operation in complex_operators
6558

66-
#
6759
def accumulator_type(self):
68-
accum = self.tile_description.math_instruction.element_accumulator
60+
return self.tile_description.math_instruction.element_accumulator
6961

70-
if self.is_complex():
71-
return get_complex_from_real(accum)
7262

73-
return accum
74-
75-
#
7663
def core_name(self):
7764
""" The basic operation kind is prefixed with a letter indicating the accumulation type. """
7865

@@ -112,7 +99,7 @@ def extended_name(self):
11299
else:
113100
extended_name = "${core_name}"
114101

115-
extended_name = SubstituteTemplate(
102+
extended_name = substitute_template(
116103
extended_name,
117104
{
118105
"element_a": DataTypeNames[self.A.element],
@@ -145,7 +132,7 @@ def configuration_name(self):
145132
else:
146133
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
147134

148-
return SubstituteTemplate(
135+
return substitute_template(
149136
configuration_name,
150137
{
151138
"opcode_class": opcode_class_name,
@@ -258,14 +245,12 @@ def emit(self, operation):
258245
operation.iterator_algorithm
259246
].capitalize(),
260247
"stride_support": StrideSupportTag[operation.stride_support],
261-
"math_operator": "cutlass::arch::OpMultiplyAddComplex"
262-
if operation.is_complex()
263-
else MathOperationTag[operation.tile_description.math_instruction.math_operation],
248+
"math_operator": MathOperationTag[operation.tile_description.math_instruction.math_operation],
264249
"align_a": str(operation.A.alignment),
265250
"align_b": str(operation.B.alignment),
266251
}
267252

268-
return SubstituteTemplate(self.template, values)
253+
return substitute_template(self.template, values)
269254

270255

271256
class EmitConv2dConfigurationLibrary:
@@ -340,7 +325,7 @@ def __init__(self, operation_path, configuration_name):
340325
def __enter__(self):
341326
self.configuration_file = open(self.configuration_path, "w")
342327
self.configuration_file.write(
343-
SubstituteTemplate(
328+
substitute_template(
344329
self.header_template, {"configuration_name": self.configuration_name}
345330
)
346331
)
@@ -351,7 +336,7 @@ def __enter__(self):
351336
def emit(self, operation):
352337
self.operations.append(operation)
353338
self.configuration_file.write(
354-
SubstituteTemplate(
339+
substitute_template(
355340
self.instance_template,
356341
{
357342
"configuration_name": self.configuration_name,
@@ -365,14 +350,14 @@ def emit(self, operation):
365350
def __exit__(self, exception_type, exception_value, traceback):
366351

367352
self.configuration_file.write(
368-
SubstituteTemplate(
353+
substitute_template(
369354
self.configuration_header, {"configuration_name": self.configuration_name}
370355
)
371356
)
372357

373358
for operation in self.operations:
374359
self.configuration_file.write(
375-
SubstituteTemplate(
360+
substitute_template(
376361
self.configuration_instance,
377362
{
378363
"configuration_name": self.configuration_name,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=import-outside-toplevel, invalid-name
18+
"""Instantiate a C++ source for profiling CUTLASS kernels."""
19+
from .gemm_profiler import GemmProfilerEmitter
20+
21+
22+
class Conv2dProfilerEmitter:
23+
def __init__(self):
24+
self.gemm_profiler_emitter = GemmProfilerEmitter()
25+
26+
def emit(self, op_name, op_def, dtype_a, dtype_b, dtype_c, ld):
27+
return self.gemm_profiler_emitter(op_name, op_def, dtype_a, dtype_b, dtype_c, ld)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Conv2d kernel generator and profiler for CUTLASS."""
19+
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
20+
from .conv2d_profiler import Conv2dProfilerEmitter
21+
from .library import (
22+
EpilogueFunctor,
23+
SwizzlingFunctor,
24+
TensorDescription,
25+
DataTypeTag,
26+
LayoutType,
27+
MathInstruction,
28+
DataType,
29+
OpcodeClass,
30+
MathOperation,
31+
TileDescription,
32+
ConvKind,
33+
IteratorAlgorithm,
34+
StrideSupport,
35+
)
36+
37+
38+
def create_conv2d_operator(
39+
layout,
40+
tile_descriptions,
41+
data_type,
42+
alignment_constraints,
43+
swizzling_functor=SwizzlingFunctor.Identity4,
44+
):
45+
"""Exhaustively instantiate all kernels from a given configuration."""
46+
ret = []
47+
48+
profiler_emitter = Conv2dProfilerEmitter()
49+
kernel_emitter = EmitConv2dInstance()
50+
51+
element_a, element_b, element_c, element_epilogue = data_type
52+
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
53+
54+
for tile in tile_descriptions:
55+
for alignment in alignment_constraints:
56+
57+
alignment_c = min(8, alignment)
58+
59+
A = TensorDescription(element_a, layout[0], alignment)
60+
B = TensorDescription(element_b, layout[1], alignment)
61+
C = TensorDescription(element_c, layout[2], alignment_c)
62+
63+
swizzling_functor_ = swizzling_functor
64+
65+
for iterator_algorithm in iterator_algorithms:
66+
op_entry = {}
67+
68+
for epilogue, opdef in zip(
69+
[
70+
EpilogueFunctor.LinearCombination,
71+
EpilogueFunctor.LinearCombinationBias,
72+
EpilogueFunctor.LinearCombinationRelu,
73+
],
74+
["opdef", "opdef_bias", "opdef_bias_relu"],
75+
):
76+
op = Conv2dOperation(
77+
ConvKind.Fprop,
78+
iterator_algorithm,
79+
tile.minimum_compute_capability,
80+
tile,
81+
A,
82+
B,
83+
C,
84+
element_epilogue,
85+
StrideSupport.Strided,
86+
epilogue,
87+
swizzling_functor_,
88+
)
89+
90+
op_entry[opdef] = kernel_emitter.emit(op)
91+
92+
op = op_entry["opdef"]
93+
op_entry["op"] = op
94+
op_entry["name"] = op.procedural_name()
95+
op_entry["src"] = profiler_emitter.emit(
96+
op.procedural_name(),
97+
op,
98+
DataTypeTag[element_a],
99+
DataTypeTag[element_b],
100+
DataTypeTag[element_c],
101+
op.leading_dim(),
102+
)
103+
op_entry["runtime"] = 9999999
104+
105+
ret.append(op_entry)
106+
107+
return ret

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name
18-
"""Kernel generator and profiler for CUTLASS."""
18+
"""GEMM kernel generator and profiler for CUTLASS."""
1919
import logging
2020
import os
2121
import re
@@ -45,7 +45,6 @@ def create_gemm_operator(
4545
tile_descriptions,
4646
data_type,
4747
alignment_constraints,
48-
epilogue_functor=EpilogueFunctor.LinearCombination,
4948
swizzling_functor=SwizzlingFunctor.Identity8,
5049
batched=False,
5150
):
@@ -76,7 +75,7 @@ def create_gemm_operator(
7675
B,
7776
C,
7877
element_epilogue,
79-
epilogue_functor,
78+
EpilogueFunctor.LinearCombination,
8079
swizzling_functor,
8180
)
8281
op_bias = GemmOperation(
@@ -110,7 +109,6 @@ def create_gemm_operator(
110109
swizzling_functor,
111110
)
112111

113-
kernel_emitter = EmitGemmInstance()
114112
op_entry["op"] = op
115113
op_entry["name"] = op.procedural_name()
116114
op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
@@ -342,7 +340,7 @@ def evaluate(self, op, args):
342340
return rt
343341

344342

345-
class CutlassGemmProfiler(object):
343+
class CutlassGemmProfiler:
346344
"""Profile all candidate kernels and select the best one."""
347345

348346
def __init__(self, sm, cutlass_path, binary_path):

python/tvm/contrib/cutlass/library.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,22 @@ class ConvKind(enum.Enum):
192192
}
193193

194194

195+
class StrideSupport(enum.Enum):
196+
Strided = enum_auto()
197+
Unity = enum_auto()
198+
199+
200+
StrideSupportTag = {
201+
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
202+
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
203+
}
204+
205+
StrideSupportNames = {
206+
StrideSupport.Strided: '',
207+
StrideSupport.Unity: 'unity_stride',
208+
}
209+
210+
195211
class IteratorAlgorithm(enum.Enum):
196212
Analytic = enum_auto()
197213
Optimized = enum_auto()

0 commit comments

Comments
 (0)