Skip to content

Commit f7d17a1

Browse files
committed
removed im2col profiling for conv2d
1 parent b724f44 commit f7d17a1

File tree

1 file changed

+36
-56
lines changed

1 file changed

+36
-56
lines changed

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -169,59 +169,39 @@ def profile(
169169
If profile_all is False, return immediately after the first applicable kernel is found.
170170
If use_multiprocessing is True, compile all profiler executables in parallel.
171171
"""
172-
if True:
173-
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
174-
N, H, W, IC = d_shape
175-
OC, R, S, _ = w_shape
176-
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))
177-
178-
for op in ops:
179-
op["runtime"] = -1
180-
181-
if profile_all:
182-
self.engine.compile_all(ops, use_multiprocessing)
183-
184-
args = [
185-
"--n=%d" % N,
186-
"--h=%d" % H,
187-
"--w=%d" % W,
188-
"--k=%d" % OC,
189-
"--c=%d" % IC,
190-
"--r=%d" % R,
191-
"--s=%d" % S,
192-
"--pad_h=%d" % padding[0],
193-
"--pad_w=%d" % padding[1],
194-
"--stride_h=%d" % stride[0],
195-
"--stride_w=%d" % stride[1],
196-
"--dilation_h=%d" % dilation[0],
197-
"--dilation_w=%d" % dilation[1],
198-
]
199-
for op in ops:
200-
out = self.engine.evaluate(op, args)
201-
op["runtime"] = out
202-
if out > 0 and profile_all is False:
203-
break
204-
205-
valid_ops = filter(lambda op: op["runtime"] > 0, ops)
206-
output = sorted(valid_ops, key=lambda i: i["runtime"])
207-
# self.cache[(M, N, K)] = output[0]
208-
return output[0]
209-
210-
else:
211-
B, _, _, IC = d_shape
212-
OC, R, S, _ = w_shape
213-
_, P, Q, _ = out_shape
214-
215-
M = B * P * Q
216-
N = OC
217-
K = R * S * IC
218-
219-
gemm_profile_result = self.gemm_profiler.profile(
220-
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
221-
)
222-
223-
tile_description = gemm_profile_result["tile_description"]
224-
alignment = gemm_profile_result["alignment"]
225-
data_type = gemm_profile_result["data_type"]
226-
227-
return create_conv2d_operator([tile_description], data_type, [alignment])[0]
172+
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
173+
N, H, W, IC = d_shape
174+
OC, R, S, _ = w_shape
175+
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))
176+
177+
for op in ops:
178+
op["runtime"] = -1
179+
180+
if profile_all:
181+
self.engine.compile_all(ops, use_multiprocessing)
182+
183+
args = [
184+
"--n=%d" % N,
185+
"--h=%d" % H,
186+
"--w=%d" % W,
187+
"--k=%d" % OC,
188+
"--c=%d" % IC,
189+
"--r=%d" % R,
190+
"--s=%d" % S,
191+
"--pad_h=%d" % padding[0],
192+
"--pad_w=%d" % padding[1],
193+
"--stride_h=%d" % stride[0],
194+
"--stride_w=%d" % stride[1],
195+
"--dilation_h=%d" % dilation[0],
196+
"--dilation_w=%d" % dilation[1],
197+
]
198+
for op in ops:
199+
out = self.engine.evaluate(op, args)
200+
op["runtime"] = out
201+
if out > 0 and profile_all is False:
202+
break
203+
204+
valid_ops = filter(lambda op: op["runtime"] > 0, ops)
205+
output = sorted(valid_ops, key=lambda i: i["runtime"])
206+
# self.cache[(M, N, K)] = output[0]
207+
return output[0]

0 commit comments

Comments
 (0)