1616# under the License.
1717# pylint: disable=invalid-name
1818"""GEMM kernel generator and profiler for CUTLASS."""
19+ from functools import partial
1920import os
2021import re
2122import tempfile
3738)
3839
3940
40- def _create_gemm_operator (
41+ def create_gemm_operator (
4142 tile_descriptions ,
4243 data_type ,
4344 alignment_constraints ,
@@ -132,25 +133,6 @@ def _create_gemm_operator(
132133 return ret
133134
134135
135- def create_gemm_operator (batched ):
136- # TODO: replace with partial
137- def op_creator (
138- tile_descriptions ,
139- data_type ,
140- alignment_constraints ,
141- swizzling_functor = SwizzlingFunctor .Identity8 ,
142- ):
143- return _create_gemm_operator (
144- tile_descriptions ,
145- data_type ,
146- alignment_constraints ,
147- swizzling_functor ,
148- batched = batched ,
149- )
150-
151- return op_creator
152-
153-
154136GENERATOR_FUNC_TABLE = {
155137 75 : generate_sm75_tensor_op_1688 ,
156138 80 : generate_sm80_tensor_op_16816 ,
@@ -193,7 +175,7 @@ def get_default(self, out_dtype, batched=False):
193175 For now, the default kernel was picked arbitrary.
194176 """
195177 ops = GENERATOR_FUNC_TABLE [self .sm ](
196- out_dtype , op_creator = create_gemm_operator ( batched )
178+ out_dtype , op_creator = partial ( create_gemm_operator , batched = batched )
197179 )
198180 default_kernel_name = DEFAULT_KERNELS [self .sm ][out_dtype ]
199181 filtered = list (filter (lambda op : op ["name" ] == default_kernel_name , ops ))
@@ -211,7 +193,7 @@ def profile(
211193 return self .cache [(M , N , K )]
212194
213195 ops = GENERATOR_FUNC_TABLE [self .sm ](
214- out_dtype , op_creator = create_gemm_operator ( batched )
196+ out_dtype , op_creator = partial ( create_gemm_operator , batched = batched )
215197 )
216198 ops = list (filter (lambda op : self .check_align (op ["name" ], M ), ops ))
217199
0 commit comments