Skip to content

Commit 347adf5

Browse files
committed
use functools.partial
1 parent df1946b commit 347adf5

File tree

2 files changed

+7
-26
lines changed

2 files changed

+7
-26
lines changed

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""GEMM kernel generator and profiler for CUTLASS."""
19+
from functools import partial
1920
import os
2021
import re
2122
import tempfile
@@ -37,7 +38,7 @@
3738
)
3839

3940

40-
def _create_gemm_operator(
41+
def create_gemm_operator(
4142
tile_descriptions,
4243
data_type,
4344
alignment_constraints,
@@ -132,25 +133,6 @@ def _create_gemm_operator(
132133
return ret
133134

134135

135-
def create_gemm_operator(batched):
136-
# TODO: replace with partial
137-
def op_creator(
138-
tile_descriptions,
139-
data_type,
140-
alignment_constraints,
141-
swizzling_functor=SwizzlingFunctor.Identity8,
142-
):
143-
return _create_gemm_operator(
144-
tile_descriptions,
145-
data_type,
146-
alignment_constraints,
147-
swizzling_functor,
148-
batched=batched,
149-
)
150-
151-
return op_creator
152-
153-
154136
GENERATOR_FUNC_TABLE = {
155137
75: generate_sm75_tensor_op_1688,
156138
80: generate_sm80_tensor_op_16816,
@@ -193,7 +175,7 @@ def get_default(self, out_dtype, batched=False):
193175
For now, the default kernel was picked arbitrary.
194176
"""
195177
ops = GENERATOR_FUNC_TABLE[self.sm](
196-
out_dtype, op_creator=create_gemm_operator(batched)
178+
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
197179
)
198180
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
199181
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
@@ -211,7 +193,7 @@ def profile(
211193
return self.cache[(M, N, K)]
212194

213195
ops = GENERATOR_FUNC_TABLE[self.sm](
214-
out_dtype, op_creator=create_gemm_operator(batched)
196+
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
215197
)
216198
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
217199

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name
18-
"""GEMM kernel generator and profiler for CUTLASS."""
18+
"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
1919
import logging
2020
import os
21-
import re
2221
import tempfile
2322
import subprocess
2423
import multiprocessing
@@ -56,7 +55,7 @@ def generate_tensor_op_common(
5655

5756

5857
def generate_sm75_tensor_op_1688(out_dtype, op_creator):
59-
"""Generate GEMM kernels for Turing."""
58+
"""Generate GEMM or Conv2D kernels for Turing."""
6059
assert out_dtype in ["float32", "float16"]
6160
math_instructions = {
6261
"float32": [
@@ -102,7 +101,7 @@ def get_tile_descriptions(math_inst):
102101

103102

104103
def generate_sm80_tensor_op_16816(out_dtype, op_creator):
105-
"""Generate GEMM kernels for Ampere."""
104+
"""Generate GEMM or Conv2D kernels for Ampere."""
106105
assert out_dtype in ["float32", "float16"]
107106
math_instructions = {
108107
"float32": [

0 commit comments

Comments
 (0)