Skip to content

Commit fa834f6

Browse files
authored
Prevent simplifing unit IterVar in CreatePrimFunc (#11292)
Simplifying unit iter vars in CreatePrimFunc changes semantics of the PrimFunc, which need different handling in analysis. This reverts commit 26cefab.
1 parent e7f1224 commit fa834f6

File tree

3 files changed

+31
-39
lines changed

3 files changed

+31
-39
lines changed

src/te/operation/create_primfunc.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
142142

143143
const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min);
144144
const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent);
145-
Range iter_var_dom = Range::FromMinExtent(dom_min, dom_extent);
146-
analyzer->Bind(new_var, iter_var_dom);
147-
iter_vars.push_back(IterVar(iter_var_dom, new_var, iter_var->iter_type, iter_var->thread_tag,
148-
iter_var->span));
145+
iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var,
146+
iter_var->iter_type, iter_var->thread_tag, iter_var->span));
149147
}
150148
};
151149
f_push_block_vars(compute_op->axis);

tests/python/unittest/test_meta_schedule_tune_relay.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,14 @@ def main( # type: ignore
6262
for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3):
6363
with T.block("T_layout_trans"):
6464
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
65-
T.reads(placeholder[0, ax4, ax2, ax3])
65+
T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3])
6666
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
67-
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4, ax2, ax3]
67+
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
68+
ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore
69+
placeholder[ax0, ax1 * 3 + ax4, ax2, ax3],
70+
T.float32(0),
71+
dtype="float32",
72+
)
6873

6974

7075
@tvm.script.ir_module
@@ -79,19 +84,18 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B
7984
for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3):
8085
with T.block("data_pad"):
8186
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
82-
T.reads(placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1]) # type: ignore
87+
T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1])
8388
T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1])
84-
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
89+
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
8590
for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5):
8691
with T.block("conv2d_NCHWc"):
8792
n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
88-
T.reads(data_pad[0, 0, oh + kh, ow + kw, ic], placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block]) # type: ignore
93+
T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore
8994
T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
9095
T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]})
9196
with T.init():
9297
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
93-
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[0, 0, oh + kh, ow + kw, ic] * placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block] # type: ignore
94-
98+
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore
9599

96100
@tvm.script.ir_module
97101
class tvmgen_default_fused_layout_transform_1:
@@ -104,9 +108,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.
104108
for i0, i1, i2, i3 in T.grid(1, 8, 16, 16):
105109
with T.block("T_layout_trans"):
106110
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
107-
T.reads(placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
111+
T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
108112
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
109-
T_layout_trans[ax0, ax1, ax2, ax3] = placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4] # type: ignore
113+
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore
110114

111115
# fmt: on
112116
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument

tests/python/unittest/test_te_create_primfunc.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=missing-function-docstring,missing-module-docstring
18-
import sys
19-
import pytest
2018
import numpy as np
2119
import tvm
2220
import tvm.testing
@@ -526,28 +524,20 @@ def test_int64_indices():
526524
assert loop.extent.dtype == "int64"
527525

528526

529-
def te_reshape():
530-
A = te.placeholder((128, 128), name="A")
531-
B = topi.reshape(A, [8, 16, 128])
532-
return [A, B]
533-
534-
535-
@T.prim_func
536-
def tir_reshape(
537-
A: T.Buffer[(128, 128), "float32"], T_reshape: T.Buffer[(8, 16, 128), "float32"]
538-
) -> None:
539-
T.func_attr({"global_symbol": "main", "tir.noalias": True})
540-
for i0, i1, i2 in T.grid(8, 16, 128):
541-
with T.block("T_reshape"):
542-
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
543-
T.reads(A[ax0 * 16 + ax1, ax2])
544-
T.writes(T_reshape[ax0, ax1, ax2])
545-
T_reshape[ax0, ax1, ax2] = A[ax0 * 16 + ax1, ax2]
546-
547-
548-
def test_reshape():
549-
_check_workload(te_reshape, tir_reshape)
550-
551-
552527
if __name__ == "__main__":
553-
sys.exit(pytest.main([__file__] + sys.argv[1:]))
528+
test_unique_name_complete_block()
529+
test_unique_name_reduction_block()
530+
test_matmul()
531+
test_element_wise()
532+
test_conv2d()
533+
test_multi_output()
534+
test_extern()
535+
test_arg_order()
536+
test_error_reporting()
537+
test_constant()
538+
test_select_simplify()
539+
test_tensor_attr()
540+
test_tensor_layout_attr()
541+
test_argmax_idx_val()
542+
test_argmax_val_idx()
543+
test_int64_indices()

0 commit comments

Comments
 (0)