Skip to content

Commit 6656fb6

Browse files
committed
[Unity][Transform] Extract partial-tuple-usage from FuseTIR
Prior to this commit, the `FuseTIR` pass explicitly tracked usage of tuple arguments, to minimize the set of arguments provided to each kernel. The additional tgracking and handling of partially-used tuples makes it difficult to follow the primary changes being made by `FuseTIR`. This commit implements the same functionality in terms of the `ExpandTupleArguments` and `RemoveUnusedParameters` transforms, introduced in #16115 and #16116 respectively. By using these passes before the main `FuseOps` changes, partial tuple usage is already handled at that point. This commit is intended to minimize any changes to user-facing behavior, and so these pre-process passes are currently used internally by `FuseOps`. This may be avoided in the future by pulling this internal delegation out into a lowering pipeline.
1 parent 45eeb8c commit 6656fb6

File tree

2 files changed

+107
-159
lines changed

2 files changed

+107
-159
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 100 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -385,58 +385,45 @@ class FusedTIRConstructor : public ExprVisitor {
385385
: mod_(mod), func_name_(func_name) {}
386386

387387
void VisitExpr_(const FunctionNode* func) final {
388-
// Step 1. Create buffers for function params
389-
390-
// Record which fields in a tuple passed as a parameter are actually accessed by the function.
391-
std::unordered_set<const Object*> tuple_param;
392-
for (auto param : func->params) {
393-
if (GetStructInfo(param)->IsInstance<TupleStructInfoNode>()) {
394-
tuple_param.insert(param.get());
395-
}
396-
}
397-
398-
PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
399-
if (auto tup_get = e.as<TupleGetItemNode>();
400-
tup_get && tuple_param.count(tup_get->tuple.get())) {
401-
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
402-
}
403-
});
404-
388+
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
405389
for (const Var& relax_param : func->params) {
406-
auto sinfo = GetStructInfo(relax_param);
407-
if (sinfo->IsInstance<ShapeStructInfoNode>()) {
408-
// It's a symbolic shape var, no need to alloc Buffers.
409-
continue;
410-
}
411-
412-
auto [params, buffers] = [=]() {
413-
if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
414-
// Add only those tuple fields which are actually used by the function body into the
415-
// function parameters.
416-
int index = 0;
417-
Array<tir::Var> params;
418-
Array<tir::Buffer> buffers;
419-
for (auto i : func_info_.used_tuple_field_indices[relax_param.get()]) {
420-
auto [ret_params, ret_buffers] =
421-
CreateParamsAndBuffers(tuple->fields[i], relax_param->name_hint(), index);
422-
ICHECK_EQ(ret_params.size(), ret_buffers.size());
423-
// Adding tuple field results to the end of params and buffers.
424-
params.insert(params.end(), ret_params.begin(), ret_params.end());
425-
buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
426-
index += ret_params.size();
390+
size_t size_before = prim_func_params.size();
391+
CollectPrimFuncParams(relax_param, &prim_func_params);
392+
393+
auto param_buffers = [&]() -> Array<tir::Buffer> {
394+
Array<tir::Buffer> out;
395+
for (size_t i = size_before; i < prim_func_params.size(); i++) {
396+
if (auto buf = prim_func_params[i].as<tir::Buffer>()) {
397+
out.push_back(buf.value());
427398
}
428-
return std::make_pair(params, buffers);
429-
} else {
430-
return CreateParamsAndBuffers(sinfo, relax_param->name_hint());
431399
}
400+
return out;
432401
}();
433402

434-
ICHECK_EQ(params.size(), buffers.size());
435-
for (size_t i = 0; i < params.size(); ++i) {
436-
func_info_.buffer_map.Set(params[i], buffers[i]);
437-
func_info_.params.push_back(params[i]);
403+
func_info_.expr2buffers.Set(relax_param, param_buffers);
404+
}
405+
406+
// Move all scalar params after buffer params.
407+
std::stable_sort(prim_func_params.begin(), prim_func_params.end(),
408+
[](const auto& a, const auto& b) {
409+
bool a_is_var = a.template as<tir::VarNode>();
410+
bool b_is_var = b.template as<tir::VarNode>();
411+
return a_is_var < b_is_var;
412+
});
413+
414+
for (const auto& param : prim_func_params) {
415+
if (auto opt = param.as<tir::Buffer>()) {
416+
auto buffer = opt.value();
417+
// Differentiate buffer name and param name by adding prefix
418+
// `p_` to the buffer name. Every symbol should be unique in
419+
// TVMScript, and while they can be de-deplicated when
420+
// printed, it's more readable when done explicitly. Since
421+
// Buffer is used more than param it gets the name with better
422+
// readability.
423+
tir::Var param = tir::Var("p_" + buffer->name, PrimType(DataType::Handle()));
424+
func_info_.params.push_back(param);
425+
func_info_.buffer_map.Set(param, buffer);
438426
}
439-
func_info_.expr2buffers.Set(relax_param, buffers);
440427
}
441428

442429
// Step 2. Visit Function body and create intermediate buffers
@@ -458,13 +445,9 @@ class FusedTIRConstructor : public ExprVisitor {
458445
}
459446

460447
// Step 4. Append symbolic vars
461-
const relax::Var& last_relax_param = func->params.back();
462-
if (GetStructInfo(last_relax_param)->IsInstance<ShapeStructInfoNode>()) {
463-
auto [params, buffers] =
464-
CreateParamsAndBuffers(GetStructInfo(last_relax_param), last_relax_param->name_hint());
465-
ICHECK(buffers.empty());
466-
for (size_t i = 0; i < params.size(); ++i) {
467-
func_info_.params.push_back(params[i]);
448+
for (const auto& param : prim_func_params) {
449+
if (auto var = param.as<tir::Var>()) {
450+
func_info_.params.push_back(var.value());
468451
}
469452
}
470453

@@ -548,12 +531,7 @@ class FusedTIRConstructor : public ExprVisitor {
548531
int end_buf_idx = 0;
549532
const TupleType& tuple_type = Downcast<TupleType>(tuple_get_item->tuple->checked_type());
550533
for (int i = 0; i < tuple_get_item->index; ++i) {
551-
auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get());
552-
// If this tuple is not passed as a parameter, or if the field at the index i is actually
553-
// used, the corresponding buffer needs to be taken into account by this function.
554-
if (it == func_info_.used_tuple_field_indices.end() || it->second.count(i)) {
555-
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
556-
}
534+
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
557535
}
558536
end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
559537
func_info_.expr2buffers.Set(
@@ -719,64 +697,46 @@ class FusedTIRConstructor : public ExprVisitor {
719697
}
720698

721699
/*!
722-
* \brief Create an TIR func params and buffers with specified relax type and shape
700+
* \brief Collect TIR func params and buffers with specified relax type and shape
723701
* \param struct_info The struct info
724702
* \param name_hint The name hint for params and buffers
725-
* \param index The index used for unique name_hint if type is Tuple.
726-
* -1 means no need to add postfix since the relax param is not a Tuple.
727-
* \return The created TIR func params and buffers
703+
* \param out The vector into which to collect the params/buffers
728704
*/
729-
static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers(
730-
StructInfo struct_info, const String& name_hint, int index = -1) {
731-
Array<tir::Var> params;
732-
Array<tir::Buffer> buffers;
733-
// The symbolic shape params must be defined at the end of the param list.
734-
bool symbolic_shape_param_started = false;
705+
static void CollectPrimFuncParams(const Var& relax_param,
706+
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
707+
auto struct_info = GetStructInfo(relax_param);
708+
709+
CHECK(!struct_info.as<TupleStructInfoNode>())
710+
<< "InternalError: "
711+
<< "All tuple parameters should be expanded before this point in FuseTIR. "
712+
<< "However, parameter " << relax_param << " has struct info " << struct_info;
713+
714+
auto name_hint = relax_param->name_hint();
715+
735716
if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
736-
// Case 1. the relax param is a Tensor, we directly create a tir var and buffer
717+
// Case 1. The relax param is a Tensor, we directly create a tir var and buffer
737718
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
738-
ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape.";
739-
CHECK(!symbolic_shape_param_started)
740-
<< "The symbolic shape params must be defined at the end of the param "
741-
"list.";
742-
String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index);
719+
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
743720
DataType dtype = tensor->dtype;
744-
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name);
745-
// Differentiate buffer name and param name by adding prefix `v_` to param
746-
// Every symbol should be unique in TVMScript, and Buffer is used more than param
747-
// So we decide to make sure buffer names have better readability.
748-
tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle()));
749-
params.push_back(std::move(param));
750-
buffers.push_back(std::move(buffer));
751-
} else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
752-
// Case 2. the relax param is a Tuple, we recursively visit each field until it's a Tensor
753-
// Enable postfix
754-
CHECK(!symbolic_shape_param_started)
755-
<< "The symbolic shape params must be defined at the end of the param "
756-
"list.";
757-
if (index == -1) index = 0;
758-
for (size_t i = 0; i < tuple->fields.size(); ++i) {
759-
auto [ret_params, ret_buffers] = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
760-
ICHECK_EQ(ret_params.size(), ret_buffers.size());
761-
// Adding tuple field results to the end of params and buffers.
762-
params.insert(params.end(), ret_params.begin(), ret_params.end());
763-
buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
764-
index += ret_params.size();
765-
}
721+
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
722+
out->push_back(std::move(buffer));
723+
724+
} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
725+
// Case 2. The relax param is a scalar, we directly create a tir var
726+
ICHECK(prim_value->value->IsInstance<tir::VarNode>());
727+
out->push_back(Downcast<tir::Var>(prim_value->value));
728+
766729
} else if (const auto* shape_expr = struct_info.as<ShapeStructInfoNode>()) {
767-
// Case 3. the relax param is a scalar, we directly create a tir var
768-
symbolic_shape_param_started = true;
769-
ICHECK(index == -1) << "TypeError: The ShapeExprNode should not be in a Tuple field.";
730+
// Case 3. The relax param is a tuple of scalars, each represented as a tir var
770731
for (const auto& var : shape_expr->values.value()) {
771732
ICHECK(var->IsInstance<tir::VarNode>());
772-
params.push_back(Downcast<tir::Var>(var));
733+
out->push_back(Downcast<tir::Var>(var));
773734
}
774735
} else {
775736
ICHECK(false) << "TypeError: The param type of PrimFunc is expected to be Tensor, Tuple or "
776737
"ShapeExpr, but got "
777738
<< struct_info->GetTypeKey();
778739
}
779-
return std::make_pair(params, buffers);
780740
}
781741

782742
/*!
@@ -870,9 +830,6 @@ class FusedTIRConstructor : public ExprVisitor {
870830
/*! \brief The map from symbolic var to its corresponding var in the fused function */
871831
tir::SymbolicMatcher symbolic_var_matcher =
872832
tir::SymbolicMatcher(&analyzer, &symbolic_var_remap);
873-
874-
/*! \brief Record indices of tuple fields that are actually accessed. */
875-
std::unordered_map<const Object*, std::unordered_set<size_t>> used_tuple_field_indices;
876833
};
877834

878835
/*! \brief The IRModule */
@@ -987,34 +944,35 @@ class TIRFuseMutator : public ExprMutator {
987944
Array<PrimExpr> tir_vars;
988945
for (size_t i = 0; i < call->args.size(); ++i) {
989946
auto arg = call->args[i];
990-
Array<Expr> flattened;
991-
if (GetStructInfo(relax_func->params[i])->IsInstance<TupleStructInfoNode>()) {
992-
// Add only those tuple fields which are actually used by the function body
993-
auto tup_get_indices = GetTupleAccessedIndices(relax_func.get(), relax_func->params[i]);
994-
for (size_t tup_get_ind : tup_get_indices) {
995-
auto flattened_inner = FlattenArg(builder_->Emit(TupleGetItem(arg, tup_get_ind)));
996-
flattened.insert(flattened.end(), flattened_inner.begin(), flattened_inner.end());
947+
auto sinfo = GetStructInfo(arg);
948+
949+
ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
950+
!sinfo.as<TupleStructInfoNode>())
951+
<< "InternalError: "
952+
<< "All tuple parameters should be expanded before this point in FuseTIR. "
953+
<< "However, argument " << arg << " with struct info " << arg->struct_info_
954+
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
955+
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
956+
<< relax_func->params[i]->struct_info_;
957+
958+
if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
959+
CHECK(shape->values.defined())
960+
<< "FuseTIR requires all shape input has struct_info value.";
961+
for (const PrimExpr& prim_value : shape->values.value()) {
962+
CHECK(prim_value->IsInstance<tir::VarNode>())
963+
<< "All shape inputs are expected to be single tir var.";
964+
tir_vars.push_back(prim_value);
997965
}
998-
} else {
999-
flattened.push_back(arg);
1000-
}
966+
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
967+
CHECK(prim_value->value.defined())
968+
<< "FuseTIR requires all R.Prim arguments to have a known value.";
969+
PrimExpr expr = prim_value->value.value();
970+
CHECK(expr->IsInstance<tir::VarNode>())
971+
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
972+
tir_vars.push_back(expr);
1001973

1002-
for (const Expr& e : flattened) {
1003-
StructInfo sinfo = GetStructInfo(e);
1004-
if (sinfo->IsInstance<TensorStructInfoNode>()) {
1005-
arg_list.push_back(e);
1006-
} else if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
1007-
CHECK(shape->values.defined())
1008-
<< "FuseTIR requires all shape input has struct_info value.";
1009-
for (const PrimExpr& prim_value : shape->values.value()) {
1010-
CHECK(prim_value->IsInstance<tir::VarNode>())
1011-
<< "All shape inputs are expected to be single tir var.";
1012-
tir_vars.push_back(prim_value);
1013-
}
1014-
} else {
1015-
LOG(FATAL) << "The flattened arg is expected to be either tensor or shape, but got "
1016-
<< sinfo->GetTypeKey();
1017-
}
974+
} else {
975+
arg_list.push_back(arg);
1018976
}
1019977
}
1020978
// Step b. Create call_tir
@@ -1042,23 +1000,6 @@ class TIRFuseMutator : public ExprMutator {
10421000
return call;
10431001
}
10441002

1045-
/********** Helper Functions **********/
1046-
1047-
/*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */
1048-
Array<Expr> FlattenArg(const Expr& arg) {
1049-
if (const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(arg)) {
1050-
Array<Expr> arg_list;
1051-
for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
1052-
Expr new_arg = builder_->Emit(TupleGetItem(arg, i));
1053-
Array<Expr> flattened = FlattenArg(new_arg);
1054-
arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
1055-
}
1056-
return arg_list;
1057-
} else {
1058-
return {arg};
1059-
}
1060-
}
1061-
10621003
private:
10631004
/*! \brief The IRModule */
10641005
const IRModule& mod_;
@@ -1076,10 +1017,17 @@ namespace transform {
10761017
Pass FuseTIR() {
10771018
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
10781019
[=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
1079-
return CreateModulePass(/*pass_function=*/pass_func, //
1080-
/*opt_level=*/0, //
1081-
/*pass_name=*/"FuseTIR", //
1082-
/*required=*/{});
1020+
auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, //
1021+
/*opt_level=*/0, //
1022+
/*pass_name=*/"FuseTIRInner", //
1023+
/*required=*/{});
1024+
return tvm::transform::Sequential(
1025+
{
1026+
ExpandTupleArguments(),
1027+
RemoveUnusedParameters(),
1028+
inner_pass,
1029+
},
1030+
"FuseTIR");
10831031
}
10841032

10851033
TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def fused_exp_squeeze(x):
205205
with bb.function("main", [x]):
206206
with bb.dataflow():
207207
lv = bb.emit_te(fused_exp_squeeze, x)
208-
lv2 = bb.emit_te(fused_exp_squeeze, lv)
208+
lv2 = bb.call_te(fused_exp_squeeze, lv)
209209
gv = bb.emit_output(lv2)
210210
bb.emit_func_output(gv)
211211
return bb.get()
@@ -245,7 +245,7 @@ def fused_exp_exp_squeeze(x):
245245
x = relax.Var("x", R.Tensor([10, 20], "float32"))
246246
with bb.function("main", [x]):
247247
with bb.dataflow():
248-
lv = bb.emit_te(fused_exp_exp_squeeze, x)
248+
lv = bb.call_te(fused_exp_exp_squeeze, x)
249249
gv = bb.emit_output(lv)
250250
bb.emit_func_output(gv)
251251
return bb.get()
@@ -257,7 +257,7 @@ def test_fuse_with_tuple_as_param():
257257
def before():
258258
bb = relax.BlockBuilder()
259259
x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")]))
260-
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}):
260+
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}, private=True):
261261
with bb.dataflow():
262262
lv0 = bb.emit(relax.TupleGetItem(x, 0))
263263
lv1 = bb.emit(relax.TupleGetItem(x, 1))
@@ -300,7 +300,7 @@ def test_fuse_with_nested_tuple_as_param():
300300
def before():
301301
bb = relax.BlockBuilder()
302302
x = relax.Var("x", tuple_struct_info)
303-
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}):
303+
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}, private=True):
304304
with bb.dataflow():
305305
lv0 = bb.emit(relax.TupleGetItem(x, 0))
306306
lv0_exp = bb.emit_te(topi.exp, lv0)
@@ -373,7 +373,7 @@ def fused_exp_squeeze(x):
373373
with bb.function("main", [x]):
374374
with bb.dataflow():
375375
lv = bb.emit_te(fused_exp_squeeze, x)
376-
lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32"))
376+
lv2 = bb.call_te(topi.add, lv, relax.const(1, "float32"))
377377
gv = bb.emit_output(lv2)
378378
bb.emit_func_output(gv)
379379
return bb.get()
@@ -414,7 +414,7 @@ def fused_add_exp_squeeze(x, y):
414414
x = relax.Var("x", R.Tensor([10, 20], "float32"))
415415
with bb.function("main", [x]):
416416
with bb.dataflow():
417-
lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
417+
lv = bb.call_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
418418
gv = bb.emit_output(lv)
419419
bb.emit_func_output(gv)
420420
return bb.get()
@@ -1268,7 +1268,7 @@ def reshape(
12681268
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
12691269
]
12701270

1271-
@R.function
1271+
@R.function(private=True)
12721272
def fused_reshape(
12731273
lv: R.Tuple(
12741274
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")

0 commit comments

Comments
 (0)