Skip to content

Commit 5bfca2e

Browse files
[Transform] Modify FuseTIR pass to propagate buffer attributes (#17075)
Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as `axis_separators` and `storage_scope`
1 parent 292ecfd commit 5bfca2e

File tree

2 files changed

+248
-20
lines changed

2 files changed

+248
-20
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 120 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator {
362362

363363
namespace relax {
364364

365+
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
366+
int num_inputs) {
367+
Array<Integer> ret;
368+
int last_idx = num_inputs;
369+
for (auto idx : inplace_indices) {
370+
int i = idx.IntValue();
371+
if (i >= 0) {
372+
ret.push_back(Integer(i));
373+
} else {
374+
CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices is -1, but got " << i;
375+
ret.push_back(Integer(last_idx));
376+
last_idx++;
377+
}
378+
}
379+
380+
return ret;
381+
}
382+
383+
class RelaxToTIRVarMapCollector : public ExprVisitor {
384+
public:
385+
explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
386+
static Map<Expr, tir::Buffer> Collect(const IRModule& mod, const Function& func) {
387+
RelaxToTIRVarMapCollector visitor(mod);
388+
visitor(func->body);
389+
return visitor.relax_to_tir_var_map_;
390+
}
391+
392+
private:
393+
void VisitBinding_(const VarBindingNode* binding) final {
394+
current_var_ = binding->var;
395+
ExprVisitor::VisitBinding_(binding);
396+
}
397+
398+
void VisitExpr_(const CallNode* call) {
399+
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
400+
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
401+
402+
ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
403+
<< "Only call_tir and call_tir_inplace are supported in primitive function, but got: "
404+
<< GetRef<Expr>(call);
405+
CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_);
406+
}
407+
408+
void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) {
409+
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
410+
tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
411+
const auto& buffer_map = prim_func_->buffer_map;
412+
const auto& tir_args = prim_func_->params;
413+
414+
const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
415+
416+
Array<Expr> relax_results;
417+
if (lhs_var->IsInstance<TupleNode>()) {
418+
relax_results = Downcast<Tuple>(lhs_var)->fields;
419+
} else {
420+
CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be either tuple or var";
421+
relax_results = {Downcast<Var>(lhs_var)};
422+
}
423+
424+
size_t num_inputs = relax_args.size();
425+
size_t num_outputs = relax_results.size();
426+
427+
Array<Integer> output_idxs;
428+
if (in_place) {
429+
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
430+
CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
431+
output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs);
432+
} else {
433+
for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
434+
output_idxs.push_back(i);
435+
}
436+
}
437+
438+
// If the `expr` is already seen (present in the map), validate whether the mapped buffer is
439+
// structurally equal to the `new_buf` passed
440+
auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) {
441+
if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) {
442+
ICHECK(StructuralEqual()((*it).second, new_buf))
443+
<< "Inconsistent buffers " << (*it).second << " and " << new_buf
444+
<< " mapped to the same relax var: " << expr;
445+
}
446+
};
447+
for (size_t i = 0; i < tir_args.size(); ++i) {
448+
const auto& tir_var = tir_args[i];
449+
if (auto tir_buffer = buffer_map.Get(tir_var)) {
450+
if (i < num_inputs) {
451+
const auto& relax_var = relax_args[i];
452+
ValidateBufferCompatibility(tir_buffer.value(), relax_var);
453+
relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
454+
}
455+
if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i);
456+
it != output_idxs.end()) {
457+
int result_idx = it - output_idxs.begin();
458+
const auto& relax_var = relax_results[result_idx];
459+
ValidateBufferCompatibility(tir_buffer.value(), relax_var);
460+
relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
461+
}
462+
}
463+
}
464+
}
465+
466+
private:
467+
/*! \brief The IRModule */
468+
const IRModule& mod_;
469+
Map<Expr, tir::Buffer> relax_to_tir_var_map_;
470+
Var current_var_;
471+
};
472+
365473
class FusedTIRConstructor : public ExprVisitor {
366474
public:
367475
/*!
@@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor {
391499
: mod_(mod), func_name_(func_name) {}
392500

393501
void VisitExpr_(const FunctionNode* func) final {
502+
auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef<Function>(func));
394503
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
395504
for (const Var& relax_param : func->params) {
396505
size_t size_before = prim_func_params.size();
397-
CollectPrimFuncParams(relax_param, &prim_func_params);
506+
CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param));
398507

399508
auto param_buffers = [&]() -> Array<tir::Buffer> {
400509
Array<tir::Buffer> out;
@@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor {
676785
MapArgsToBuffer(arg_list, buffer_list);
677786
}
678787

679-
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
680-
int num_inputs) {
681-
Array<Integer> ret;
682-
int last_idx = num_inputs;
683-
for (auto idx : inplace_indices) {
684-
int i = idx.IntValue();
685-
if (i >= 0) {
686-
ret.push_back(Integer(i));
687-
} else {
688-
ret.push_back(Integer(last_idx));
689-
last_idx++;
690-
}
691-
}
692-
693-
return ret;
694-
}
695-
696788
static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
697789
const Array<Integer>& output_indices) {
698790
size_t n = func->params.size();
@@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor {
799891
* \param out The vector into which to collect the params/buffers
800892
*/
801893
static void CollectPrimFuncParams(const Var& relax_param,
802-
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
894+
std::vector<Variant<tir::Var, tir::Buffer>>* out,
895+
const tvm::runtime::Optional<tir::Buffer>& tir_buffer_param) {
803896
auto struct_info = GetStructInfo(relax_param);
804897

805898
CHECK(!struct_info.as<TupleStructInfoNode>())
@@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor {
814907
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
815908
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
816909
DataType dtype = tensor->dtype;
817-
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
910+
tir::Buffer buffer;
911+
if (tir_buffer_param.defined()) {
912+
buffer =
913+
tir::decl_buffer(shape_expr->values, dtype, name_hint, tir_buffer_param.value().scope(),
914+
tir_buffer_param.value()->axis_separators);
915+
} else {
916+
buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
917+
}
818918
out->push_back(std::move(buffer));
819919

820920
} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import pytest
19+
1820
import tvm
1921
import tvm.testing
2022
from tvm import relax, topi
@@ -2314,5 +2316,131 @@ def take(
23142316
_check(Before, Before)
23152317

23162318

2319+
def test_fuse_with_axis_separators():
2320+
@I.ir_module
2321+
class Before:
2322+
@T.prim_func(private=True)
2323+
def add(a: T.handle, b: T.handle, c: T.handle):
2324+
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2325+
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2326+
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2327+
2328+
for iters in T.grid(T.int64(16), T.int64(32)):
2329+
with T.block("compute"):
2330+
i, j = T.axis.remap("SS", iters)
2331+
C[i, j] = A[i, j] + B[i, j]
2332+
2333+
@R.function(private=True)
2334+
def fused_function(
2335+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2336+
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2337+
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2338+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2339+
R.func_attr({"Primitive": 1})
2340+
cls = Before
2341+
with R.dataflow():
2342+
w = R.call_tir(
2343+
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
2344+
)
2345+
out = R.call_tir(
2346+
cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
2347+
)
2348+
R.output(out)
2349+
return out
2350+
2351+
@R.function
2352+
def main(
2353+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2354+
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2355+
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2356+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2357+
cls = Before
2358+
with R.dataflow():
2359+
gv = cls.fused_function(x, y, z)
2360+
R.output(gv)
2361+
return gv
2362+
2363+
@I.ir_module
2364+
class Expected:
2365+
@T.prim_func(private=True)
2366+
def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
2367+
T.func_attr({"tir.noalias": True})
2368+
X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2369+
Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2370+
Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2371+
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2372+
Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
2373+
for iters in T.grid(*X.shape):
2374+
with T.block("compute_Y"):
2375+
i, j = T.axis.remap("SS", iters)
2376+
Temp[i, j] = X[i, j] + Y[i, j]
2377+
2378+
for iters in T.grid(*X.shape):
2379+
with T.block("compute_Z"):
2380+
i, j = T.axis.remap("SS", iters)
2381+
C[i, j] = Temp[i, j] + Z[i, j]
2382+
2383+
@R.function
2384+
def main(
2385+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2386+
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2387+
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2388+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2389+
cls = Expected
2390+
with R.dataflow():
2391+
gv = R.call_tir(
2392+
cls.fused_function,
2393+
[x, y, z],
2394+
out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
2395+
)
2396+
R.output(gv)
2397+
return gv
2398+
2399+
_check(Before, Expected)
2400+
2401+
2402+
def test_fuse_with_axis_separators_inconsistent_buffer_mapping():
2403+
@I.ir_module
2404+
class Before:
2405+
@T.prim_func(private=True)
2406+
def mul(a: T.handle, b: T.handle, c: T.handle):
2407+
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2408+
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[])
2409+
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
2410+
2411+
for iters in T.grid(T.int64(16), T.int64(32)):
2412+
with T.block("compute"):
2413+
i, j = T.axis.remap("SS", iters)
2414+
C[i, j] = A[i, j] * B[i, j]
2415+
2416+
@R.function(private=True)
2417+
def fused_function(
2418+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2419+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2420+
R.func_attr({"Primitive": 1})
2421+
cls = Before
2422+
with R.dataflow():
2423+
out = R.call_tir(
2424+
cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
2425+
)
2426+
R.output(out)
2427+
return out
2428+
2429+
@R.function
2430+
def main(
2431+
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
2432+
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
2433+
cls = Before
2434+
with R.dataflow():
2435+
gv = cls.fused_function(x)
2436+
R.output(gv)
2437+
return gv
2438+
2439+
with pytest.raises(
2440+
tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same relax var:.*"
2441+
):
2442+
relax.transform.FuseTIR()(Before)
2443+
2444+
23172445
if __name__ == "__main__":
23182446
tvm.testing.main()

0 commit comments

Comments
 (0)