Skip to content

Commit 0a836ab

Browse files
committed
Use a set to ensure in-place indices will be unique
1 parent ee1ed13 commit 0a836ab

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,11 @@ class FusedTIRConstructor : public ExprVisitor {
379379
CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
380380
<< "Expected a function with attr `kPrimitive`";
381381
visitor(Downcast<relax::Function>(f));
382-
return {visitor.fused_tir_, visitor.inplace_indices_};
382+
Array<Integer> inplace_indices;
383+
for (size_t idx : visitor.inplace_indices_) {
384+
inplace_indices.push_back(Integer(idx));
385+
}
386+
return {visitor.fused_tir_, inplace_indices};
383387
}
384388

385389
private:
@@ -444,10 +448,10 @@ class FusedTIRConstructor : public ExprVisitor {
444448
const Array<tir::Buffer>& buffers = (*it).second;
445449

446450
// map of input buffers to indices (helpful for detecting in-place inputs)
447-
std::unordered_map<tir::Buffer, Integer, ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
448-
std::unordered_map<tir::Var, Integer, ObjectPtrHash, ObjectPtrEqual> input_to_idx;
451+
std::unordered_map<tir::Buffer, size_t, ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
452+
std::unordered_map<tir::Var, size_t, ObjectPtrHash, ObjectPtrEqual> input_to_idx;
449453
for (size_t i = 0; i < func_info_.params.size(); i++) {
450-
input_to_idx[func_info_.params[i]] = Integer(i);
454+
input_to_idx[func_info_.params[i]] = i;
451455
}
452456
for (auto [var, buffer] : func_info_.buffer_map) {
453457
if (auto it = input_to_idx.find(var); it != input_to_idx.end()) {
@@ -463,7 +467,7 @@ class FusedTIRConstructor : public ExprVisitor {
463467
// (i.e., already listed in the buffer map. This would result
464468
// in duplicates in the buffer map otherwise)
465469
if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) {
466-
inplace_indices_.push_back((*it).second);
470+
inplace_indices_.insert((*it).second);
467471
continue;
468472
}
469473

@@ -933,7 +937,7 @@ class FusedTIRConstructor : public ExprVisitor {
933937
/*! \brief The tir function after fusion*/
934938
tir::PrimFunc fused_tir_;
935939
/*! \brief Indices of inputs that are used for in-place computation */
936-
Array<Integer> inplace_indices_;
940+
std::unordered_set<size_t> inplace_indices_;
937941
};
938942

939943
std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) {

0 commit comments

Comments
 (0)