Skip to content

Commit c3a1124

Browse files
committed
fix
1 parent 8c0b009 commit c3a1124

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

src/tir/schedule/primitive/blockize_tensorize.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include <tvm/tir/data_type_rewriter.h>
20+
1921
#include <functional>
2022

2123
#include "../ir_comparator.h"
@@ -523,6 +525,19 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
523525
}
524526
PrimFunc intrin_desc = intrin->desc;
525527
PrimFunc intrin_impl = DeepCopy(intrin->impl);
528+
529+
int index_dtype_bits = -1;
530+
auto f_update_max_dtype_bits_from_region = [&](const Array<BufferRegion>& buffer_regions) {
531+
for (const BufferRegion& buffer_region : buffer_regions) {
532+
for (const auto& range : buffer_region->region) {
533+
index_dtype_bits = std::max(index_dtype_bits, range->min.dtype().bits());
534+
}
535+
}
536+
};
537+
f_update_max_dtype_bits_from_region(block_realize->block->reads);
538+
f_update_max_dtype_bits_from_region(block_realize->block->writes);
539+
ICHECK(index_dtype_bits > 0);
540+
intrin_impl = IndexDataTypeNormalizer(DataType::Int(index_dtype_bits)).Rewrite(intrin_impl);
526541
// Step 2: Structural pattern matching
527542
TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
528543
comparator.VisitStmt(block_realize, intrin_desc->body);

src/tir/transforms/lower_match_buffer.cc

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,17 @@
2323
*/
2424

2525
#include <tvm/arith/analyzer.h>
26-
#include <tvm/tir/data_type_rewriter.h>
2726
#include <tvm/tir/function.h>
2827
#include <tvm/tir/op.h>
2928
#include <tvm/tir/stmt_functor.h>
3029
#include <tvm/tir/transform.h>
3130

32-
#include "../../printer/text_printer.h"
3331
#include "../ir/functor_common.h"
3432
#include "ir_utils.h"
3533

3634
namespace tvm {
3735
namespace tir {
38-
class MatchBufferLower : public DataTypeLegalizer {
36+
class MatchBufferLower : public StmtExprMutator {
3937
public:
4038
explicit MatchBufferLower(const PrimFunc& func) {
4139
for (const Var& param : func->params) {
@@ -190,14 +188,14 @@ class MatchBufferLower : public DataTypeLegalizer {
190188
Array<PrimExpr> buffer_start_indices = source_buffer->ElemOffset(indices);
191189
if (buffer_start_indices.size() == 1) {
192190
Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset");
193-
CHECK(analyzer_.CanProve(truncmod(buffer_start_indices[0], buffer->offset_factor) == 0))
191+
CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0))
194192
<< "The source elem_offset " << buffer_start_indices[0]
195193
<< " does not satisfy the offset_factor " << buffer->offset_factor << ".";
196194
} else {
197195
// Non-zero elem_offset is ill-defined for non-flat memory.
198196
// If needed in the future, will require `Array<PrimExpr>
199197
// elem_offsets`, with one offset for each flattened index.
200-
Bind(buffer->elem_offset, make_zero(buffer->elem_offset.dtype()));
198+
Bind(buffer->elem_offset, 0);
201199
}
202200
}
203201

@@ -231,7 +229,7 @@ class MatchBufferLower : public DataTypeLegalizer {
231229
}
232230

233231
void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") {
234-
CHECK_EQ(arg.dtype().code(), value.dtype().code())
232+
CHECK_EQ(arg.dtype(), value.dtype())
235233
<< "The data type mismatched: " << arg->dtype << " vs. " << value->dtype;
236234
// Handle recursive case
237235
value = Substitute(std::move(value), var_map_);
@@ -240,7 +238,7 @@ class MatchBufferLower : public DataTypeLegalizer {
240238
auto it = var_map_.find(v);
241239
if (it == var_map_.end()) {
242240
var_map_.Set(v, value);
243-
// analyzer_.Bind(v, value);
241+
analyzer_.Bind(v, value);
244242
} else {
245243
AssertBinding((*it).second, value, arg_name);
246244
}
@@ -249,21 +247,10 @@ class MatchBufferLower : public DataTypeLegalizer {
249247
}
250248
}
251249

252-
PrimExpr LookUpArgBind(const PrimExpr& arg) {
253-
if (arg->IsInstance<VarNode>()) {
254-
Var v = Downcast<Var>(arg);
255-
if (auto it = var_map_.find(v); it != var_map_.end()) {
256-
return (*it).second;
257-
}
258-
}
259-
return arg;
260-
}
261-
262250
void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs,
263251
const std::string& arg_name = "argument") {
264-
CHECK(analyzer_.CanProve(LookUpArgBind(lhs) == rhs))
265-
<< "The buffer match constraint for " << arg_name << " unmet: " << lhs << "==" << rhs
266-
<< ".";
252+
CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name
253+
<< " unmet: " << lhs << "==" << rhs << ".";
267254
}
268255

269256
private:
@@ -277,9 +264,7 @@ class MatchBufferLower : public DataTypeLegalizer {
277264

278265
PrimFunc LowerMatchBuffer(PrimFunc func) {
279266
auto fptr = func.CopyOnWrite();
280-
// LOG(INFO) << "BeforeLMB:\n" << tir::AsTVMScript(func);
281267
fptr->body = MatchBufferLower(func)(std::move(fptr->body));
282-
// LOG(INFO) << "AfterLMB:\n" << tir::AsTVMScript(func);
283268
return func;
284269
}
285270

@@ -297,4 +282,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB
297282
} // namespace transform
298283

299284
} // namespace tir
300-
} // namespace tvm
285+
} // namespace tvm

0 commit comments

Comments
 (0)