@@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator {
362362
363363namespace relax {
364364
365+ static Array<Integer> GetInplaceOutputIndices (const Array<Integer>& inplace_indices,
366+ int num_inputs) {
367+ Array<Integer> ret;
368+ int last_idx = num_inputs;
369+ for (auto idx : inplace_indices) {
370+ int i = idx.IntValue ();
371+ if (i >= 0 ) {
372+ ret.push_back (Integer (i));
373+ } else {
374+ CHECK_EQ (i, -1 ) << " The only negative index expected in inplace_indices is -1, but got " << i;
375+ ret.push_back (Integer (last_idx));
376+ last_idx++;
377+ }
378+ }
379+
380+ return ret;
381+ }
382+
383+ class RelaxToTIRVarMapCollector : public ExprVisitor {
384+ public:
385+ explicit RelaxToTIRVarMapCollector (const IRModule& mod) : mod_(mod) {}
386+ static Map<Expr, tir::Buffer> Collect (const IRModule& mod, const Function& func) {
387+ RelaxToTIRVarMapCollector visitor (mod);
388+ visitor (func->body );
389+ return visitor.relax_to_tir_var_map_ ;
390+ }
391+
392+ private:
393+ void VisitBinding_ (const VarBindingNode* binding) final {
394+ current_var_ = binding->var ;
395+ ExprVisitor::VisitBinding_ (binding);
396+ }
397+
398+ void VisitExpr_ (const CallNode* call) {
399+ static const Op& call_tir_op_ = Op::Get (" relax.call_tir" );
400+ static const Op& call_tir_inplace_op_ = Op::Get (" relax.call_tir_inplace" );
401+
402+ ICHECK (call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
403+ << " Only call_tir and call_tir_inplace are supported in primitive function, but got: "
404+ << GetRef<Expr>(call);
405+ CollectVarMapping (call, current_var_, call->op == call_tir_inplace_op_);
406+ }
407+
408+ void CollectVarMapping (const CallNode* call, const Expr& lhs_var, bool in_place) {
409+ GlobalVar gv = Downcast<GlobalVar>(call->args [0 ]);
410+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup (gv));
411+ const auto & buffer_map = prim_func_->buffer_map ;
412+ const auto & tir_args = prim_func_->params ;
413+
414+ const auto & relax_args = Downcast<Tuple>(call->args [1 ])->fields ;
415+
416+ Array<Expr> relax_results;
417+ if (lhs_var->IsInstance <TupleNode>()) {
418+ relax_results = Downcast<Tuple>(lhs_var)->fields ;
419+ } else {
420+ CHECK (lhs_var->IsInstance <VarNode>()) << " The lhs_var is expected to be either tuple or var" ;
421+ relax_results = {Downcast<Var>(lhs_var)};
422+ }
423+
424+ size_t num_inputs = relax_args.size ();
425+ size_t num_outputs = relax_results.size ();
426+
427+ Array<Integer> output_idxs;
428+ if (in_place) {
429+ const auto * attrs = call->attrs .as <CallTIRInplaceAttrs>();
430+ CHECK (attrs) << " Must have CallTIRInplaceAttrs for an in-place call" ;
431+ output_idxs = GetInplaceOutputIndices (attrs->inplace_indices , num_inputs);
432+ } else {
433+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
434+ output_idxs.push_back (i);
435+ }
436+ }
437+
438+ // If the `expr` is already seen (present in the map), validate whether the mapped buffer is
439+ // structurally equal to the `new_buf` passed
440+ auto ValidateBufferCompatibility = [this ](tir::Buffer new_buf, Expr expr) {
441+ if (auto it = relax_to_tir_var_map_.find (expr); it != relax_to_tir_var_map_.end ()) {
442+ ICHECK (StructuralEqual ()((*it).second , new_buf))
443+ << " Inconsistent buffers " << (*it).second << " and " << new_buf
444+ << " mapped to the same relax var: " << expr;
445+ }
446+ };
447+ for (size_t i = 0 ; i < tir_args.size (); ++i) {
448+ const auto & tir_var = tir_args[i];
449+ if (auto tir_buffer = buffer_map.Get (tir_var)) {
450+ if (i < num_inputs) {
451+ const auto & relax_var = relax_args[i];
452+ ValidateBufferCompatibility (tir_buffer.value (), relax_var);
453+ relax_to_tir_var_map_.Set (relax_var, tir_buffer.value ());
454+ }
455+ if (auto it = std::find (output_idxs.begin (), output_idxs.end (), i);
456+ it != output_idxs.end ()) {
457+ int result_idx = it - output_idxs.begin ();
458+ const auto & relax_var = relax_results[result_idx];
459+ ValidateBufferCompatibility (tir_buffer.value (), relax_var);
460+ relax_to_tir_var_map_.Set (relax_var, tir_buffer.value ());
461+ }
462+ }
463+ }
464+ }
465+
466+ private:
467+ /* ! \brief The IRModule */
468+ const IRModule& mod_;
469+ Map<Expr, tir::Buffer> relax_to_tir_var_map_;
470+ Var current_var_;
471+ };
472+
365473class FusedTIRConstructor : public ExprVisitor {
366474 public:
367475 /* !
@@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor {
391499 : mod_(mod), func_name_(func_name) {}
392500
393501 void VisitExpr_ (const FunctionNode* func) final {
502+ auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect (mod_, GetRef<Function>(func));
394503 std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
395504 for (const Var& relax_param : func->params ) {
396505 size_t size_before = prim_func_params.size ();
397- CollectPrimFuncParams (relax_param, &prim_func_params);
506+ CollectPrimFuncParams (relax_param, &prim_func_params, relax_to_tir_var_map. Get (relax_param) );
398507
399508 auto param_buffers = [&]() -> Array<tir::Buffer> {
400509 Array<tir::Buffer> out;
@@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor {
676785 MapArgsToBuffer (arg_list, buffer_list);
677786 }
678787
679- static Array<Integer> GetInplaceOutputIndices (const Array<Integer>& inplace_indices,
680- int num_inputs) {
681- Array<Integer> ret;
682- int last_idx = num_inputs;
683- for (auto idx : inplace_indices) {
684- int i = idx.IntValue ();
685- if (i >= 0 ) {
686- ret.push_back (Integer (i));
687- } else {
688- ret.push_back (Integer (last_idx));
689- last_idx++;
690- }
691- }
692-
693- return ret;
694- }
695-
696788 static Array<tir::Var> GetPrimFuncOutputParams (const tir::PrimFunc& func,
697789 const Array<Integer>& output_indices) {
698790 size_t n = func->params .size ();
@@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor {
799891 * \param out The vector into which to collect the params/buffers
800892 */
801893 static void CollectPrimFuncParams (const Var& relax_param,
802- std::vector<Variant<tir::Var, tir::Buffer>>* out) {
894+ std::vector<Variant<tir::Var, tir::Buffer>>* out,
895+ const tvm::runtime::Optional<tir::Buffer>& tir_buffer_param) {
803896 auto struct_info = GetStructInfo (relax_param);
804897
805898 CHECK (!struct_info.as <TupleStructInfoNode>())
@@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor {
814907 const auto * shape_expr = tensor->shape .as <ShapeExprNode>();
815908 ICHECK (shape_expr) << " FuseTIR expects all Tensor parameters have a known shape." ;
816909 DataType dtype = tensor->dtype ;
817- tir::Buffer buffer = tir::decl_buffer (shape_expr->values , dtype, name_hint);
910+ tir::Buffer buffer;
911+ if (tir_buffer_param.defined ()) {
912+ buffer =
913+ tir::decl_buffer (shape_expr->values , dtype, name_hint, tir_buffer_param.value ().scope (),
914+ tir_buffer_param.value ()->axis_separators );
915+ } else {
916+ buffer = tir::decl_buffer (shape_expr->values , dtype, name_hint);
917+ }
818918 out->push_back (std::move (buffer));
819919
820920 } else if (const auto * prim_value = struct_info.as <PrimStructInfoNode>()) {
0 commit comments