Skip to content

Commit 54aaec9

Browse files
authored
Refactor index handling in BufferStore and BufferLoad to promote 64-bit integers (#796)
- Updated index processing in `BufferStore` and `BufferLoad` to ensure that integer indices with less than 64 bits are promoted to 64-bit integers. - Introduced a new array to store the modified indices before updating the original indices, enhancing clarity and maintainability of the code.
1 parent 7467f2b commit 54aaec9

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/transform/config_index_bitwidth.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,35 +123,43 @@ class IndexLegalizer : public IRMutatorWithAnalyzer {
123123
auto buffer_store =
124124
Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
125125
auto indices = buffer_store->indices;
126+
Array<PrimExpr> new_indices;
126127
for (auto index : indices) {
127128
if (index->dtype.is_int() && index->dtype.bits() < 64) {
128129
auto int_bound = analyzer_->const_int_bound(index);
129130
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
130131
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
131132
Int64Promoter promoter;
132133
index = promoter(index);
134+
new_indices.push_back(index);
135+
continue;
133136
}
134137
}
138+
new_indices.push_back(index);
135139
}
136-
buffer_store.CopyOnWrite()->indices = indices;
140+
buffer_store.CopyOnWrite()->indices = new_indices;
137141
return std::move(buffer_store);
138142
}
139143

140144
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
141145
auto buffer_load =
142146
Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
143147
auto indices = buffer_load->indices;
148+
Array<PrimExpr> new_indices;
144149
for (auto index : indices) {
145150
if (index->dtype.is_int() && index->dtype.bits() < 64) {
146151
auto int_bound = analyzer_->const_int_bound(index);
147152
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
148153
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
149154
Int64Promoter promoter;
150155
index = promoter(index);
156+
new_indices.push_back(index);
157+
continue;
151158
}
152159
}
160+
new_indices.push_back(index);
153161
}
154-
buffer_load.CopyOnWrite()->indices = indices;
162+
buffer_load.CopyOnWrite()->indices = new_indices;
155163
return std::move(buffer_load);
156164
}
157165
};

0 commit comments

Comments
 (0)