Skip to content

Commit 9c35fe5

Browse files
committed
fix
1 parent 85b7a10 commit 9c35fe5

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

src/tir/ir/data_type_rewriter.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,8 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
466466
For new_for = GetRef<For>(op);
467467
auto* n = new_for.CopyOnWrite();
468468
n->loop_var = new_loop_var;
469-
n->min = min;
470-
n->extent = extent;
469+
n->min = cast(new_loop_var.dtype(), min);
470+
n->extent = cast(new_loop_var.dtype(), extent);
471471
n->body = new_body;
472472
return std::move(new_for);
473473
} else {

tests/python/unittest/test_meta_schedule_relay_integration.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -391,63 +391,63 @@ def test_meta_schedule_te2primfunc_argument_order_and_lowering():
391391
class _fused_layout_transform:
392392
@T.prim_func
393393
def main( # type: ignore
394-
placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore
395-
T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore
394+
placeholder: T.Buffer[(T.int64(1), T.int64(3), T.int64(16), T.int64(16)), "float32"], # type: ignore
395+
T_layout_trans: T.Buffer[(T.int64(1), T.int64(1), T.int64(16), T.int64(16), T.int64(3)), "float32"], # type: ignore
396396
) -> None: # type: ignore
397397
# function attr dict
398398
T.func_attr({"global_symbol": "main", "tir.noalias": True})
399399
# body
400400
# with T.block("root")
401-
for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3):
401+
for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(1), T.int64(16), T.int64(16), T.int64(3)):
402402
with T.block("T_layout_trans"):
403403
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
404-
T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3])
404+
T.reads(placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3])
405405
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
406406
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
407-
ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore
408-
placeholder[ax0, ax1 * 3 + ax4, ax2, ax3],
407+
ax0 < T.int64(1) and ax1 * T.int64(3) + ax4 < T.int64(3) and ax2 < T.int64(16) and ax3 < T.int64(16), # type: ignore
408+
placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3],
409409
T.float32(0),
410410
dtype="float32",
411411
)
412412

413413
@tvm.script.ir_module
414414
class _fused_layout_transform_1:
415415
@T.prim_func
416-
def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore
416+
def main(placeholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(16), T.int64(16), T.int64(4)), "float32"], T_layout_trans: T.Buffer[(T.int64(1), T.int64(8), T.int64(16), T.int64(16)), "float32"]) -> None: # type: ignore
417417
# function attr dict
418418
T.func_attr({"global_symbol": "main", "tir.noalias": True})
419419
# body
420420
# with T.block("root")
421-
for i0, i1, i2, i3 in T.grid(1, 8, 16, 16):
421+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(8), T.int64(16), T.int64(16)):
422422
with T.block("T_layout_trans"):
423423
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
424-
T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
424+
T.reads(placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)]) # type: ignore
425425
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
426-
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore
426+
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < T.int64(1) and ax1 < T.int64(8) and ax2 < T.int64(16) and ax3 < T.int64(16), placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)], T.float32(0), dtype="float32") # type: ignore
427427

428428
@tvm.script.ir_module
429429
class _fused_nn_contrib_conv2d_NCHWc:
430430
@T.prim_func
431-
def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # type: ignore
431+
def main(placeholder: T.Buffer[(T.int64(1), T.int64(1), T.int64(16), T.int64(16), T.int64(3)), "float32"], placeholder_1: T.Buffer[(T.int64(2), T.int64(1), T.int64(5), T.int64(5), T.int64(3), T.int64(4)), "float32"], conv2d_NCHWc: T.Buffer[(T.int64(1), T.int64(2), T.int64(16), T.int64(16), T.int64(4)), "float32"]) -> None: # type: ignore
432432
# function attr dict
433433
T.func_attr({"global_symbol": "main", "tir.noalias": True})
434434
# body
435435
# with T.block("root")
436-
data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32")
437-
for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3):
436+
data_pad = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(20), T.int64(20), T.int64(3)], dtype="float32")
437+
for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(1), T.int64(20), T.int64(20), T.int64(3)):
438438
with T.block("data_pad"):
439439
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
440-
T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1])
440+
T.reads(placeholder[i0_1, i1_1, i2_1 - T.int64(2), i3_1 - T.int64(2), i4_1])
441441
T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1])
442-
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
443-
for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5):
442+
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(T.int64(2) <= i2_1 and i2_1 < T.int64(18) and T.int64(2) <= i3_1 and i3_1 < T.int64(18), placeholder[i0_1, i1_1, i2_1 - T.int64(2), i3_1 - T.int64(2), i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
443+
for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(2), T.int64(16), T.int64(16), T.int64(4), T.int64(3), T.int64(5), T.int64(5)):
444444
with T.block("conv2d_NCHWc"):
445445
n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
446-
T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore
446+
T.reads(data_pad[n, ic // T.int64(3), oh + kh, ow + kw, ic % T.int64(3)], placeholder_1[oc_chunk, ic // T.int64(3), kh, kw, ic % T.int64(3), oc_block]) # type: ignore
447447
T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
448448
with T.init():
449449
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
450-
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore
450+
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // T.int64(3), oh + kh, ow + kw, ic % T.int64(3)] * placeholder_1[oc_chunk, ic // T.int64(3), kh, kw, ic % T.int64(3), oc_block] # type: ignore
451451

452452
# fmt: on
453453
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument

0 commit comments

Comments
 (0)