Skip to content

Commit 4b60b69

Browse files
committed
[Unity][Transform] Extract partial-tuple-usage from FuseTIR
Prior to this commit, the `FuseTIR` pass explicitly tracked usage of tuple arguments, to minimize the set of arguments provided to each kernel. The additional tgracking and handling of partially-used tuples makes it difficult to follow the primary changes being made by `FuseTIR`. This commit implements the same functionality in terms of the `ExpandTupleArguments` and `RemoveUnusedParameters` transforms, introduced in #16115 and #16116 respectively. By using these passes before the main `FuseOps` changes, partial tuple usage is already handled at that point. This commit is intended to minimize any changes to user-facing behavior, and so these pre-process passes are currently used internally by `FuseOps`. This may be avoided in the future by pulling this internal delegation out into a lowering pipeline.
1 parent 4e505c0 commit 4b60b69

File tree

2 files changed

+18
-25
lines changed

2 files changed

+18
-25
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -376,13 +376,6 @@ class FusedTIRConstructor : public ExprVisitor {
376376
}
377377
}
378378

379-
PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
380-
if (auto tup_get = e.as<TupleGetItemNode>();
381-
tup_get && tuple_param.count(tup_get->tuple.get())) {
382-
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
383-
}
384-
});
385-
386379
for (const Var& relax_param : func->params) {
387380
auto sinfo = GetStructInfo(relax_param);
388381
if (sinfo->IsInstance<ShapeStructInfoNode>()) {
@@ -397,7 +390,7 @@ class FusedTIRConstructor : public ExprVisitor {
397390
int index = 0;
398391
Array<tir::Var> params;
399392
Array<tir::Buffer> buffers;
400-
for (auto i : func_info_.used_tuple_field_indices[relax_param.get()]) {
393+
for (size_t i = 0; i < tuple->fields.size(); i++) {
401394
auto [ret_params, ret_buffers] =
402395
CreateParamsAndBuffers(tuple->fields[i], relax_param->name_hint(), index);
403396
ICHECK_EQ(ret_params.size(), ret_buffers.size());
@@ -529,12 +522,7 @@ class FusedTIRConstructor : public ExprVisitor {
529522
int end_buf_idx = 0;
530523
const TupleType& tuple_type = Downcast<TupleType>(tuple_get_item->tuple->checked_type());
531524
for (int i = 0; i < tuple_get_item->index; ++i) {
532-
auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get());
533-
// If this tuple is not passed as a parameter, or if the field at the index i is actually
534-
// used, the corresponding buffer needs to be taken into account by this function.
535-
if (it == func_info_.used_tuple_field_indices.end() || it->second.count(i)) {
536-
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
537-
}
525+
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
538526
}
539527
end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
540528
func_info_.expr2buffers.Set(
@@ -835,8 +823,6 @@ class FusedTIRConstructor : public ExprVisitor {
835823
std::string global_name = "fused";
836824
/*! \brief The map from symbolic var to its corresponding var in the fused function */
837825
tir::SymbolicMatcher symbolic_var_matcher = tir::SymbolicMatcher(&symbolic_var_remap);
838-
/*! \brief Record indices of tuple fields that are actually accessed. */
839-
std::unordered_map<const Object*, std::unordered_set<size_t>> used_tuple_field_indices;
840826
};
841827

842828
/*! \brief The IRModule */
@@ -1040,10 +1026,17 @@ namespace transform {
10401026
Pass FuseTIR() {
10411027
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
10421028
[=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
1043-
return CreateModulePass(/*pass_function=*/pass_func, //
1044-
/*opt_level=*/0, //
1045-
/*pass_name=*/"FuseTIR", //
1046-
/*required=*/{});
1029+
auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, //
1030+
/*opt_level=*/0, //
1031+
/*pass_name=*/"FuseTIRInner", //
1032+
/*required=*/{});
1033+
return tvm::transform::Sequential(
1034+
{
1035+
ExpandTupleArguments(),
1036+
RemoveUnusedParameters(),
1037+
inner_pass,
1038+
},
1039+
"FuseTIR");
10471040
}
10481041

10491042
TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def fused_exp_squeeze(x):
205205
with bb.function("main", [x]):
206206
with bb.dataflow():
207207
lv = bb.emit_te(fused_exp_squeeze, x)
208-
lv2 = bb.emit_te(fused_exp_squeeze, lv)
208+
lv2 = bb.call_te(fused_exp_squeeze, lv)
209209
gv = bb.emit_output(lv2)
210210
bb.emit_func_output(gv)
211211
return bb.get()
@@ -245,7 +245,7 @@ def fused_exp_exp_squeeze(x):
245245
x = relax.Var("x", R.Tensor([10, 20], "float32"))
246246
with bb.function("main", [x]):
247247
with bb.dataflow():
248-
lv = bb.emit_te(fused_exp_exp_squeeze, x)
248+
lv = bb.call_te(fused_exp_exp_squeeze, x)
249249
gv = bb.emit_output(lv)
250250
bb.emit_func_output(gv)
251251
return bb.get()
@@ -373,7 +373,7 @@ def fused_exp_squeeze(x):
373373
with bb.function("main", [x]):
374374
with bb.dataflow():
375375
lv = bb.emit_te(fused_exp_squeeze, x)
376-
lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32"))
376+
lv2 = bb.call_te(topi.add, lv, relax.const(1, "float32"))
377377
gv = bb.emit_output(lv2)
378378
bb.emit_func_output(gv)
379379
return bb.get()
@@ -414,7 +414,7 @@ def fused_add_exp_squeeze(x, y):
414414
x = relax.Var("x", R.Tensor([10, 20], "float32"))
415415
with bb.function("main", [x]):
416416
with bb.dataflow():
417-
lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
417+
lv = bb.call_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
418418
gv = bb.emit_output(lv)
419419
bb.emit_func_output(gv)
420420
return bb.get()
@@ -1115,7 +1115,7 @@ def reshape(
11151115
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
11161116
]
11171117

1118-
@R.function
1118+
@R.function(private=True)
11191119
def fused_reshape(
11201120
lv: R.Tuple(
11211121
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")

0 commit comments

Comments
 (0)