@@ -379,7 +379,11 @@ class FusedTIRConstructor : public ExprVisitor {
379379 CHECK (f->HasNonzeroAttr (relax::attr::kPrimitive ))
380380 << " Expected a function with attr `kPrimitive`" ;
381381 visitor (Downcast<relax::Function>(f));
382- return {visitor.fused_tir_ , visitor.inplace_indices_ };
382+ Array<Integer> inplace_indices;
383+ for (size_t idx : visitor.inplace_indices_ ) {
384+ inplace_indices.push_back (Integer (idx));
385+ }
386+ return {visitor.fused_tir_ , inplace_indices};
383387 }
384388
385389 private:
@@ -444,10 +448,10 @@ class FusedTIRConstructor : public ExprVisitor {
444448 const Array<tir::Buffer>& buffers = (*it).second ;
445449
446450 // map of input buffers to indices (helpful for detecting in-place inputs)
447- std::unordered_map<tir::Buffer, Integer , ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
448- std::unordered_map<tir::Var, Integer , ObjectPtrHash, ObjectPtrEqual> input_to_idx;
451+ std::unordered_map<tir::Buffer, size_t , ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
452+ std::unordered_map<tir::Var, size_t , ObjectPtrHash, ObjectPtrEqual> input_to_idx;
449453 for (size_t i = 0 ; i < func_info_.params .size (); i++) {
450- input_to_idx[func_info_.params [i]] = Integer (i) ;
454+ input_to_idx[func_info_.params [i]] = i ;
451455 }
452456 for (auto [var, buffer] : func_info_.buffer_map ) {
453457 if (auto it = input_to_idx.find (var); it != input_to_idx.end ()) {
@@ -463,7 +467,7 @@ class FusedTIRConstructor : public ExprVisitor {
463467 // (i.e., already listed in the buffer map. This would result
464468 // in duplicates in the buffer map otherwise)
465469 if (auto it = buffer_to_idx.find (buffers[i]); it != buffer_to_idx.end ()) {
466- inplace_indices_.push_back ((*it).second );
470+ inplace_indices_.insert ((*it).second );
467471 continue ;
468472 }
469473
@@ -933,7 +937,7 @@ class FusedTIRConstructor : public ExprVisitor {
933937 /* ! \brief The tir function after fusion*/
934938 tir::PrimFunc fused_tir_;
935939 /* ! \brief Indices of inputs that are used for in-place computation */
936- Array<Integer > inplace_indices_;
940+ std::unordered_set< size_t > inplace_indices_;
937941};
938942
939943std::vector<size_t > GetTupleAccessedIndices (const FunctionNode* func, const Var& tuple_var) {
0 commit comments