Skip to content

Commit fe4687b

Browse files
committed
fixed cmd arguement
1 parent ab114f5 commit fe4687b

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ def profile(
157157
self,
158158
d_shape,
159159
w_shape,
160+
out_shape,
160161
padding,
161162
stride,
162163
dilation,
163-
out_shape,
164164
out_dtype,
165165
profile_all=True,
166166
use_multiprocessing=False,
@@ -188,13 +188,13 @@ def profile(
188188
"--k=%d" % OC,
189189
"--c=%d" % IC,
190190
"--r=%d" % R,
191-
"--s=%d" % IC,
191+
"--s=%d" % S,
192192
"--pad_h=%d" % padding[0],
193-
"--pad_w=%d," % padding[1],
193+
"--pad_w=%d" % padding[1],
194194
"--stride_h=%d" % stride[0],
195-
"--stride_w=%d" % stride[0],
195+
"--stride_w=%d" % stride[1],
196196
"--dilation_h=%d" % dilation[0],
197-
"--dilation_w=%d" % dilation[0],
197+
"--dilation_w=%d" % dilation[1],
198198
]
199199
for op in ops:
200200
out = self.engine.evaluate(op, args)
@@ -224,6 +224,4 @@ def profile(
224224
alignment = gemm_profile_result["alignment"]
225225
data_type = gemm_profile_result["data_type"]
226226

227-
out = create_conv2d_operator([tile_description], data_type, [alignment])[0]
228-
# print(out["src"])
229-
return out
227+
return create_conv2d_operator([tile_description], data_type, [alignment])[0]

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def _compile(self, op):
190190
fi.write(op["src"])
191191
fi.close()
192192
cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
193-
print(cmd)
194193
os.system(cmd)
195194
os.unlink(fi.name)
196195

@@ -213,7 +212,6 @@ def evaluate(self, op, args):
213212
for arg in args:
214213
cmd.append(str(arg))
215214
try:
216-
print("".join(cmd))
217215
sp = subprocess.run(cmd, capture_output=True, check=True)
218216
rt = float(sp.stdout)
219217
logger.info("%s, %f", op_name, rt)

0 commit comments

Comments
 (0)