Skip to content

Commit 5c00398

Browse files
committed
Begin conv2d support
1 parent 6159b8e commit 5c00398

File tree

2 files changed

+422
-0
lines changed

2 files changed

+422
-0
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
18+
"""Generator for CUTLASS Conv2D kernels."""
19+
import enum
20+
import os.path
21+
import shutil
22+
23+
from library import *
24+
25+
###################################################################################################
26+
27+
#
28+
class Conv2dOperation:
29+
#
30+
def __init__(
31+
self,
32+
conv_kind,
33+
iterator_algorithm,
34+
arch,
35+
tile_description,
36+
A,
37+
B,
38+
C,
39+
element_epilogue,
40+
stride_support,
41+
epilogue_functor=EpilogueFunctor.LinearCombination,
42+
swizzling_functor=SwizzlingFunctor.Identity1,
43+
):
44+
45+
self.operation_kind = OperationKind.Conv2d
46+
self.arch = arch
47+
self.tile_description = tile_description
48+
self.conv_kind = conv_kind
49+
self.A = A
50+
self.B = B
51+
self.C = C
52+
self.element_epilogue = element_epilogue
53+
self.epilogue_functor = epilogue_functor
54+
self.iterator_algorithm = iterator_algorithm
55+
self.stride_support = stride_support
56+
self.swizzling_functor = swizzling_functor
57+
58+
#
59+
def is_complex(self):
60+
complex_operators = [
61+
MathOperation.multiply_add_complex,
62+
MathOperation.multiply_add_complex_gaussian,
63+
]
64+
return self.tile_description.math_instruction.math_operation in complex_operators
65+
66+
#
67+
def accumulator_type(self):
68+
accum = self.tile_description.math_instruction.element_accumulator
69+
70+
if self.is_complex():
71+
return get_complex_from_real(accum)
72+
73+
return accum
74+
75+
#
76+
def core_name(self):
77+
""" The basic operation kind is prefixed with a letter indicating the accumulation type. """
78+
79+
intermediate_type = ""
80+
81+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
82+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
83+
if (
84+
self.tile_description.math_instruction.element_a != self.A.element
85+
and self.tile_description.math_instruction.element_a != self.accumulator_type()
86+
):
87+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
88+
else:
89+
inst_shape = ""
90+
91+
return "%s%s%s%s_%s" % (
92+
ShortDataTypeNames[self.accumulator_type()],
93+
inst_shape,
94+
intermediate_type,
95+
ConvKindNames[self.conv_kind],
96+
IteratorAlgorithmNames[self.iterator_algorithm],
97+
)
98+
99+
#
100+
def extended_name(self):
101+
""" Append data types if they differ from compute type. """
102+
if (
103+
self.C.element != self.tile_description.math_instruction.element_accumulator
104+
and self.A.element != self.tile_description.math_instruction.element_accumulator
105+
):
106+
extended_name = "${element_c}_${core_name}_${element_a}"
107+
elif (
108+
self.C.element == self.tile_description.math_instruction.element_accumulator
109+
and self.A.element != self.tile_description.math_instruction.element_accumulator
110+
):
111+
extended_name = "${core_name}_${element_a}"
112+
else:
113+
extended_name = "${core_name}"
114+
115+
extended_name = SubstituteTemplate(
116+
extended_name,
117+
{
118+
"element_a": DataTypeNames[self.A.element],
119+
"element_c": DataTypeNames[self.C.element],
120+
"core_name": self.core_name(),
121+
},
122+
)
123+
124+
return extended_name
125+
126+
#
127+
def layout_name(self):
128+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
129+
130+
#
131+
def configuration_name(self):
132+
""" The full procedural name indicates architecture, extended name, tile size, and layout. """
133+
134+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
135+
136+
threadblock = "%dx%d_%dx%d" % (
137+
self.tile_description.threadblock_shape[0],
138+
self.tile_description.threadblock_shape[1],
139+
self.tile_description.threadblock_shape[2],
140+
self.tile_description.stages,
141+
)
142+
143+
if self.stride_support == StrideSupport.Unity:
144+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}_unity_stride"
145+
else:
146+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
147+
148+
return SubstituteTemplate(
149+
configuration_name,
150+
{
151+
"opcode_class": opcode_class_name,
152+
"extended_name": self.extended_name(),
153+
"threadblock": threadblock,
154+
"layout": self.layout_name(),
155+
"alignment": "%d" % self.A.alignment,
156+
},
157+
)
158+
159+
#
160+
def procedural_name(self):
161+
""" The full procedural name indicates architecture, extended name, tile size, and layout. """
162+
return self.configuration_name()
163+
164+
165+
###################################################################################################
166+
#
167+
# Emits single instances of a CUTLASS device-wide operator
168+
#
169+
###################################################################################################
170+
171+
172+
class EmitConv2dInstance:
173+
def __init__(self):
174+
self.template = """
175+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
176+
using ${operation_name}_base =
177+
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
178+
${element_a},
179+
${layout_a},
180+
${element_b},
181+
${layout_b},
182+
${element_c},
183+
${layout_c},
184+
${element_accumulator},
185+
${opcode_class},
186+
${arch},
187+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
188+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
189+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
190+
${epilogue_functor}<
191+
${element_c},
192+
${epilogue_vector_length},
193+
${element_accumulator},
194+
${element_epilogue}
195+
>,
196+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
197+
${stages},
198+
${math_operator},
199+
${iterator_algorithm},
200+
${stride_support},
201+
${align_a},
202+
${align_b}
203+
>::Kernel;
204+
"""
205+
206+
def emit(self, operation):
207+
208+
warp_shape = [
209+
int(
210+
operation.tile_description.threadblock_shape[idx]
211+
/ operation.tile_description.warp_count[idx]
212+
)
213+
for idx in range(3)
214+
]
215+
216+
epilogue_vector_length = int(
217+
min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
218+
/ DataTypeSize[operation.C.element]
219+
)
220+
221+
values = {
222+
"operation_name": operation.procedural_name(),
223+
"conv_kind": ConvKindTag[operation.conv_kind],
224+
"conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
225+
"element_a": DataTypeTag[operation.A.element],
226+
"layout_a": LayoutTag[operation.A.layout],
227+
"element_b": DataTypeTag[operation.B.element],
228+
"layout_b": LayoutTag[operation.B.layout],
229+
"element_c": DataTypeTag[operation.C.element],
230+
"layout_c": LayoutTag[operation.C.layout],
231+
"element_accumulator": DataTypeTag[operation.accumulator_type()],
232+
"opcode_class": OpcodeClassTag[
233+
operation.tile_description.math_instruction.opcode_class
234+
],
235+
"arch": "cutlass::arch::Sm%d" % operation.arch,
236+
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
237+
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
238+
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
239+
"warp_shape_m": str(warp_shape[0]),
240+
"warp_shape_n": str(warp_shape[1]),
241+
"warp_shape_k": str(warp_shape[2]),
242+
"instruction_shape_m": str(
243+
operation.tile_description.math_instruction.instruction_shape[0]
244+
),
245+
"instruction_shape_n": str(
246+
operation.tile_description.math_instruction.instruction_shape[1]
247+
),
248+
"instruction_shape_k": str(
249+
operation.tile_description.math_instruction.instruction_shape[2]
250+
),
251+
"epilogue_vector_length": str(epilogue_vector_length),
252+
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
253+
"element_epilogue": str(DataTypeTag[operation.element_epilogue]),
254+
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
255+
"stages": str(operation.tile_description.stages),
256+
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
257+
"iterator_algorithm_name": IteratorAlgorithmNames[
258+
operation.iterator_algorithm
259+
].capitalize(),
260+
"stride_support": StrideSupportTag[operation.stride_support],
261+
"math_operator": "cutlass::arch::OpMultiplyAddComplex"
262+
if operation.is_complex()
263+
else MathOperationTag[operation.tile_description.math_instruction.math_operation],
264+
"align_a": str(operation.A.alignment),
265+
"align_b": str(operation.B.alignment),
266+
}
267+
268+
return SubstituteTemplate(self.template, values)
269+
270+
271+
class EmitConv2dConfigurationLibrary:
272+
def __init__(self, operation_path, configuration_name):
273+
self.configuration_name = configuration_name
274+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
275+
276+
self.instance_emitter = EmitConv2dInstance()
277+
278+
self.instance_template = """
279+
${operation_instance}
280+
281+
// Derived class
282+
struct ${operation_name} :
283+
public ${operation_name}_base { };
284+
285+
///////////////////////////////////////////////////////////////////////////////////////////////////
286+
287+
"""
288+
self.header_template = """
289+
/*
290+
Generated by conv2d_operation.py - Do not edit.
291+
*/
292+
293+
///////////////////////////////////////////////////////////////////////////////////////////////////
294+
295+
#include "cutlass/cutlass.h"
296+
#include "cutlass/library/library.h"
297+
#include "cutlass/library/manifest.h"
298+
299+
#include "library_internal.h"
300+
#include "conv2d_operation.h"
301+
302+
///////////////////////////////////////////////////////////////////////////////////////////////////
303+
"""
304+
305+
self.configuration_header = """
306+
307+
namespace cutlass {
308+
namespace library {
309+
310+
// Initialize all instances
311+
void initialize_${configuration_name}(Manifest &manifest) {
312+
313+
"""
314+
315+
self.configuration_instance = """
316+
using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
317+
${operation_name}>;
318+
319+
manifest.append(new cutlass::library::Conv2dOperation<
320+
Operation_${operation_name}>(
321+
"${operation_name}"));
322+
323+
"""
324+
325+
self.configuration_epilogue = """
326+
}
327+
"""
328+
self.epilogue_template = """
329+
330+
///////////////////////////////////////////////////////////////////////////////////////////////////
331+
332+
} // namespace library
333+
} // namespace cutlass
334+
335+
///////////////////////////////////////////////////////////////////////////////////////////////////
336+
337+
"""
338+
339+
#
340+
def __enter__(self):
341+
self.configuration_file = open(self.configuration_path, "w")
342+
self.configuration_file.write(
343+
SubstituteTemplate(
344+
self.header_template, {"configuration_name": self.configuration_name}
345+
)
346+
)
347+
self.operations = []
348+
return self
349+
350+
#
351+
def emit(self, operation):
352+
self.operations.append(operation)
353+
self.configuration_file.write(
354+
SubstituteTemplate(
355+
self.instance_template,
356+
{
357+
"configuration_name": self.configuration_name,
358+
"operation_name": operation.procedural_name(),
359+
"operation_instance": self.instance_emitter.emit(operation),
360+
},
361+
)
362+
)
363+
364+
#
365+
def __exit__(self, exception_type, exception_value, traceback):
366+
367+
self.configuration_file.write(
368+
SubstituteTemplate(
369+
self.configuration_header, {"configuration_name": self.configuration_name}
370+
)
371+
)
372+
373+
for operation in self.operations:
374+
self.configuration_file.write(
375+
SubstituteTemplate(
376+
self.configuration_instance,
377+
{
378+
"configuration_name": self.configuration_name,
379+
"operation_name": operation.procedural_name(),
380+
},
381+
)
382+
)
383+
384+
self.configuration_file.write(self.configuration_epilogue)
385+
self.configuration_file.write(self.epilogue_template)
386+
self.configuration_file.close()

0 commit comments

Comments
 (0)