diff --git a/src/transform/config_index_bitwidth.cc b/src/transform/config_index_bitwidth.cc index cc87cce05..58ca0da7f 100644 --- a/src/transform/config_index_bitwidth.cc +++ b/src/transform/config_index_bitwidth.cc @@ -123,6 +123,7 @@ class IndexLegalizer : public IRMutatorWithAnalyzer { auto buffer_store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); auto indices = buffer_store->indices; + Array new_indices; for (auto index : indices) { if (index->dtype.is_int() && index->dtype.bits() < 64) { auto int_bound = analyzer_->const_int_bound(index); @@ -130,10 +131,13 @@ class IndexLegalizer : public IRMutatorWithAnalyzer { int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { Int64Promoter promoter; index = promoter(index); + new_indices.push_back(index); + continue; } } + new_indices.push_back(index); } - buffer_store.CopyOnWrite()->indices = indices; + buffer_store.CopyOnWrite()->indices = new_indices; return std::move(buffer_store); } @@ -141,6 +145,7 @@ class IndexLegalizer : public IRMutatorWithAnalyzer { auto buffer_load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); auto indices = buffer_load->indices; + Array new_indices; for (auto index : indices) { if (index->dtype.is_int() && index->dtype.bits() < 64) { auto int_bound = analyzer_->const_int_bound(index); @@ -148,10 +153,13 @@ class IndexLegalizer : public IRMutatorWithAnalyzer { int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { Int64Promoter promoter; index = promoter(index); + new_indices.push_back(index); + continue; } } + new_indices.push_back(index); } - buffer_load.CopyOnWrite()->indices = indices; + buffer_load.CopyOnWrite()->indices = new_indices; return std::move(buffer_load); } };