Skip to content

Commit 6a3fadc

Browse files
authored
[Unity][Transform] Handle call_tir_inplace in FuseTIR and FuseOps (#16487)
* WIP initial commit * Handle in-place calls in FuseTIR * Formatting * Add test case for FuseOps * Address review comments related to clarity * Use a set to ensure in-place indices will be unique * Add test case where PrimFunc is used both in-place and DPS * Explicitly check for duplicate index
1 parent 6a2459f commit 6a3fadc

File tree

4 files changed

+600
-33
lines changed

4 files changed

+600
-33
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ class GraphCreator : public ExprVisitor {
183183
ICHECK_NOTNULL(binding_var_node);
184184

185185
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
186+
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
187+
186188
OpPatternKind pattern = OpPatternKind::kOpaque;
187189
Array<Expr> args = call->args;
188190

@@ -191,7 +193,7 @@ class GraphCreator : public ExprVisitor {
191193
// - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we
192194
// recurse into the call expression.
193195
const auto* op = call->op.as<OpNode>();
194-
if (op == call_tir_op_.get()) {
196+
if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) {
195197
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
196198
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
197199

@@ -377,7 +379,8 @@ class FunctionCreator : public ExprMutator {
377379
* function accordingly
378380
* \param binding The binding to be appended
379381
* \note Allowed bindings are:
380-
* - VarBinding with value being a call node calling `relax.call_tir`.
382+
* - VarBinding with value being a call node calling `relax.call_tir` or
383+
* `relax.call_tir_inplace`.
381384
* - VarBinding with value being a tuple-get-item node.
382385
* // TODO(tvm-team): handle match shape
383386
*/
@@ -387,7 +390,8 @@ class FunctionCreator : public ExprMutator {
387390

388391
if (const auto* var_binding = binding.as<VarBindingNode>()) {
389392
if (const auto* call = var_binding->value.as<CallNode>()) {
390-
if (call->op == Op::Get("relax.call_tir")) {
393+
if (call->op == Op::Get("relax.call_tir") ||
394+
call->op == Op::Get("relax.call_tir_inplace")) {
391395
// Update the name of the function.
392396
name_hint_ = name_hint_ + "_" + Downcast<GlobalVar>(call->args[0])->name_hint;
393397

src/relax/transform/fuse_tir.cc

Lines changed: 128 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* under the License.
1818
*/
1919
#include <tvm/relax/analysis.h>
20+
#include <tvm/relax/attrs/op.h>
2021
#include <tvm/relax/expr_functor.h>
2122
#include <tvm/relax/struct_info.h>
2223
#include <tvm/relax/transform.h>
@@ -367,17 +368,22 @@ class FusedTIRConstructor : public ExprVisitor {
367368
* \brief Construct a fused TIR PrimFunc from a relax sub-function
368369
* \param mod The IRModule
369370
* \param gv The global var of relax subfunction to be fused into one PrimFunc
370-
* \return The fused TIR PrimFunc
371+
* \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call)
371372
*/
372-
static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
373+
static std::pair<tir::PrimFunc, Array<Integer>> GetFusedTIR(const IRModule& mod,
374+
const GlobalVar& gv) {
373375
FusedTIRConstructor visitor(mod, gv->name_hint);
374376
BaseFunc f = mod->Lookup(gv);
375377
CHECK(f->IsInstance<relax::FunctionNode>())
376378
<< "Expected relax functions, but got: " << f->GetTypeKey();
377379
CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
378380
<< "Expected a function with attr `kPrimitive`";
379381
visitor(Downcast<relax::Function>(f));
380-
return visitor.fused_tir_;
382+
Array<Integer> inplace_indices;
383+
for (size_t idx : visitor.inplace_indices_) {
384+
inplace_indices.push_back(Integer(idx));
385+
}
386+
return {visitor.fused_tir_, inplace_indices};
381387
}
382388

383389
private:
@@ -438,9 +444,38 @@ class FusedTIRConstructor : public ExprVisitor {
438444
auto it = func_info_.expr2buffers.find(body);
439445
ICHECK(it != func_info_.expr2buffers.end())
440446
<< "Fail to detect output buffers for function body";
447+
441448
const Array<tir::Buffer>& buffers = (*it).second;
449+
450+
// map of input buffers to indices (helpful for detecting in-place inputs)
451+
std::unordered_map<tir::Buffer, size_t, ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
452+
std::unordered_map<tir::Var, size_t, ObjectPtrHash, ObjectPtrEqual> input_to_idx;
453+
for (size_t i = 0; i < func_info_.params.size(); i++) {
454+
input_to_idx[func_info_.params[i]] = i;
455+
}
456+
for (auto [var, buffer] : func_info_.buffer_map) {
457+
if (auto it = input_to_idx.find(var); it != input_to_idx.end()) {
458+
buffer_to_idx[buffer] = (*it).second;
459+
}
460+
}
461+
462+
// numbered separately because the number of output *vars* might differ from the
463+
// number of outputs if there are in-place inputs
464+
int out_idx = 0;
442465
for (size_t i = 0; i < buffers.size(); ++i) {
443-
tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle()));
466+
// Do not add output vars for in-place inputs
467+
// (i.e., already listed in the buffer map. This would result
468+
// in duplicates in the buffer map otherwise)
469+
if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) {
470+
auto idx = (*it).second;
471+
CHECK(!inplace_indices_.count(idx))
472+
<< "In-place index " << idx << " used twice! An argument must be aliased.";
473+
inplace_indices_.insert(idx);
474+
continue;
475+
}
476+
477+
tir::Var param = tir::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle()));
478+
out_idx++;
444479
func_info_.buffer_map.Set(param, buffers[i]);
445480
func_info_.params.push_back(param);
446481
func_info_.output_buffers.insert(buffers[i].get());
@@ -476,8 +511,11 @@ class FusedTIRConstructor : public ExprVisitor {
476511
void VisitExpr_(const CallNode* call) final {
477512
ExprVisitor::VisitExpr_(call);
478513
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
479-
ICHECK(call->op == call_tir_op_)
480-
<< "Only call_tir is supported in primitive function, but got: " << GetRef<Expr>(call);
514+
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
515+
516+
ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
517+
<< "Only call_tir and call_tir_inplace are supported in primitive function, but got: "
518+
<< GetRef<Expr>(call);
481519

482520
// Step 1. Get Global var and PrimFunc
483521
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
@@ -503,7 +541,7 @@ class FusedTIRConstructor : public ExprVisitor {
503541
MapInputBuffer(prim_func, call->args[1]);
504542
const Array<Array<PrimExpr>>& output_buffer_shapes = GetCallTIROutputShapes(call);
505543

506-
AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, output_buffer_shapes);
544+
AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes);
507545

508546
// Step 6. Update tir_vars
509547
if (call->args.size() > 2) {
@@ -566,7 +604,8 @@ class FusedTIRConstructor : public ExprVisitor {
566604
*/
567605
static Array<Array<PrimExpr>> GetCallTIROutputShapes(const CallNode* call) {
568606
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
569-
ICHECK(call->op.same_as(call_tir_op_));
607+
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
608+
ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_));
570609
ICHECK_EQ(call->sinfo_args.size(), 1);
571610
auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
572611
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
@@ -611,7 +650,7 @@ class FusedTIRConstructor : public ExprVisitor {
611650
}
612651
}
613652
}
614-
// Make sure every buffers are mapped.
653+
// Make sure every buffer is mapped.
615654
ICHECK_EQ(buffer_idx, buffers.size());
616655
}
617656

@@ -639,28 +678,49 @@ class FusedTIRConstructor : public ExprVisitor {
639678
MapArgsToBuffer(arg_list, buffer_list);
640679
}
641680

642-
static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func, size_t output_size) {
681+
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
682+
int num_inputs) {
683+
Array<Integer> ret;
684+
int last_idx = num_inputs;
685+
for (auto idx : inplace_indices) {
686+
int i = idx.IntValue();
687+
if (i >= 0) {
688+
ret.push_back(Integer(i));
689+
} else {
690+
ret.push_back(Integer(last_idx));
691+
last_idx++;
692+
}
693+
}
694+
695+
return ret;
696+
}
697+
698+
static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
699+
const Array<Integer>& output_indices) {
643700
size_t n = func->params.size();
644701
int symbolic_var_index = -1;
702+
size_t output_size = output_indices.size();
645703
ICHECK_GE(n, output_size);
646-
for (size_t i = 0; i < n; ++i) {
647-
const tir::Var& param = func->params[i];
704+
705+
Array<tir::Var> ret;
706+
for (auto idx : output_indices) {
707+
int i = idx.IntValue();
708+
const tir::Var& param = func->params[static_cast<size_t>(i)];
648709
if (param->dtype.is_int() || param->dtype.is_uint()) {
649710
if (symbolic_var_index == -1) symbolic_var_index = i;
650711
} else if (param->dtype.is_handle()) {
651712
CHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the "
652713
"parameter list.";
714+
ret.push_back(param);
653715
} else {
654716
LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: "
655717
<< param->dtype;
656718
}
657719
}
720+
658721
size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index;
659722
ICHECK_GE(end_index, output_size);
660-
size_t begin_index = end_index - output_size;
661-
Array<tir::Var> output_params{func->params.begin() + begin_index,
662-
func->params.begin() + end_index};
663-
return output_params;
723+
return ret;
664724
}
665725

666726
/*!
@@ -670,18 +730,39 @@ class FusedTIRConstructor : public ExprVisitor {
670730
* \param func The old TIR PrimFunc
671731
* \param output_shapes The shape of output params.
672732
*/
673-
void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
733+
void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func,
674734
const Array<Array<PrimExpr>>& output_shapes) {
735+
bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace"));
736+
675737
size_t n = func->params.size();
738+
int num_inputs = Downcast<Tuple>(call->args[1])->fields.size();
676739
size_t output_size = output_shapes.size();
677740
ICHECK_GE(n, output_size);
678-
// Allocate intermediate buffer
679-
Array<tir::Buffer> alloc_buffers;
680-
Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_size);
741+
Array<tir::Buffer> output_buffers;
742+
Array<Integer> output_idxs;
743+
if (is_inplace) {
744+
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
745+
CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
746+
output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices, num_inputs));
747+
} else {
748+
for (size_t i = 0; i < output_size; i++) {
749+
output_idxs.push_back(num_inputs + i);
750+
}
751+
}
752+
753+
Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_idxs);
754+
auto input_buffers = func_info_.expr2buffers.Get(call->args[1]);
681755
for (size_t i = 0; i < output_size; ++i) {
682756
const tir::Var& param = output_params[i];
683757
const tir::Buffer& buffer = func->buffer_map.at(param);
684758

759+
// if this is an inplace output, do not do an intermediate allocation
760+
if (output_idxs[i].IntValue() < num_inputs) {
761+
CHECK(input_buffers.defined()) << "Inplace functions must have some defined input";
762+
output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]);
763+
continue;
764+
}
765+
685766
auto unify_name_hints = [this, &buffer]() {
686767
String base_name = buffer->name;
687768
String unique_name = base_name + "_intermediate";
@@ -703,14 +784,14 @@ class FusedTIRConstructor : public ExprVisitor {
703784
n->name = unify_name_hints();
704785
tir::Buffer new_buffer(n);
705786
func_info_.alloc_buffers.push_back(new_buffer);
706-
alloc_buffers.push_back(new_buffer);
787+
output_buffers.push_back(new_buffer);
707788

708789
// Match the shape of the output buffer with the shape
709790
func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape);
710791
func_info_.buffer_subst_map.Set(buffer, new_buffer);
711792
}
712793
// Update expr2buffers
713-
func_info_.expr2buffers.Set(expr, alloc_buffers);
794+
func_info_.expr2buffers.Set(GetRef<Expr>(call), output_buffers);
714795
}
715796

716797
/*!
@@ -858,6 +939,8 @@ class FusedTIRConstructor : public ExprVisitor {
858939
FuseFuncInfo func_info_;
859940
/*! \brief The tir function after fusion*/
860941
tir::PrimFunc fused_tir_;
942+
/*! \brief Indices of inputs that are used for in-place computation */
943+
std::unordered_set<size_t> inplace_indices_;
861944
};
862945

863946
std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) {
@@ -897,8 +980,11 @@ class TIRFuseMutator : public ExprMutator {
897980
for (const auto& [gv, func] : mod->functions) {
898981
// Only fuse primitive relax functions
899982
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
900-
tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
901-
mutator.fused_tir_funcs_.Set(gv, fused_tir);
983+
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv);
984+
mutator.fused_tir_funcs_.Set(gv, prim_func);
985+
if (!indices.empty()) {
986+
mutator.inplace_indices_.Set(gv, indices);
987+
}
902988
}
903989
}
904990

@@ -945,6 +1031,7 @@ class TIRFuseMutator : public ExprMutator {
9451031

9461032
Expr VisitExpr_(const CallNode* op) final {
9471033
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
1034+
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
9481035

9491036
Call call = Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));
9501037

@@ -985,26 +1072,34 @@ class TIRFuseMutator : public ExprMutator {
9851072
CHECK(prim_value->value.defined())
9861073
<< "FuseTIR requires all R.Prim arguments to have a known value.";
9871074
PrimExpr expr = prim_value->value.value();
988-
CHECK(expr->IsInstance<tir::VarNode>())
989-
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
1075+
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
1076+
"arguments to provide a single tir::Var.";
9901077
tir_vars.push_back(expr);
9911078

9921079
} else {
9931080
arg_list.push_back(arg);
9941081
}
9951082
}
996-
// Step b. Create call_tir
1083+
// Step b. Create call_tir or call_tir_inplace
9971084
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
9981085
if (!tir_vars.empty()) {
9991086
call_args.push_back(ShapeExpr(tir_vars));
10001087
}
1001-
return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)});
1088+
Op call_op = call_tir_op_;
1089+
Attrs call_attrs = call->attrs;
1090+
if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) {
1091+
call_op = call_tir_inplace_op_;
1092+
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
1093+
inplace_attrs->inplace_indices = (*it).second;
1094+
call_attrs = Attrs(inplace_attrs);
1095+
}
1096+
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
10021097
} else {
10031098
// Case 1.2. The callee function is not primitive, nothing to do.
10041099
return call;
10051100
}
1006-
} else if (call->op == call_tir_op_) {
1007-
// Case 2. It is a call_tir, re-emit the PrimFunc.
1101+
} else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
1102+
// Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
10081103
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
10091104
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
10101105
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
@@ -1023,6 +1118,9 @@ class TIRFuseMutator : public ExprMutator {
10231118
const IRModule& mod_;
10241119
/*! \brief The map from global var of primitive relax function to generated prim func. */
10251120
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
1121+
/*! \brief The map from global var of primitive relax function to in-place indices
1122+
* (if there are any). */
1123+
Map<GlobalVar, Array<Integer>> inplace_indices_;
10261124
};
10271125

10281126
IRModule FuseTIR(IRModule mod) {

0 commit comments

Comments
 (0)