@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
362362
363363namespace relax {
364364
365+ static Array<Integer> GetInplaceOutputIndices (const Array<Integer>& inplace_indices,
366+ int num_inputs) {
367+ Array<Integer> ret;
368+ int last_idx = num_inputs;
369+ for (auto idx : inplace_indices) {
370+ int i = idx.IntValue ();
371+ if (i >= 0 ) {
372+ ret.push_back (Integer (i));
373+ } else {
374+ ret.push_back (Integer (last_idx));
375+ last_idx++;
376+ }
377+ }
378+
379+ return ret;
380+ }
381+
382+ class RelaxToTIRVarMapCollector : public ExprVisitor {
383+ void CollectVarMapping (const CallNode* call, const Expr& lhs_var, bool in_place = false ) {
384+ GlobalVar gv = Downcast<GlobalVar>(call->args [0 ]);
385+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup (gv));
386+ const auto & buffer_map = prim_func_->buffer_map ;
387+ const auto & tir_args = prim_func_->params ;
388+
389+ const auto & relax_args = Downcast<Tuple>(call->args [1 ])->fields ;
390+
391+ Array<Expr> relax_results;
392+ if (lhs_var->IsInstance <TupleNode>()) {
393+ relax_results = Downcast<Tuple>(lhs_var)->fields ;
394+ } else {
395+ CHECK (lhs_var->IsInstance <VarNode>()) << " The lhs_var is expected to be either tuple or var" ;
396+ relax_results = {Downcast<Var>(lhs_var)};
397+ }
398+
399+ size_t num_inputs = relax_args.size ();
400+ size_t num_outputs = relax_results.size ();
401+
402+ Array<Integer> output_idxs;
403+ if (in_place) {
404+ const auto * attrs = call->attrs .as <CallTIRInplaceAttrs>();
405+ CHECK (attrs) << " Must have CallTIRInplaceAttrs for an in-place call" ;
406+ output_idxs = GetInplaceOutputIndices (attrs->inplace_indices , num_inputs);
407+ } else {
408+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
409+ output_idxs.push_back (i);
410+ }
411+ }
412+ for (size_t i = 0 ; i < tir_args.size (); ++i) {
413+ const auto & tir_var = Downcast<tir::Var>(tir_args[i]);
414+ if (i < num_inputs) {
415+ const auto & relax_var = Downcast<Var>(relax_args[i]);
416+ relax_to_tir_var_map_.Set (relax_var, buffer_map[tir_var]);
417+ }
418+ if (auto it = std::find (output_idxs.begin (), output_idxs.end (), i); it != output_idxs.end ()) {
419+ int result_idx = it - output_idxs.begin ();
420+ const auto & inplace_out_var = Downcast<Var>(relax_results[result_idx]);
421+ relax_to_tir_var_map_.Set (inplace_out_var, buffer_map[tir_var]);
422+ }
423+ }
424+ }
425+
426+ public:
427+ explicit RelaxToTIRVarMapCollector (const IRModule& mod) : mod_(mod) {}
428+ static Map<Var, tir::Buffer> Collect (const IRModule& mod, const Function& func) {
429+ RelaxToTIRVarMapCollector visitor (mod);
430+ visitor (func->body );
431+ return visitor.relax_to_tir_var_map_ ;
432+ }
433+ void VisitBinding_ (const VarBindingNode* binding) final {
434+ const auto & lhs_var = binding->var ;
435+ const auto & value = binding->value ;
436+ if (const CallNode* call = value.as <CallNode>()) {
437+ static const Op& call_tir_op_ = Op::Get (" relax.call_tir" );
438+ static const Op& call_tir_inplace_op_ = Op::Get (" relax.call_tir_inplace" );
439+
440+ ICHECK (call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
441+ << " Only call_tir and call_tir_inplace are supported in primitive function, but got: "
442+ << GetRef<Expr>(call);
443+ if (call->op == call_tir_inplace_op_) {
444+ CollectVarMapping (call, lhs_var, /* in_place*/ true );
445+ } else {
446+ CollectVarMapping (call, lhs_var);
447+ }
448+ }
449+ }
450+
451+ private:
452+ /* ! \brief The IRModule */
453+ const IRModule& mod_;
454+ // size_t call_num_inputs_ = -1;
455+ Map<Var, tir::Buffer> relax_to_tir_var_map_;
456+ };
457+
365458class FusedTIRConstructor : public ExprVisitor {
366459 public:
367460 /* !
@@ -391,10 +484,15 @@ class FusedTIRConstructor : public ExprVisitor {
391484 : mod_(mod), func_name_(func_name) {}
392485
393486 void VisitExpr_ (const FunctionNode* func) final {
487+ auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect (mod_, GetRef<Function>(func));
394488 std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
395489 for (const Var& relax_param : func->params ) {
396490 size_t size_before = prim_func_params.size ();
397- CollectPrimFuncParams (relax_param, &prim_func_params);
491+ if (relax_to_tir_var_map.count (relax_param)) {
492+ CollectPrimFuncParams (relax_param, &prim_func_params, relax_to_tir_var_map[relax_param]);
493+ } else {
494+ CollectPrimFuncParams (relax_param, &prim_func_params);
495+ }
398496
399497 auto param_buffers = [&]() -> Array<tir::Buffer> {
400498 Array<tir::Buffer> out;
@@ -676,23 +774,6 @@ class FusedTIRConstructor : public ExprVisitor {
676774 MapArgsToBuffer (arg_list, buffer_list);
677775 }
678776
679- static Array<Integer> GetInplaceOutputIndices (const Array<Integer>& inplace_indices,
680- int num_inputs) {
681- Array<Integer> ret;
682- int last_idx = num_inputs;
683- for (auto idx : inplace_indices) {
684- int i = idx.IntValue ();
685- if (i >= 0 ) {
686- ret.push_back (Integer (i));
687- } else {
688- ret.push_back (Integer (last_idx));
689- last_idx++;
690- }
691- }
692-
693- return ret;
694- }
695-
696777 static Array<tir::Var> GetPrimFuncOutputParams (const tir::PrimFunc& func,
697778 const Array<Integer>& output_indices) {
698779 size_t n = func->params .size ();
@@ -798,8 +879,9 @@ class FusedTIRConstructor : public ExprVisitor {
798879 * \param name_hint The name hint for params and buffers
799880 * \param out The vector into which to collect the params/buffers
800881 */
801- static void CollectPrimFuncParams (const Var& relax_param,
802- std::vector<Variant<tir::Var, tir::Buffer>>* out) {
882+ static void CollectPrimFuncParams (
883+ const Var& relax_param, std::vector<Variant<tir::Var, tir::Buffer>>* out,
884+ const tvm::runtime::Optional<tir::Buffer>& tir_buffer_param = NullOpt) {
803885 auto struct_info = GetStructInfo (relax_param);
804886
805887 CHECK (!struct_info.as <TupleStructInfoNode>())
@@ -814,7 +896,14 @@ class FusedTIRConstructor : public ExprVisitor {
814896 const auto * shape_expr = tensor->shape .as <ShapeExprNode>();
815897 ICHECK (shape_expr) << " FuseTIR expects all Tensor parameters have a known shape." ;
816898 DataType dtype = tensor->dtype ;
817- tir::Buffer buffer = tir::decl_buffer (shape_expr->values , dtype, name_hint);
899+ tir::Buffer buffer;
900+ if (tir_buffer_param.defined ()) {
901+ buffer =
902+ tir::decl_buffer (shape_expr->values , dtype, name_hint, tir_buffer_param.value ().scope (),
903+ tir_buffer_param.value ()->axis_separators );
904+ } else {
905+ buffer = tir::decl_buffer (shape_expr->values , dtype, name_hint);
906+ }
818907 out->push_back (std::move (buffer));
819908
820909 } else if (const auto * prim_value = struct_info.as <PrimStructInfoNode>()) {
0 commit comments