Skip to content

Commit 24d0375

Browse files
committed
update tensorize
1 parent c3a1124 commit 24d0375

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

src/tir/schedule/primitive/blockize_tensorize.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
537537
f_update_max_dtype_bits_from_region(block_realize->block->reads);
538538
f_update_max_dtype_bits_from_region(block_realize->block->writes);
539539
ICHECK(index_dtype_bits > 0);
540+
LOG(INFO) << "normalize to " << index_dtype_bits << " bits";
540541
intrin_impl = IndexDataTypeNormalizer(DataType::Int(index_dtype_bits)).Rewrite(intrin_impl);
541542
// Step 2: Structural pattern matching
542543
TensorizeComparator comparator(self->mod, /*assert_mode=*/true);

src/tir/transforms/lower_match_buffer.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class MatchBufferLower : public StmtExprMutator {
195195
// Non-zero elem_offset is ill-defined for non-flat memory.
196196
// If needed in the future, will require `Array<PrimExpr>
197197
// elem_offsets`, with one offset for each flattened index.
198-
Bind(buffer->elem_offset, 0);
198+
Bind(buffer->elem_offset, make_const(buffer->elem_offset.dtype(), 0));
199199
}
200200
}
201201

@@ -206,7 +206,7 @@ class MatchBufferLower : public StmtExprMutator {
206206
if (!buffer->strides.empty()) {
207207
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
208208
if (source_buffer->strides.empty()) {
209-
PrimExpr stride = make_const(DataType::Int(32), 1);
209+
PrimExpr stride = make_const(buffer->strides.back().dtype(), 1);
210210
for (size_t i = buffer->shape.size(); i > 0; --i) {
211211
const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
212212
Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
@@ -230,7 +230,8 @@ class MatchBufferLower : public StmtExprMutator {
230230

231231
void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") {
232232
CHECK_EQ(arg.dtype(), value.dtype())
233-
<< "The data type mismatched: " << arg->dtype << " vs. " << value->dtype;
233+
<< "The data type mismatched: " << arg->dtype << " vs. " << value->dtype
234+
<< " when binding " << arg_name << " to " << value;
234235
// Handle recursive case
235236
value = Substitute(std::move(value), var_map_);
236237
if (arg->IsInstance<VarNode>()) {
@@ -282,4 +283,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB
282283
} // namespace transform
283284

284285
} // namespace tir
285-
} // namespace tvm
286+
} // namespace tvm

tests/python/unittest/test_tir_schedule_tensorize.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -697,34 +697,34 @@ def tensorized_matmul_int64_shape(
697697
]
698698
)
699699
T.writes(C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)])
700-
A_elem_offset = T.var("int32")
701-
B_elem_offset = T.var("int32")
702-
C_elem_offset = T.var("int32")
700+
A_elem_offset = T.var("int64")
701+
B_elem_offset = T.var("int64")
702+
C_elem_offset = T.var("int64")
703703
A_sub = T.match_buffer(
704704
A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
705-
[16, 16],
705+
[T.int64(16), T.int64(16)],
706706
elem_offset=A_elem_offset,
707707
)
708708
B_sub = T.match_buffer(
709709
B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
710-
[16, 16],
710+
[T.int64(16), T.int64(16)],
711711
elem_offset=B_elem_offset,
712712
)
713713
C_sub = T.match_buffer(
714714
C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)],
715-
[16, 16],
715+
[T.int64(16), T.int64(16)],
716716
elem_offset=C_elem_offset,
717717
)
718718
T.evaluate(
719719
T.tvm_mma_sync(
720720
C_sub.data,
721-
T.floordiv(C_sub.elem_offset, 256),
721+
T.floordiv(C_sub.elem_offset, T.int64(256)),
722722
A_sub.data,
723-
T.floordiv(A_sub.elem_offset, 256),
723+
T.floordiv(A_sub.elem_offset, T.int64(256)),
724724
B_sub.data,
725-
T.floordiv(B_sub.elem_offset, 256),
725+
T.floordiv(B_sub.elem_offset, T.int64(256)),
726726
C_sub.data,
727-
T.floordiv(C_sub.elem_offset, 256),
727+
T.floordiv(C_sub.elem_offset, T.int64(256)),
728728
dtype="handle",
729729
)
730730
)

0 commit comments

Comments
 (0)