@@ -385,58 +385,45 @@ class FusedTIRConstructor : public ExprVisitor {
385385 : mod_(mod), func_name_(func_name) {}
386386
387387 void VisitExpr_ (const FunctionNode* func) final {
388- // Step 1. Create buffers for function params
389-
390- // Record which fields in a tuple passed as a parameter are actually accessed by the function.
391- std::unordered_set<const Object*> tuple_param;
392- for (auto param : func->params ) {
393- if (GetStructInfo (param)->IsInstance <TupleStructInfoNode>()) {
394- tuple_param.insert (param.get ());
395- }
396- }
397-
398- PostOrderVisit (func->body , [=, &tuple_param](Expr e) {
399- if (auto tup_get = e.as <TupleGetItemNode>();
400- tup_get && tuple_param.count (tup_get->tuple .get ())) {
401- func_info_.used_tuple_field_indices [tup_get->tuple .get ()].insert (tup_get->index );
402- }
403- });
404-
388+ std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
405389 for (const Var& relax_param : func->params ) {
406- auto sinfo = GetStructInfo (relax_param);
407- if (sinfo->IsInstance <ShapeStructInfoNode>()) {
408- // It's a symbolic shape var, no need to alloc Buffers.
409- continue ;
410- }
411-
412- auto [params, buffers] = [=]() {
413- if (const auto * tuple = sinfo.as <TupleStructInfoNode>()) {
414- // Add only those tuple fields which are actually used by the function body into the
415- // function parameters.
416- int index = 0 ;
417- Array<tir::Var> params;
418- Array<tir::Buffer> buffers;
419- for (auto i : func_info_.used_tuple_field_indices [relax_param.get ()]) {
420- auto [ret_params, ret_buffers] =
421- CreateParamsAndBuffers (tuple->fields [i], relax_param->name_hint (), index);
422- ICHECK_EQ (ret_params.size (), ret_buffers.size ());
423- // Adding tuple field results to the end of params and buffers.
424- params.insert (params.end (), ret_params.begin (), ret_params.end ());
425- buffers.insert (buffers.end (), ret_buffers.begin (), ret_buffers.end ());
426- index += ret_params.size ();
390+ size_t size_before = prim_func_params.size ();
391+ CollectPrimFuncParams (relax_param, &prim_func_params);
392+
393+ auto param_buffers = [&]() -> Array<tir::Buffer> {
394+ Array<tir::Buffer> out;
395+ for (size_t i = size_before; i < prim_func_params.size (); i++) {
396+ if (auto buf = prim_func_params[i].as <tir::Buffer>()) {
397+ out.push_back (buf.value ());
427398 }
428- return std::make_pair (params, buffers);
429- } else {
430- return CreateParamsAndBuffers (sinfo, relax_param->name_hint ());
431399 }
400+ return out;
432401 }();
433402
434- ICHECK_EQ (params.size (), buffers.size ());
435- for (size_t i = 0 ; i < params.size (); ++i) {
436- func_info_.buffer_map .Set (params[i], buffers[i]);
437- func_info_.params .push_back (params[i]);
403+ func_info_.expr2buffers .Set (relax_param, param_buffers);
404+ }
405+
406+ // Move all scalar params after buffer params.
407+ std::stable_sort (prim_func_params.begin (), prim_func_params.end (),
408+ [](const auto & a, const auto & b) {
409+ bool a_is_var = a.template as <tir::VarNode>();
410+ bool b_is_var = b.template as <tir::VarNode>();
411+ return a_is_var < b_is_var;
412+ });
413+
414+ for (const auto & param : prim_func_params) {
415+ if (auto opt = param.as <tir::Buffer>()) {
416+ auto buffer = opt.value ();
417+ // Differentiate buffer name and param name by adding prefix
418+ // `p_` to the buffer name. Every symbol should be unique in
419+ // TVMScript, and while they can be de-deplicated when
420+ // printed, it's more readable when done explicitly. Since
421+ // Buffer is used more than param it gets the name with better
422+ // readability.
423+ tir::Var param = tir::Var (" p_" + buffer->name , PrimType (DataType::Handle ()));
424+ func_info_.params .push_back (param);
425+ func_info_.buffer_map .Set (param, buffer);
438426 }
439- func_info_.expr2buffers .Set (relax_param, buffers);
440427 }
441428
442429 // Step 2. Visit Function body and create intermediate buffers
@@ -458,13 +445,9 @@ class FusedTIRConstructor : public ExprVisitor {
458445 }
459446
460447 // Step 4. Append symbolic vars
461- const relax::Var& last_relax_param = func->params .back ();
462- if (GetStructInfo (last_relax_param)->IsInstance <ShapeStructInfoNode>()) {
463- auto [params, buffers] =
464- CreateParamsAndBuffers (GetStructInfo (last_relax_param), last_relax_param->name_hint ());
465- ICHECK (buffers.empty ());
466- for (size_t i = 0 ; i < params.size (); ++i) {
467- func_info_.params .push_back (params[i]);
448+ for (const auto & param : prim_func_params) {
449+ if (auto var = param.as <tir::Var>()) {
450+ func_info_.params .push_back (var.value ());
468451 }
469452 }
470453
@@ -548,12 +531,7 @@ class FusedTIRConstructor : public ExprVisitor {
548531 int end_buf_idx = 0 ;
549532 const TupleType& tuple_type = Downcast<TupleType>(tuple_get_item->tuple ->checked_type ());
550533 for (int i = 0 ; i < tuple_get_item->index ; ++i) {
551- auto it = func_info_.used_tuple_field_indices .find (tuple_get_item->tuple .get ());
552- // If this tuple is not passed as a parameter, or if the field at the index i is actually
553- // used, the corresponding buffer needs to be taken into account by this function.
554- if (it == func_info_.used_tuple_field_indices .end () || it->second .count (i)) {
555- begin_buf_idx += GetTotalTensorSize (tuple_type->fields [i]);
556- }
534+ begin_buf_idx += GetTotalTensorSize (tuple_type->fields [i]);
557535 }
558536 end_buf_idx = begin_buf_idx + GetTotalTensorSize (tuple_type->fields [tuple_get_item->index ]);
559537 func_info_.expr2buffers .Set (
@@ -719,64 +697,46 @@ class FusedTIRConstructor : public ExprVisitor {
719697 }
720698
721699 /* !
722- * \brief Create an TIR func params and buffers with specified relax type and shape
700+ * \brief Collect TIR func params and buffers with specified relax type and shape
723701 * \param struct_info The struct info
724702 * \param name_hint The name hint for params and buffers
725- * \param index The index used for unique name_hint if type is Tuple.
726- * -1 means no need to add postfix since the relax param is not a Tuple.
727- * \return The created TIR func params and buffers
703+ * \param out The vector into which to collect the params/buffers
728704 */
729- static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers (
730- StructInfo struct_info, const String& name_hint, int index = -1 ) {
731- Array<tir::Var> params;
732- Array<tir::Buffer> buffers;
733- // The symbolic shape params must be defined at the end of the param list.
734- bool symbolic_shape_param_started = false ;
705+ static void CollectPrimFuncParams (const Var& relax_param,
706+ std::vector<Variant<tir::Var, tir::Buffer>>* out) {
707+ auto struct_info = GetStructInfo (relax_param);
708+
709+ CHECK (!struct_info.as <TupleStructInfoNode>())
710+ << " InternalError: "
711+ << " All tuple parameters should be expanded before this point in FuseTIR. "
712+ << " However, parameter " << relax_param << " has struct info " << struct_info;
713+
714+ auto name_hint = relax_param->name_hint ();
715+
735716 if (const auto * tensor = struct_info.as <TensorStructInfoNode>()) {
736- // Case 1. the relax param is a Tensor, we directly create a tir var and buffer
717+ // Case 1. The relax param is a Tensor, we directly create a tir var and buffer
737718 const auto * shape_expr = tensor->shape .as <ShapeExprNode>();
738- ICHECK (shape_expr) << " FuseTIR expects all parameters are Tensors with symbolic shape." ;
739- CHECK (!symbolic_shape_param_started)
740- << " The symbolic shape params must be defined at the end of the param "
741- " list." ;
742- String name = index == -1 ? name_hint : name_hint + " _" + std::to_string (index);
719+ ICHECK (shape_expr) << " FuseTIR expects all Tensor parameters have a known shape." ;
743720 DataType dtype = tensor->dtype ;
744- tir::Buffer buffer = tir::decl_buffer (shape_expr->values , dtype, name);
745- // Differentiate buffer name and param name by adding prefix `v_` to param
746- // Every symbol should be unique in TVMScript, and Buffer is used more than param
747- // So we decide to make sure buffer names have better readability.
748- tir::Var param = tir::Var (" p_" + name, PrimType (DataType::Handle ()));
749- params.push_back (std::move (param));
750- buffers.push_back (std::move (buffer));
751- } else if (const auto * tuple = struct_info.as <TupleStructInfoNode>()) {
752- // Case 2. the relax param is a Tuple, we recursively visit each field until it's a Tensor
753- // Enable postfix
754- CHECK (!symbolic_shape_param_started)
755- << " The symbolic shape params must be defined at the end of the param "
756- " list." ;
757- if (index == -1 ) index = 0 ;
758- for (size_t i = 0 ; i < tuple->fields .size (); ++i) {
759- auto [ret_params, ret_buffers] = CreateParamsAndBuffers (tuple->fields [i], name_hint, index);
760- ICHECK_EQ (ret_params.size (), ret_buffers.size ());
761- // Adding tuple field results to the end of params and buffers.
762- params.insert (params.end (), ret_params.begin (), ret_params.end ());
763- buffers.insert (buffers.end (), ret_buffers.begin (), ret_buffers.end ());
764- index += ret_params.size ();
765- }
721+ tir::Buffer buffer = tir::decl_buffer (shape_expr->values , dtype, name_hint);
722+ out->push_back (std::move (buffer));
723+
724+ } else if (const auto * prim_value = struct_info.as <PrimStructInfoNode>()) {
725+ // Case 2. The relax param is a scalar, we directly create a tir var
726+ ICHECK (prim_value->value ->IsInstance <tir::VarNode>());
727+ out->push_back (Downcast<tir::Var>(prim_value->value ));
728+
766729 } else if (const auto * shape_expr = struct_info.as <ShapeStructInfoNode>()) {
767- // Case 3. the relax param is a scalar, we directly create a tir var
768- symbolic_shape_param_started = true ;
769- ICHECK (index == -1 ) << " TypeError: The ShapeExprNode should not be in a Tuple field." ;
730+ // Case 3. The relax param is a tuple of scalars, each represented as a tir var
770731 for (const auto & var : shape_expr->values .value ()) {
771732 ICHECK (var->IsInstance <tir::VarNode>());
772- params. push_back (Downcast<tir::Var>(var));
733+ out-> push_back (Downcast<tir::Var>(var));
773734 }
774735 } else {
775736 ICHECK (false ) << " TypeError: The param type of PrimFunc is expected to be Tensor, Tuple or "
776737 " ShapeExpr, but got "
777738 << struct_info->GetTypeKey ();
778739 }
779- return std::make_pair (params, buffers);
780740 }
781741
782742 /* !
@@ -870,9 +830,6 @@ class FusedTIRConstructor : public ExprVisitor {
870830 /* ! \brief The map from symbolic var to its corresponding var in the fused function */
871831 tir::SymbolicMatcher symbolic_var_matcher =
872832 tir::SymbolicMatcher (&analyzer, &symbolic_var_remap);
873-
874- /* ! \brief Record indices of tuple fields that are actually accessed. */
875- std::unordered_map<const Object*, std::unordered_set<size_t >> used_tuple_field_indices;
876833 };
877834
878835 /* ! \brief The IRModule */
@@ -987,34 +944,35 @@ class TIRFuseMutator : public ExprMutator {
987944 Array<PrimExpr> tir_vars;
988945 for (size_t i = 0 ; i < call->args .size (); ++i) {
989946 auto arg = call->args [i];
990- Array<Expr> flattened;
991- if (GetStructInfo (relax_func->params [i])->IsInstance <TupleStructInfoNode>()) {
992- // Add only those tuple fields which are actually used by the function body
993- auto tup_get_indices = GetTupleAccessedIndices (relax_func.get (), relax_func->params [i]);
994- for (size_t tup_get_ind : tup_get_indices) {
995- auto flattened_inner = FlattenArg (builder_->Emit (TupleGetItem (arg, tup_get_ind)));
996- flattened.insert (flattened.end (), flattened_inner.begin (), flattened_inner.end ());
947+ auto sinfo = GetStructInfo (arg);
948+
949+ ICHECK (!relax_func->params [i]->struct_info_ ->IsInstance <TupleStructInfoNode>() &&
950+ !sinfo.as <TupleStructInfoNode>())
951+ << " InternalError: "
952+ << " All tuple parameters should be expanded before this point in FuseTIR. "
953+ << " However, argument " << arg << " with struct info " << arg->struct_info_
954+ << " is passed as argument " << i << " to Primitive Relax function " << old_gv
955+ << " , which expects parameter " << relax_func->params [i] << " to have struct info "
956+ << relax_func->params [i]->struct_info_ ;
957+
958+ if (const auto * shape = sinfo.as <ShapeStructInfoNode>()) {
959+ CHECK (shape->values .defined ())
960+ << " FuseTIR requires all shape input has struct_info value." ;
961+ for (const PrimExpr& prim_value : shape->values .value ()) {
962+ CHECK (prim_value->IsInstance <tir::VarNode>())
963+ << " All shape inputs are expected to be single tir var." ;
964+ tir_vars.push_back (prim_value);
997965 }
998- } else {
999- flattened.push_back (arg);
1000- }
966+ } else if (const auto * prim_value = sinfo.as <PrimStructInfoNode>()) {
967+ CHECK (prim_value->value .defined ())
968+ << " FuseTIR requires all R.Prim arguments to have a known value." ;
969+ PrimExpr expr = prim_value->value .value ();
970+ CHECK (expr->IsInstance <tir::VarNode>())
971+ << " FuseTIR currently requires all R.Prim arguments to provide a single tir::Var." ;
972+ tir_vars.push_back (expr);
1001973
1002- for (const Expr& e : flattened) {
1003- StructInfo sinfo = GetStructInfo (e);
1004- if (sinfo->IsInstance <TensorStructInfoNode>()) {
1005- arg_list.push_back (e);
1006- } else if (const auto * shape = sinfo.as <ShapeStructInfoNode>()) {
1007- CHECK (shape->values .defined ())
1008- << " FuseTIR requires all shape input has struct_info value." ;
1009- for (const PrimExpr& prim_value : shape->values .value ()) {
1010- CHECK (prim_value->IsInstance <tir::VarNode>())
1011- << " All shape inputs are expected to be single tir var." ;
1012- tir_vars.push_back (prim_value);
1013- }
1014- } else {
1015- LOG (FATAL) << " The flattened arg is expected to be either tensor or shape, but got "
1016- << sinfo->GetTypeKey ();
1017- }
974+ } else {
975+ arg_list.push_back (arg);
1018976 }
1019977 }
1020978 // Step b. Create call_tir
@@ -1042,23 +1000,6 @@ class TIRFuseMutator : public ExprMutator {
10421000 return call;
10431001 }
10441002
1045- /* ********* Helper Functions **********/
1046-
1047- /* ! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */
1048- Array<Expr> FlattenArg (const Expr& arg) {
1049- if (const auto * tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(arg)) {
1050- Array<Expr> arg_list;
1051- for (size_t i = 0 ; i < tuple_sinfo->fields .size (); ++i) {
1052- Expr new_arg = builder_->Emit (TupleGetItem (arg, i));
1053- Array<Expr> flattened = FlattenArg (new_arg);
1054- arg_list.insert (arg_list.end (), flattened.begin (), flattened.end ());
1055- }
1056- return arg_list;
1057- } else {
1058- return {arg};
1059- }
1060- }
1061-
10621003 private:
10631004 /* ! \brief The IRModule */
10641005 const IRModule& mod_;
@@ -1076,10 +1017,17 @@ namespace transform {
10761017Pass FuseTIR () {
10771018 runtime::TypedPackedFunc<IRModule (IRModule, PassContext)> pass_func = //
10781019 [=](IRModule m, PassContext pc) { return relax::FuseTIR (m); };
1079- return CreateModulePass (/* pass_function=*/ pass_func, //
1080- /* opt_level=*/ 0 , //
1081- /* pass_name=*/ " FuseTIR" , //
1082- /* required=*/ {});
1020+ auto inner_pass = CreateModulePass (/* pass_function=*/ pass_func, //
1021+ /* opt_level=*/ 0 , //
1022+ /* pass_name=*/ " FuseTIRInner" , //
1023+ /* required=*/ {});
1024+ return tvm::transform::Sequential (
1025+ {
1026+ ExpandTupleArguments (),
1027+ RemoveUnusedParameters (),
1028+ inner_pass,
1029+ },
1030+ " FuseTIR" );
10831031}
10841032
10851033TVM_REGISTER_GLOBAL (" relax.transform.FuseTIR" ).set_body_typed(FuseTIR);
0 commit comments