Skip to content

Commit 643089d

Browse files
committed
[Transform] Modify FuseTIR pass to propagate buffer attributes
Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as `axis_separators` and `storage_scope`
1 parent 4183229 commit 643089d

File tree

2 files changed

+193
-21
lines changed

2 files changed

+193
-21
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 110 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
362362

363363
namespace relax {
364364

365+
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
366+
int num_inputs) {
367+
Array<Integer> ret;
368+
int last_idx = num_inputs;
369+
for (auto idx : inplace_indices) {
370+
int i = idx.IntValue();
371+
if (i >= 0) {
372+
ret.push_back(Integer(i));
373+
} else {
374+
ret.push_back(Integer(last_idx));
375+
last_idx++;
376+
}
377+
}
378+
379+
return ret;
380+
}
381+
382+
class RelaxToTIRVarMapCollector : public ExprVisitor {
383+
void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place = false) {
384+
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
385+
tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
386+
const auto& buffer_map = prim_func_->buffer_map;
387+
const auto& tir_args = prim_func_->params;
388+
389+
const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
390+
391+
Array<Expr> relax_results;
392+
if (lhs_var->IsInstance<TupleNode>()) {
393+
relax_results = Downcast<Tuple>(lhs_var)->fields;
394+
} else {
395+
CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be either tuple or var";
396+
relax_results = {Downcast<Var>(lhs_var)};
397+
}
398+
399+
size_t num_inputs = relax_args.size();
400+
size_t num_outputs = relax_results.size();
401+
402+
Array<Integer> output_idxs;
403+
if (in_place) {
404+
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
405+
CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
406+
output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs);
407+
} else {
408+
for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
409+
output_idxs.push_back(i);
410+
}
411+
}
412+
for (size_t i = 0; i < tir_args.size(); ++i) {
413+
const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
414+
if (i < num_inputs) {
415+
const auto& relax_var = Downcast<Var>(relax_args[i]);
416+
relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
417+
}
418+
if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it != output_idxs.end()) {
419+
int result_idx = it - output_idxs.begin();
420+
const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
421+
relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
422+
}
423+
}
424+
}
425+
426+
public:
427+
explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
428+
static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function& func) {
429+
RelaxToTIRVarMapCollector visitor(mod);
430+
visitor(func->body);
431+
return visitor.relax_to_tir_var_map_;
432+
}
433+
void VisitBinding_(const VarBindingNode* binding) final {
434+
const auto& lhs_var = binding->var;
435+
const auto& value = binding->value;
436+
if (const CallNode* call = value.as<CallNode>()) {
437+
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
438+
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
439+
440+
ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
441+
<< "Only call_tir and call_tir_inplace are supported in primitive function, but got: "
442+
<< GetRef<Expr>(call);
443+
if (call->op == call_tir_inplace_op_) {
444+
CollectVarMapping(call, lhs_var, /*in_place*/ true);
445+
} else {
446+
CollectVarMapping(call, lhs_var);
447+
}
448+
}
449+
}
450+
451+
private:
452+
/*! \brief The IRModule */
453+
const IRModule& mod_;
454+
// size_t call_num_inputs_ = -1;
455+
Map<Var, tir::Buffer> relax_to_tir_var_map_;
456+
};
457+
365458
class FusedTIRConstructor : public ExprVisitor {
366459
public:
367460
/*!
@@ -391,10 +484,15 @@ class FusedTIRConstructor : public ExprVisitor {
391484
: mod_(mod), func_name_(func_name) {}
392485

393486
void VisitExpr_(const FunctionNode* func) final {
487+
auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef<Function>(func));
394488
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
395489
for (const Var& relax_param : func->params) {
396490
size_t size_before = prim_func_params.size();
397-
CollectPrimFuncParams(relax_param, &prim_func_params);
491+
if (relax_to_tir_var_map.count(relax_param)) {
492+
CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map[relax_param]);
493+
} else {
494+
CollectPrimFuncParams(relax_param, &prim_func_params);
495+
}
398496

399497
auto param_buffers = [&]() -> Array<tir::Buffer> {
400498
Array<tir::Buffer> out;
@@ -676,23 +774,6 @@ class FusedTIRConstructor : public ExprVisitor {
676774
MapArgsToBuffer(arg_list, buffer_list);
677775
}
678776

679-
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
680-
int num_inputs) {
681-
Array<Integer> ret;
682-
int last_idx = num_inputs;
683-
for (auto idx : inplace_indices) {
684-
int i = idx.IntValue();
685-
if (i >= 0) {
686-
ret.push_back(Integer(i));
687-
} else {
688-
ret.push_back(Integer(last_idx));
689-
last_idx++;
690-
}
691-
}
692-
693-
return ret;
694-
}
695-
696777
static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
697778
const Array<Integer>& output_indices) {
698779
size_t n = func->params.size();
@@ -798,8 +879,9 @@ class FusedTIRConstructor : public ExprVisitor {
798879
* \param name_hint The name hint for params and buffers
799880
* \param out The vector into which to collect the params/buffers
800881
*/
801-
static void CollectPrimFuncParams(const Var& relax_param,
802-
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
882+
static void CollectPrimFuncParams(
883+
const Var& relax_param, std::vector<Variant<tir::Var, tir::Buffer>>* out,
884+
const tvm::runtime::Optional<tir::Buffer>& tir_buffer_param = NullOpt) {
803885
auto struct_info = GetStructInfo(relax_param);
804886

805887
CHECK(!struct_info.as<TupleStructInfoNode>())
@@ -814,7 +896,14 @@ class FusedTIRConstructor : public ExprVisitor {
814896
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
815897
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
816898
DataType dtype = tensor->dtype;
817-
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
899+
tir::Buffer buffer;
900+
if (tir_buffer_param.defined()) {
901+
buffer =
902+
tir::decl_buffer(shape_expr->values, dtype, name_hint, tir_buffer_param.value().scope(),
903+
tir_buffer_param.value()->axis_separators);
904+
} else {
905+
buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
906+
}
818907
out->push_back(std::move(buffer));
819908

820909
} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,5 +2314,88 @@ def take(
23142314
_check(Before, Before)
23152315

23162316

2317+
def test_fuse_with_axis_separators():
2318+
@I.ir_module
2319+
class Before:
2320+
@T.prim_func(private=True)
2321+
def add(a: T.handle, b: T.handle, c: T.handle):
2322+
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2323+
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2324+
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2325+
2326+
for iters in T.grid(T.int64(16), T.int64(32)):
2327+
with T.block("compute"):
2328+
i, j = T.axis.remap("SS", iters)
2329+
C[i, j] = A[i, j] + B[i, j]
2330+
2331+
@R.function(private=True)
2332+
def fused_function(
2333+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2334+
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2335+
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2336+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2337+
R.func_attr({"Primitive": 1})
2338+
cls = Before
2339+
with R.dataflow():
2340+
w = R.call_tir(
2341+
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
2342+
)
2343+
out = R.call_tir(
2344+
cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
2345+
)
2346+
R.output(out)
2347+
return out
2348+
2349+
@R.function
2350+
def main(
2351+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2352+
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2353+
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2354+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2355+
cls = Before
2356+
with R.dataflow():
2357+
gv = cls.fused_function(x, y, z)
2358+
R.output(gv)
2359+
return gv
2360+
2361+
@I.ir_module
2362+
class Expected:
2363+
@T.prim_func(private=True)
2364+
def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
2365+
T.func_attr({"tir.noalias": True})
2366+
X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2367+
Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2368+
Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2369+
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2370+
Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
2371+
for iters in T.grid(*X.shape):
2372+
with T.block("compute_Y"):
2373+
i, j = T.axis.remap("SS", iters)
2374+
Temp[i, j] = X[i, j] + Y[i, j]
2375+
2376+
for iters in T.grid(*X.shape):
2377+
with T.block("compute_Z"):
2378+
i, j = T.axis.remap("SS", iters)
2379+
C[i, j] = Temp[i, j] + Z[i, j]
2380+
2381+
@R.function
2382+
def main(
2383+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2384+
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2385+
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2386+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2387+
cls = Expected
2388+
with R.dataflow():
2389+
gv = R.call_tir(
2390+
cls.fused_function,
2391+
[x, y, z],
2392+
out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
2393+
)
2394+
R.output(gv)
2395+
return gv
2396+
2397+
_check(Before, Expected)
2398+
2399+
23172400
if __name__ == "__main__":
23182401
tvm.testing.main()

0 commit comments

Comments
 (0)