Skip to content

Commit dd42ef2

Browse files
masahicomaniac
andauthored
[CUTLASS] Add conv2d profiler (#9737)
* Add cutlass conv2d profiler commit 1c0bbb2 Author: Masahiro Masuda <[email protected]> Date: Sun Dec 12 18:29:03 2021 +0900 fix lint commit 463574c Author: Masahiro Masuda <[email protected]> Date: Sun Dec 12 17:28:38 2021 +0900 fixed conv2d check commit 588c5ab Author: Masahiro Masuda <[email protected]> Date: Sun Dec 12 15:05:27 2021 +0900 update test commit a447b57 Author: Masahiro Masuda <[email protected]> Date: Sun Dec 12 14:54:52 2021 +0900 speed up profiling by removing initialization commit 93cd039 Author: Masahiro Masuda <[email protected]> Date: Sun Dec 12 08:26:29 2021 +0900 fixed nhwc cudnn depthwise conv commit 6db7172 Author: Masahiro Masuda <[email protected]> Date: Sat Dec 11 15:39:05 2021 +0900 add cache commit f7d17a1 Author: Masahiro Masuda <[email protected]> Date: Sat Dec 11 15:05:38 2021 +0900 removed im2col profiling for conv2d commit b724f44 Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 22:57:54 2021 +0900 black commit fe4687b Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 22:49:13 2021 +0900 fixed cmd arguement commit ab114f5 Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 22:22:19 2021 +0900 conv2d profiler working commit 49ee61f Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 20:26:15 2021 +0900 add conv2d profiler commit 49e2c89 Author: Masahiro Masuda <[email protected]> Date: Sun Dec 12 08:03:36 2021 +0900 do not offload depthwise conv2d commit cd83677 Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 13:20:01 2021 +0900 lint fix commit 870823c Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 12:54:38 2021 +0900 add comment on IC == 3 case commit 6b780db Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 12:48:33 2021 +0900 check align on N dim commit 308c4da Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 12:34:42 2021 +0900 fixed check functions for fused cases, run infer type before mergecomposite commit 8d6a1bf Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 12:10:59 2021 +0900 test IC=3 convolution commit ffce47d Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 12:10:16 2021 +0900 use align1 kernel for unusual channel cases (IC = 3 etc) commit 6cdf205 Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 12:06:56 2021 +0900 add dtype and layout check in parttern match commit 7743cc6 Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 10:40:53 2021 +0900 add sm75 kernels to sm80 profilings commit efceccb Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 10:40:42 2021 +0900 skip legalize when batch size is dynamic commit 65fbc0a Author: Masahiro Masuda <[email protected]> Date: Fri Dec 10 10:36:36 2021 +0900 bug fix in im2col encoding * minor fix * lint fix * allow autotvm NCHW depthwise conv2d schedule even if -libs=cudnn * Update python/tvm/contrib/cutlass/gen_conv2d.py Co-authored-by: Cody Yu <[email protected]> * simplify processing profiler outputs * more simplify * fix runtime check Co-authored-by: Cody Yu <[email protected]>
1 parent 21abb6e commit dd42ef2

File tree

7 files changed

+265
-50
lines changed

7 files changed

+265
-50
lines changed

python/tvm/contrib/cutlass/build.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def handle_conv2d(
184184
op_type,
185185
d_shape,
186186
w_shape,
187-
out_shape,
187+
padding,
188+
strides,
189+
dilation,
188190
out_dtype,
189191
profile_all,
190192
use_multiprocessing,
@@ -197,7 +199,9 @@ def handle_conv2d(
197199
out = cutlass_profiler.profile(
198200
d_shape,
199201
w_shape,
200-
out_shape,
202+
padding,
203+
strides,
204+
dilation,
201205
out_dtype,
202206
profile_all=profile_all,
203207
use_multiprocessing=use_multiprocessing,
@@ -278,7 +282,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
278282
op_type,
279283
arg0_shape,
280284
arg1_shape,
281-
annotator.signature["ret_shape"],
285+
annotator.op_attrs.padding,
286+
annotator.op_attrs.strides,
287+
annotator.op_attrs.dilation,
282288
out_dtype,
283289
profile_all,
284290
use_multiprocessing,
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=import-outside-toplevel, invalid-name
18+
"""Instantiate a C++ source for profiling CUTLASS kernels."""
19+
20+
21+
class Conv2dProfilerEmitter(object):
22+
"""Emit a C++ source for profiling CUTLASS kernels."""
23+
24+
def __init__(self):
25+
from jinja2 import Template
26+
27+
self.template = Template(
28+
"""
29+
#include <iostream>
30+
#include "cutlass/cutlass.h"
31+
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
32+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
33+
#include "cutlass/util/command_line.h"
34+
#include "cutlass/util/host_tensor.h"
35+
#include "cutlass/util/reference/host/tensor_fill.h"
36+
37+
#define CUTLASS_CHECK(status) \
38+
{ \
39+
cutlass::Status error = status; \
40+
if (error != cutlass::Status::kSuccess) { \
41+
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
42+
<< std::endl; \
43+
exit(EXIT_FAILURE); \
44+
} \
45+
}
46+
47+
{{OperatorDef}}
48+
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<{{OperatorName}}>;
49+
50+
struct Options {
51+
cutlass::Tensor4DCoord input_size;
52+
cutlass::Tensor4DCoord filter_size;
53+
cutlass::Tensor4DCoord padding;
54+
cutlass::MatrixCoord conv_stride;
55+
cutlass::MatrixCoord dilation;
56+
57+
void parse(int argc, char const **args) {
58+
cutlass::CommandLine cmd(argc, args);
59+
cmd.get_cmd_line_argument("n", input_size.n());
60+
cmd.get_cmd_line_argument("h", input_size.h());
61+
cmd.get_cmd_line_argument("w", input_size.w());
62+
cmd.get_cmd_line_argument("c", input_size.c());
63+
cmd.get_cmd_line_argument("k", filter_size.n());
64+
cmd.get_cmd_line_argument("r", filter_size.h());
65+
cmd.get_cmd_line_argument("s", filter_size.w());
66+
int pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w;
67+
cmd.get_cmd_line_argument("pad_h", pad_h);
68+
cmd.get_cmd_line_argument("pad_w", pad_w);
69+
cmd.get_cmd_line_argument("stride_h", stride_h);
70+
cmd.get_cmd_line_argument("stride_w", stride_w);
71+
cmd.get_cmd_line_argument("dilation_h", dilation_h);
72+
cmd.get_cmd_line_argument("dilation_w", dilation_w);
73+
filter_size.c() = input_size.c();
74+
padding = {pad_h, pad_h, pad_w, pad_w};
75+
conv_stride = {stride_h, stride_w};
76+
dilation = {dilation_h, dilation_w};
77+
}
78+
79+
cutlass::Tensor4DCoord output_size() const {
80+
auto dilated_h = (filter_size.h() - 1) * dilation.row() + 1;
81+
auto dilated_w = (filter_size.w() - 1) * dilation.column() + 1;
82+
auto h = (input_size.h() + padding.n() + padding.h() - dilated_h) / conv_stride.row() + 1;
83+
auto w = (input_size.w() + padding.w() + padding.c() - dilated_w) / conv_stride.column() + 1;
84+
return cutlass::Tensor4DCoord(input_size.n(), h, w, filter_size.n());
85+
}
86+
};
87+
88+
double profile_convolution(Options const &options) {
89+
using ElementOutput = typename ImplicitGemm::ElementC;
90+
using ElementInputA = typename ImplicitGemm::ElementA;
91+
using ElementInputB = typename ImplicitGemm::ElementB;
92+
auto oshape = options.output_size();
93+
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(options.input_size);
94+
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(options.filter_size);
95+
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(oshape);
96+
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(oshape);
97+
98+
cutlass::conv::Conv2dProblemSize problem_size(
99+
options.input_size,
100+
options.filter_size,
101+
options.padding,
102+
options.conv_stride,
103+
options.dilation,
104+
options.output_size(),
105+
cutlass::conv::Mode::kCrossCorrelation,
106+
1
107+
);
108+
109+
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
110+
typename ImplicitGemm::Arguments arguments{
111+
problem_size,
112+
tensor_a.device_ref(),
113+
tensor_b.device_ref(),
114+
tensor_c.device_ref(),
115+
tensor_c.device_ref(),
116+
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
117+
};
118+
119+
ImplicitGemm implicit_gemm_op;
120+
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
121+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
122+
auto status = implicit_gemm_op.can_implement(arguments);
123+
CUTLASS_CHECK(status);
124+
125+
status = implicit_gemm_op.initialize(arguments, workspace.get());
126+
CUTLASS_CHECK(status);
127+
status = implicit_gemm_op();
128+
CUTLASS_CHECK(status);
129+
130+
cudaEvent_t events[2];
131+
for (auto & event : events) {
132+
cudaEventCreate(&event);
133+
}
134+
cudaEventRecord(events[0]);
135+
136+
for (int iteration = 0; iteration < 100; ++iteration) {
137+
auto status = implicit_gemm_op();
138+
CUTLASS_CHECK(status);
139+
}
140+
141+
cudaEventRecord(events[1]);
142+
cudaEventSynchronize(events[1]);
143+
float runtime_ms = 0;
144+
cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
145+
146+
for (auto event : events) {
147+
(void)cudaEventDestroy(event);
148+
}
149+
return double(runtime_ms) / 100.0;
150+
}
151+
152+
int main(int argc, char const **args) {
153+
Options options;
154+
options.parse(argc, args);
155+
std::cout << profile_convolution(options) << std::endl;
156+
return 0;
157+
}
158+
"""
159+
)
160+
161+
def emit(self, op_def, op_name):
162+
src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
163+
return src

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Conv2d kernel generator and profiler for CUTLASS."""
19+
import re
1920
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
2021
from .gen_gemm import CutlassGemmProfiler
22+
from .conv2d_profiler import Conv2dProfilerEmitter
23+
from .gen_tensor_op import (
24+
ProfilerEngine,
25+
GENERATOR_FUNC_TABLE,
26+
)
2127
from .library import (
2228
EpilogueFunctor,
2329
SwizzlingFunctor,
@@ -39,6 +45,7 @@ def create_conv2d_operator(
3945
ret = []
4046

4147
kernel_emitter = EmitConv2dInstance()
48+
profiler_emitter = Conv2dProfilerEmitter()
4249

4350
element_a, element_b, element_c, element_epilogue = data_type
4451
iterator_algorithms = [IteratorAlgorithm.Optimized]
@@ -72,9 +79,9 @@ def create_conv2d_operator(
7279
swizzling_functor_,
7380
)
7481

75-
# TODO(masahi): Add profiler source here
7682
op_entry["opdef"] = kernel_emitter.emit(op)
7783
op_entry["op"] = op
84+
op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name())
7885
op_entry["name"] = op.procedural_name()
7986
op_entry["runtime"] = 9999999
8087

@@ -113,6 +120,9 @@ class CutlassConv2DProfiler:
113120
def __init__(self, sm, cutlass_path, binary_path):
114121
self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path)
115122
self.sm = sm
123+
assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm
124+
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
125+
self.cache = {}
116126

117127
def get_default(self, out_dtype):
118128
gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
@@ -121,27 +131,67 @@ def get_default(self, out_dtype):
121131
data_type = gemm_profile_result["data_type"]
122132
return create_conv2d_operator([tile_description], data_type, [alignment])[0]
123133

134+
def check_align(self, op_name, C, K):
135+
"""Filter out kernels that cannot be supported."""
136+
aligns = re.findall(r"align[1|2|4|8]", op_name)
137+
assert len(aligns) == 1
138+
align = int(aligns[0][-1])
139+
return all([dim % align == 0 for dim in [C, K]])
140+
124141
def profile(
125-
self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False
142+
self,
143+
d_shape,
144+
w_shape,
145+
padding,
146+
stride,
147+
dilation,
148+
out_dtype,
149+
profile_all=True,
150+
use_multiprocessing=False,
126151
):
127152
"""Profile and select the best kernel from candidate kernels.
128153
If profile_all is False, return immediately after the first applicable kernel is found.
129154
If use_multiprocessing is True, compile all profiler executables in parallel.
130155
"""
131-
B, _, _, IC = d_shape
156+
N, H, W, IC = d_shape
132157
OC, R, S, _ = w_shape
133-
_, P, Q, _ = out_shape
158+
workload = (
159+
N,
160+
H,
161+
W,
162+
IC,
163+
OC,
164+
R,
165+
S,
166+
padding[0],
167+
padding[1],
168+
stride[0],
169+
stride[1],
170+
dilation[0],
171+
dilation[1],
172+
)
134173

135-
M = B * P * Q
136-
N = OC
137-
K = R * S * IC
174+
if workload in self.cache:
175+
return self.cache[workload]
138176

139-
gemm_profile_result = self.gemm_profiler.profile(
140-
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
141-
)
177+
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
178+
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))
142179

143-
tile_description = gemm_profile_result["tile_description"]
144-
alignment = gemm_profile_result["alignment"]
145-
data_type = gemm_profile_result["data_type"]
180+
if profile_all:
181+
self.engine.compile_all(ops, use_multiprocessing)
146182

147-
return create_conv2d_operator([tile_description], data_type, [alignment])[0]
183+
args = (
184+
"--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d --pad_w=%d "
185+
"--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d"
186+
) % workload
187+
188+
for op in ops:
189+
out = self.engine.evaluate(op, args.split(" "))
190+
op["runtime"] = out
191+
if out < float("inf") and not profile_all:
192+
self.cache[workload] = op
193+
return op
194+
195+
output = min(ops, key=lambda i: i["runtime"])
196+
self.cache[workload] = output
197+
return output

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from .gemm_profiler import GemmProfilerEmitter
2323
from .gen_tensor_op import (
2424
ProfilerEngine,
25-
generate_sm75_tensor_op_1688,
26-
generate_sm80_tensor_op_16816,
25+
GENERATOR_FUNC_TABLE,
2726
)
2827
from .library import (
2928
EpilogueFunctor,
@@ -132,12 +131,6 @@ def create_gemm_operator(
132131
return ret
133132

134133

135-
GENERATOR_FUNC_TABLE = {
136-
75: generate_sm75_tensor_op_1688,
137-
80: generate_sm80_tensor_op_16816,
138-
}
139-
140-
141134
# TODO(masahi): A sensible way to pick reasonable default kernels
142135
DEFAULT_KERNELS = {
143136
75: {
@@ -199,19 +192,16 @@ def profile(
199192
)
200193
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))
201194

202-
for op in ops:
203-
op["runtime"] = -1
204-
205195
if profile_all:
206196
self.engine.compile_all(ops, use_multiprocessing)
207197

208198
for op in ops:
209199
out = self.engine.evaluate(op, [M, N, K])
210200
op["runtime"] = out
211-
if out > 0 and profile_all is False:
212-
break
201+
if out < float("inf") and not profile_all:
202+
self.cache[(M, N, K)] = op
203+
return op
213204

214-
valid_ops = filter(lambda op: op["runtime"] > 0, ops)
215-
output = sorted(valid_ops, key=lambda i: i["runtime"])
216-
self.cache[(M, N, K)] = output[0]
217-
return output[0]
205+
output = min(ops, key=lambda i: i["runtime"])
206+
self.cache[(M, N, K)] = output
207+
return output

0 commit comments

Comments
 (0)