Skip to content

Commit b09e72b

Browse files
authored
[TIR] Legalize dtype of constants in IndexMap (#14385)
Previously, the legalization was only handled by propagating the dtype of the indices to the transformed indices. As a result, output indices whose value did not depend on the input index would be left with the incorrect dtype.
1 parent ad6fbec commit b09e72b

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,8 +1095,17 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&
10951095

10961096
Array<Var> initial_indices;
10971097
Map<Var, PrimExpr> var_map;
1098+
std::optional<DataType> index_dtype = std::nullopt;
10981099

10991100
for (size_t i = 0; i < args.size(); ++i) {
1101+
if (index_dtype.has_value()) {
1102+
ICHECK_EQ(*index_dtype, args[i]->dtype)
1103+
<< "Buffer index " << args[i] << " has dtype " << args[i]->dtype
1104+
<< ", but previous index for the same buffer access used index type " << *index_dtype;
1105+
} else {
1106+
index_dtype = args[i]->dtype;
1107+
}
1108+
11001109
if (args[i]->dtype != initial_indices_orig[i].dtype()) {
11011110
auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype);
11021111
initial_indices.push_back(new_idx);
@@ -1108,8 +1117,13 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&
11081117

11091118
if (!var_map.empty()) {
11101119
auto final_indices = index_map->final_indices.Map([&](PrimExpr index) {
1111-
return SubstituteWithDataTypeLegalization(index,
1112-
[&](const Var& var) { return var_map.Get(var); });
1120+
if (auto* ptr = index.as<IntImmNode>()) {
1121+
ICHECK(index_dtype.has_value());
1122+
return tir::make_const(*index_dtype, ptr->value);
1123+
} else {
1124+
return SubstituteWithDataTypeLegalization(index,
1125+
[&](const Var& var) { return var_map.Get(var); });
1126+
}
11131127
});
11141128
Optional<IndexMap> opt_inverse_index_map =
11151129
Downcast<Optional<IndexMap>>(index_map->inverse_index_map);

tests/python/unittest/test_tir_schedule_transform_layout.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,5 +1049,41 @@ def func(A: T.Buffer(T.int64(58), "int32")):
10491049
)
10501050

10511051

1052+
def test_index_map_dtype_legalize_with_constant():
1053+
"""Legalization of inverse containing a constant output
1054+
1055+
The index map `lambda i,j: [i, j//8, j % 8]` has an inverse `lambda i,j,k: [i, 8*j+k]`.
1056+
"""
1057+
1058+
@T.prim_func
1059+
def func(A: T.Buffer(T.int64(16), "int32")):
1060+
for i in T.grid(T.int64(16)):
1061+
with T.block("block"):
1062+
vi = T.axis.remap("S", [i])
1063+
A[vi] = 0
1064+
1065+
sch = tir.Schedule(func)
1066+
1067+
# Triggering the error requires an IndexMap that introduces padding
1068+
func = lambda i: [
1069+
# And a constant to be one of the output indices.
1070+
tir.const(0, i.dtype),
1071+
(i + 1) // 8,
1072+
(i + 1) % 8,
1073+
]
1074+
1075+
# Previously, the legalization was only handled by propagating the
1076+
# dtype of the indices to the transformed indices. As a result,
1077+
# output indices whose value did not depend on the input index
1078+
# would be left with the incorrect dtype.
1079+
1080+
# Prior to the bugfix, this resulted in the following error is
1081+
# raised from the IterVar constructor.
1082+
#
1083+
# TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) :
1084+
# The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32)
1085+
sch.transform_layout(block="block", buffer="A", index_map=func, pad_value=0)
1086+
1087+
10521088
if __name__ == "__main__":
10531089
tvm.testing.main()

0 commit comments

Comments
 (0)