1717 * under the License.
1818 */
1919#include < tvm/relax/analysis.h>
20+ #include < tvm/relax/attrs/op.h>
2021#include < tvm/relax/expr_functor.h>
2122#include < tvm/relax/struct_info.h>
2223#include < tvm/relax/transform.h>
@@ -367,17 +368,22 @@ class FusedTIRConstructor : public ExprVisitor {
367368 * \brief Construct a fused TIR PrimFunc from a relax sub-function
368369 * \param mod The IRModule
369370 * \param gv The global var of relax subfunction to be fused into one PrimFunc
370- * \return The fused TIR PrimFunc
371+ * \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call)
371372 */
372- static tir::PrimFunc GetFusedTIR (const IRModule& mod, const GlobalVar& gv) {
373+ static std::pair<tir::PrimFunc, Array<Integer>> GetFusedTIR (const IRModule& mod,
374+ const GlobalVar& gv) {
373375 FusedTIRConstructor visitor (mod, gv->name_hint );
374376 BaseFunc f = mod->Lookup (gv);
375377 CHECK (f->IsInstance <relax::FunctionNode>())
376378 << " Expected relax functions, but got: " << f->GetTypeKey ();
377379 CHECK (f->HasNonzeroAttr (relax::attr::kPrimitive ))
378380 << " Expected a function with attr `kPrimitive`" ;
379381 visitor (Downcast<relax::Function>(f));
380- return visitor.fused_tir_ ;
382+ Array<Integer> inplace_indices;
383+ for (size_t idx : visitor.inplace_indices_ ) {
384+ inplace_indices.push_back (Integer (idx));
385+ }
386+ return {visitor.fused_tir_ , inplace_indices};
381387 }
382388
383389 private:
@@ -438,9 +444,38 @@ class FusedTIRConstructor : public ExprVisitor {
438444 auto it = func_info_.expr2buffers .find (body);
439445 ICHECK (it != func_info_.expr2buffers .end ())
440446 << " Fail to detect output buffers for function body" ;
447+
441448 const Array<tir::Buffer>& buffers = (*it).second ;
449+
450+ // map of input buffers to indices (helpful for detecting in-place inputs)
451+ std::unordered_map<tir::Buffer, size_t , ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
452+ std::unordered_map<tir::Var, size_t , ObjectPtrHash, ObjectPtrEqual> input_to_idx;
453+ for (size_t i = 0 ; i < func_info_.params .size (); i++) {
454+ input_to_idx[func_info_.params [i]] = i;
455+ }
456+ for (auto [var, buffer] : func_info_.buffer_map ) {
457+ if (auto it = input_to_idx.find (var); it != input_to_idx.end ()) {
458+ buffer_to_idx[buffer] = (*it).second ;
459+ }
460+ }
461+
462+ // numbered separately because the number of output *vars* might differ from the
463+ // number of outputs if there are in-place inputs
464+ int out_idx = 0 ;
442465 for (size_t i = 0 ; i < buffers.size (); ++i) {
443- tir::Var param = tir::Var (" p_output" + std::to_string (i), PrimType (DataType::Handle ()));
466+ // Do not add output vars for in-place inputs
467+ // (i.e., already listed in the buffer map. This would result
468+ // in duplicates in the buffer map otherwise)
469+ if (auto it = buffer_to_idx.find (buffers[i]); it != buffer_to_idx.end ()) {
470+ auto idx = (*it).second ;
471+ CHECK (!inplace_indices_.count (idx))
472+ << " In-place index " << idx << " used twice! An argument must be aliased." ;
473+ inplace_indices_.insert (idx);
474+ continue ;
475+ }
476+
477+ tir::Var param = tir::Var (" p_output" + std::to_string (out_idx), PrimType (DataType::Handle ()));
478+ out_idx++;
444479 func_info_.buffer_map .Set (param, buffers[i]);
445480 func_info_.params .push_back (param);
446481 func_info_.output_buffers .insert (buffers[i].get ());
@@ -476,8 +511,11 @@ class FusedTIRConstructor : public ExprVisitor {
476511 void VisitExpr_ (const CallNode* call) final {
477512 ExprVisitor::VisitExpr_ (call);
478513 static const Op& call_tir_op_ = Op::Get (" relax.call_tir" );
479- ICHECK (call->op == call_tir_op_)
480- << " Only call_tir is supported in primitive function, but got: " << GetRef<Expr>(call);
514+ static const Op& call_tir_inplace_op_ = Op::Get (" relax.call_tir_inplace" );
515+
516+ ICHECK (call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
517+ << " Only call_tir and call_tir_inplace are supported in primitive function, but got: "
518+ << GetRef<Expr>(call);
481519
482520 // Step 1. Get Global var and PrimFunc
483521 GlobalVar gv = Downcast<GlobalVar>(call->args [0 ]);
@@ -503,7 +541,7 @@ class FusedTIRConstructor : public ExprVisitor {
503541 MapInputBuffer (prim_func, call->args [1 ]);
504542 const Array<Array<PrimExpr>>& output_buffer_shapes = GetCallTIROutputShapes (call);
505543
506- AllocateIntermediateBuffer (GetRef<Expr>( call) , prim_func, output_buffer_shapes);
544+ AllocateIntermediateBuffer (call, prim_func, output_buffer_shapes);
507545
508546 // Step 6. Update tir_vars
509547 if (call->args .size () > 2 ) {
@@ -566,7 +604,8 @@ class FusedTIRConstructor : public ExprVisitor {
566604 */
567605 static Array<Array<PrimExpr>> GetCallTIROutputShapes (const CallNode* call) {
568606 static const Op& call_tir_op_ = Op::Get (" relax.call_tir" );
569- ICHECK (call->op .same_as (call_tir_op_));
607+ static const Op& call_tir_inplace_op_ = Op::Get (" relax.call_tir_inplace" );
608+ ICHECK (call->op .same_as (call_tir_op_) || call->op .same_as (call_tir_inplace_op_));
570609 ICHECK_EQ (call->sinfo_args .size (), 1 );
571610 auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
572611 const auto * shape_expr = sinfo->shape .as <ShapeExprNode>();
@@ -611,7 +650,7 @@ class FusedTIRConstructor : public ExprVisitor {
611650 }
612651 }
613652 }
614- // Make sure every buffers are mapped.
653+ // Make sure every buffer is mapped.
615654 ICHECK_EQ (buffer_idx, buffers.size ());
616655 }
617656
@@ -639,28 +678,49 @@ class FusedTIRConstructor : public ExprVisitor {
639678 MapArgsToBuffer (arg_list, buffer_list);
640679 }
641680
642- static Array<tir::Var> GetPrimFuncOutputParams (const tir::PrimFunc& func, size_t output_size) {
681+ static Array<Integer> GetInplaceOutputIndices (const Array<Integer>& inplace_indices,
682+ int num_inputs) {
683+ Array<Integer> ret;
684+ int last_idx = num_inputs;
685+ for (auto idx : inplace_indices) {
686+ int i = idx.IntValue ();
687+ if (i >= 0 ) {
688+ ret.push_back (Integer (i));
689+ } else {
690+ ret.push_back (Integer (last_idx));
691+ last_idx++;
692+ }
693+ }
694+
695+ return ret;
696+ }
697+
698+ static Array<tir::Var> GetPrimFuncOutputParams (const tir::PrimFunc& func,
699+ const Array<Integer>& output_indices) {
643700 size_t n = func->params .size ();
644701 int symbolic_var_index = -1 ;
702+ size_t output_size = output_indices.size ();
645703 ICHECK_GE (n, output_size);
646- for (size_t i = 0 ; i < n; ++i) {
647- const tir::Var& param = func->params [i];
704+
705+ Array<tir::Var> ret;
706+ for (auto idx : output_indices) {
707+ int i = idx.IntValue ();
708+ const tir::Var& param = func->params [static_cast <size_t >(i)];
648709 if (param->dtype .is_int () || param->dtype .is_uint ()) {
649710 if (symbolic_var_index == -1 ) symbolic_var_index = i;
650711 } else if (param->dtype .is_handle ()) {
651712 CHECK (symbolic_var_index == -1 ) << " The scalar input should be at the ending of the "
652713 " parameter list." ;
714+ ret.push_back (param);
653715 } else {
654716 LOG (FATAL) << " The params of PrimFunc are expected to be Buffer handle or scalar, but got: "
655717 << param->dtype ;
656718 }
657719 }
720+
658721 size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index;
659722 ICHECK_GE (end_index, output_size);
660- size_t begin_index = end_index - output_size;
661- Array<tir::Var> output_params{func->params .begin () + begin_index,
662- func->params .begin () + end_index};
663- return output_params;
723+ return ret;
664724 }
665725
666726 /* !
@@ -670,18 +730,39 @@ class FusedTIRConstructor : public ExprVisitor {
670730 * \param func The old TIR PrimFunc
671731 * \param output_shapes The shape of output params.
672732 */
673- void AllocateIntermediateBuffer (const Expr& expr , const tir::PrimFunc& func,
733+ void AllocateIntermediateBuffer (const CallNode* call , const tir::PrimFunc& func,
674734 const Array<Array<PrimExpr>>& output_shapes) {
735+ bool is_inplace = (call->op == Op::Get (" relax.call_tir_inplace" ));
736+
675737 size_t n = func->params .size ();
738+ int num_inputs = Downcast<Tuple>(call->args [1 ])->fields .size ();
676739 size_t output_size = output_shapes.size ();
677740 ICHECK_GE (n, output_size);
678- // Allocate intermediate buffer
679- Array<tir::Buffer> alloc_buffers;
680- Array<tir::Var> output_params = GetPrimFuncOutputParams (func, output_size);
741+ Array<tir::Buffer> output_buffers;
742+ Array<Integer> output_idxs;
743+ if (is_inplace) {
744+ const auto * attrs = call->attrs .as <CallTIRInplaceAttrs>();
745+ CHECK (attrs) << " Must have CallTIRInplaceAttrs for an in-place call" ;
746+ output_idxs = std::move (GetInplaceOutputIndices (attrs->inplace_indices , num_inputs));
747+ } else {
748+ for (size_t i = 0 ; i < output_size; i++) {
749+ output_idxs.push_back (num_inputs + i);
750+ }
751+ }
752+
753+ Array<tir::Var> output_params = GetPrimFuncOutputParams (func, output_idxs);
754+ auto input_buffers = func_info_.expr2buffers .Get (call->args [1 ]);
681755 for (size_t i = 0 ; i < output_size; ++i) {
682756 const tir::Var& param = output_params[i];
683757 const tir::Buffer& buffer = func->buffer_map .at (param);
684758
759+ // if this is an inplace output, do not do an intermediate allocation
760+ if (output_idxs[i].IntValue () < num_inputs) {
761+ CHECK (input_buffers.defined ()) << " Inplace functions must have some defined input" ;
762+ output_buffers.push_back (input_buffers.value ()[output_idxs[i].IntValue ()]);
763+ continue ;
764+ }
765+
685766 auto unify_name_hints = [this , &buffer]() {
686767 String base_name = buffer->name ;
687768 String unique_name = base_name + " _intermediate" ;
@@ -703,14 +784,14 @@ class FusedTIRConstructor : public ExprVisitor {
703784 n->name = unify_name_hints ();
704785 tir::Buffer new_buffer (n);
705786 func_info_.alloc_buffers .push_back (new_buffer);
706- alloc_buffers .push_back (new_buffer);
787+ output_buffers .push_back (new_buffer);
707788
708789 // Match the shape of the output buffer with the shape
709790 func_info_.symbolic_var_matcher .Match (buffer->shape , n->shape );
710791 func_info_.buffer_subst_map .Set (buffer, new_buffer);
711792 }
712793 // Update expr2buffers
713- func_info_.expr2buffers .Set (expr, alloc_buffers );
794+ func_info_.expr2buffers .Set (GetRef<Expr>(call), output_buffers );
714795 }
715796
716797 /* !
@@ -858,6 +939,8 @@ class FusedTIRConstructor : public ExprVisitor {
858939 FuseFuncInfo func_info_;
859940 /* ! \brief The tir function after fusion*/
860941 tir::PrimFunc fused_tir_;
942+ /* ! \brief Indices of inputs that are used for in-place computation */
943+ std::unordered_set<size_t > inplace_indices_;
861944};
862945
863946std::vector<size_t > GetTupleAccessedIndices (const FunctionNode* func, const Var& tuple_var) {
@@ -897,8 +980,11 @@ class TIRFuseMutator : public ExprMutator {
897980 for (const auto & [gv, func] : mod->functions ) {
898981 // Only fuse primitive relax functions
899982 if (func->IsInstance <relax::FunctionNode>() && func->HasNonzeroAttr (attr::kPrimitive )) {
900- tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR (mod, gv);
901- mutator.fused_tir_funcs_ .Set (gv, fused_tir);
983+ const auto & [prim_func, indices] = FusedTIRConstructor::GetFusedTIR (mod, gv);
984+ mutator.fused_tir_funcs_ .Set (gv, prim_func);
985+ if (!indices.empty ()) {
986+ mutator.inplace_indices_ .Set (gv, indices);
987+ }
902988 }
903989 }
904990
@@ -945,6 +1031,7 @@ class TIRFuseMutator : public ExprMutator {
9451031
9461032 Expr VisitExpr_ (const CallNode* op) final {
9471033 static const Op& call_tir_op_ = Op::Get (" relax.call_tir" );
1034+ static const Op& call_tir_inplace_op_ = Op::Get (" relax.call_tir_inplace" );
9481035
9491036 Call call = Downcast<Call>(builder_->Normalize (ExprMutator::VisitExpr_ (op)));
9501037
@@ -985,26 +1072,34 @@ class TIRFuseMutator : public ExprMutator {
9851072 CHECK (prim_value->value .defined ())
9861073 << " FuseTIR requires all R.Prim arguments to have a known value." ;
9871074 PrimExpr expr = prim_value->value .value ();
988- CHECK (expr->IsInstance <tir::VarNode>())
989- << " FuseTIR currently requires all R.Prim arguments to provide a single tir::Var." ;
1075+ CHECK (expr->IsInstance <tir::VarNode>()) << " FuseTIR currently requires all R.Prim "
1076+ " arguments to provide a single tir::Var." ;
9901077 tir_vars.push_back (expr);
9911078
9921079 } else {
9931080 arg_list.push_back (arg);
9941081 }
9951082 }
996- // Step b. Create call_tir
1083+ // Step b. Create call_tir or call_tir_inplace
9971084 Array<Expr> call_args = {fused_tir_gv, Tuple (arg_list)};
9981085 if (!tir_vars.empty ()) {
9991086 call_args.push_back (ShapeExpr (tir_vars));
10001087 }
1001- return Call (call_tir_op_, call_args, call->attrs , {GetStructInfo (call)});
1088+ Op call_op = call_tir_op_;
1089+ Attrs call_attrs = call->attrs ;
1090+ if (auto it = inplace_indices_.find (old_gv); it != inplace_indices_.end ()) {
1091+ call_op = call_tir_inplace_op_;
1092+ auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
1093+ inplace_attrs->inplace_indices = (*it).second ;
1094+ call_attrs = Attrs (inplace_attrs);
1095+ }
1096+ return Call (call_op, call_args, call_attrs, {GetStructInfo (call)});
10021097 } else {
10031098 // Case 1.2. The callee function is not primitive, nothing to do.
10041099 return call;
10051100 }
1006- } else if (call->op == call_tir_op_) {
1007- // Case 2. It is a call_tir, re-emit the PrimFunc.
1101+ } else if (call->op == call_tir_op_ || call-> op == call_tir_inplace_op_ ) {
1102+ // Case 2. It is a call_tir or call_tir_inplace , re-emit the PrimFunc.
10081103 if (const auto * gv = call->args [0 ].as <GlobalVarNode>()) {
10091104 tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup (GetRef<GlobalVar>(gv)));
10101105 GlobalVar new_gv = this ->builder_ ->AddFunction (func, gv->name_hint );
@@ -1023,6 +1118,9 @@ class TIRFuseMutator : public ExprMutator {
10231118 const IRModule& mod_;
10241119 /* ! \brief The map from global var of primitive relax function to generated prim func. */
10251120 Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
1121+ /* ! \brief The map from global var of primitive relax function to in-place indices
1122+ * (if there are any). */
1123+ Map<GlobalVar, Array<Integer>> inplace_indices_;
10261124};
10271125
10281126IRModule FuseTIR (IRModule mod) {
0 commit comments