Skip to content

Commit bd7f1f8

Browse files
authored
[TIR] Validate tir::Buffer axis_separators on construction (#17219)
* [TIR] Validate tir::Buffer axis_separators on construction Prior to this commit, the `axis_separators` field of a TIR buffer wasn't validated until the `tir.FlattenBuffer` legalization pass. Delaying the error until this point makes it difficult to determine where it invalid `axis_separators` were initially defined. This commit updates the `tir::Buffer` constructor to validate the `axis_separators` field immediately, allowing these invalid values to be caught on construction. Closes #17215 * Update metaschedule primitive to only set axis_separators of alloc * Allow axis separators to be increasing, rather than strictly increasing
1 parent cd09ab6 commit bd7f1f8

File tree

4 files changed

+51
-25
lines changed

4 files changed

+51
-25
lines changed

src/tir/ir/buffer.cc

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -334,24 +334,37 @@ inline Array<PrimExpr> BufferOffset(const BufferNode* n, Array<PrimExpr> index,
334334
return offsets;
335335
}
336336

337-
Buffer Buffer::GetFlattenedBuffer() const {
338-
auto self = operator->();
339-
337+
static void ValidateAxisSeparators(const Array<IntImm>& axis_separators, size_t buffer_dim) {
340338
// These checks ensure that all output axes contain at least one
341339
// input axis.
342-
for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
343-
auto sep = self->axis_separators[i]->value;
344-
auto next_sep = self->axis_separators[i + 1]->value;
345-
ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order.";
346-
}
347-
if (self->axis_separators.size()) {
348-
auto first_sep = self->axis_separators[0]->value;
349-
ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, "
350-
<< "so that first output axis contains at least one input axis";
351-
auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value;
352-
ICHECK_LT(last_sep, self->shape.size())
353-
<< "Last output axis must contain at least one input axis.";
340+
for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
341+
auto sep = axis_separators[i]->value;
342+
auto next_sep = axis_separators[i + 1]->value;
343+
CHECK_LE(sep, next_sep) << "ValueError: "
344+
<< "Axis separators must be in increasing order, "
345+
<< "but axis_separators[" << i << "] = " << sep
346+
<< " is greater than or equal to axis_separators[" << (i + 1)
347+
<< "] = " << next_sep << ".";
348+
}
349+
if (axis_separators.size()) {
350+
auto first_sep = axis_separators[0]->value;
351+
CHECK_GE(first_sep, 0) << "ValueError: "
352+
<< "All axis separators must be non-negative. "
353+
<< "However, the axis_separators[0] = " << first_sep;
354+
auto last_sep = axis_separators[axis_separators.size() - 1]->value;
355+
CHECK_LE(last_sep, buffer_dim)
356+
<< "ValueError: "
357+
<< "All axis separators must be within the range "
358+
<< "0 <= sep <= buffer_dim. "
359+
<< "However, the last axis_separators[" << (axis_separators.size() - 1)
360+
<< "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim;
354361
}
362+
}
363+
364+
Buffer Buffer::GetFlattenedBuffer() const {
365+
auto self = operator->();
366+
367+
ValidateAxisSeparators(self->axis_separators, self->shape.size());
355368

356369
Array<PrimExpr> output_shape;
357370
if (self->strides.size()) {
@@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
565578
ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
566579
<< "Variable " << data->name_hint << " does not point to a primitive.";
567580

581+
ValidateAxisSeparators(axis_separators, shape.size());
582+
568583
auto n = make_object<BufferNode>();
569584
n->data = std::move(data);
570585
n->dtype = dtype;

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
14851485
if (it != buffer_var_map_.end()) {
14861486
const Buffer& new_source_buffer = it->second;
14871487
Buffer new_target_buffer = match_buffer->buffer;
1488-
new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
1489-
if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
1490-
LOG(WARNING)
1491-
<< "Target buffer in match_buffer doesn't have the same dimensionality as its source "
1492-
"buffer. `axis_separators` for the target buffer might be incorrect.";
1488+
1489+
if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) {
1490+
new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
1491+
} else {
1492+
new_target_buffer.CopyOnWrite()->axis_separators =
1493+
Array<IntImm>(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0));
1494+
LOG(WARNING) << "Buffer view " << new_target_buffer
1495+
<< " has different dimensionality than backing buffer " << new_source_buffer
1496+
<< ". The `axis_separators` for " << new_target_buffer << "."
1497+
<< "`axis_separators` for the view might be incorrect.";
14931498
}
14941499
buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
14951500
return MatchBufferRegion(new_target_buffer,

tests/python/tir-base/test_tir_buffer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod():
109109
A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
110110

111111
def assert_simplified_equal(index_simplified, index_direct):
112-
tvm.ir.assert_structural_equal(
113-
index_simplified, index_direct
114-
), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct)
112+
(
113+
tvm.ir.assert_structural_equal(index_simplified, index_direct),
114+
"index_simplified=%s, index_direct=%s" % (index_simplified, index_direct),
115+
)
115116

116117
idxd = tvm.tir.indexdiv
117118
idxm = tvm.tir.indexmod
@@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators():
276277
tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32])
277278

278279

280+
def test_invalid_axis_separators_raises_exception():
281+
with pytest.raises(ValueError):
282+
tvm.tir.decl_buffer([1], axis_separators=[1, 2])
283+
284+
279285
if __name__ == "__main__":
280286
tvm.testing.main()

tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "flo
9494
for i, j in T.grid(128, 128):
9595
with T.block("B"):
9696
vi, vj = T.axis.remap("SS", [i, j])
97-
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
97+
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0])
9898
B_subregion0[()] = A[vi, vj] * T.float32(2)
9999
for i, j in T.grid(128, 128):
100100
with T.block("C"):
101101
vi, vj = T.axis.remap("SS", [i, j])
102-
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
102+
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0])
103103
C[vi, vj] = B_subregion1[()] + T.float32(1)
104104

105105

0 commit comments

Comments
 (0)