Skip to content

Commit 2bc7f47

Browse files
committed
[Relax][Transform] Handle identical PrimFunc with distinct VDevice
Prior to this commit, if an `IRModule` contained two expressions, where the types of the arguments differed only by the `VDevice`, these would be legalized to produce a single PrimFunc. This PrimFunc would have the a `tvm::attr::kTarget` annotation specific to one of those expressions, and would be incorrect for use in the other location. This commit updates the `LegalizeOps` transform to handle this case, producing multiple TIR PrimFuncs if required by the `VDevice` annotations.
1 parent 6252fa5 commit 2bc7f47

File tree

3 files changed

+204
-8
lines changed

3 files changed

+204
-8
lines changed

src/relax/transform/legalize_ops.cc

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/relax/op_attr_types.h>
2929
#include <tvm/relax/struct_info.h>
3030
#include <tvm/relax/transform.h>
31+
#include <tvm/tir/transform.h>
3132

3233
namespace tvm {
3334
namespace relax {
@@ -83,7 +84,12 @@ class LegalizeMutator : public ExprMutator {
8384
builder_->UpdateFunction(gv, f);
8485
}
8586
}
86-
return builder_->GetContextIRModule();
87+
IRModule output = builder_->GetContextIRModule();
88+
if (requires_tir_convert_ssa_) {
89+
output = tir::transform::ConvertSSA()(output);
90+
}
91+
92+
return output;
8793
}
8894

8995
private:
@@ -129,7 +135,7 @@ class LegalizeMutator : public ExprMutator {
129135
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
130136
}
131137

132-
Target GetTarget(const Array<StructInfo>& sinfos) {
138+
Optional<Target> GetTarget(const Array<StructInfo>& sinfos) {
133139
for (auto sinfo : sinfos) {
134140
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
135141
if (tinfo->vdevice.defined()) {
@@ -142,20 +148,90 @@ class LegalizeMutator : public ExprMutator {
142148
return GetTarget(tup_sinfo->fields);
143149
}
144150
}
145-
return Target();
151+
return NullOpt;
146152
}
147153

148154
void SaveTarget(const Expr& expr) {
149155
if (expr->IsInstance<CallNode>()) {
150156
auto call = Downcast<Call>(expr);
151-
auto target = GetTarget(call->sinfo_args);
152-
const GlobalVarNode* gvar_node;
153-
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
154-
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
157+
158+
if (auto target = GetTarget(call->sinfo_args)) {
159+
if (auto gvar = call->args[0].as<GlobalVar>()) {
160+
this->tmap_.Set(gvar.value(), target.value());
161+
}
155162
}
156163
}
157164
}
158165

166+
Expr BindTarget(Expr expr) {
167+
if (!expr->IsInstance<CallNode>()) {
168+
// FLegalize returned something other than a relax::Call. This
169+
// post-processing only handles cases where legalization
170+
// produces a lowered call node. In principle, this
171+
// post-processing isn't necessary, and FLegalize should already
172+
// have generated vdevice-aware kernels, so hopefully the
173+
// FLegalize implementation did so.
174+
return expr;
175+
}
176+
177+
auto call = Downcast<Call>(expr);
178+
179+
auto vdevice_target = GetTarget(call->sinfo_args);
180+
if (!vdevice_target.defined()) {
181+
// No vdevice annotation is present, so we don't need to apply
182+
// any updates.
183+
return expr;
184+
}
185+
186+
if (call->args.empty()) {
187+
return expr;
188+
}
189+
190+
auto gvar = call->args[0].as<GlobalVar>();
191+
if (!gvar.defined()) {
192+
// This is not a call into a legalized function within the
193+
// current IRModule, so no post-processing is required.
194+
return expr;
195+
}
196+
197+
auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value());
198+
auto opt_prim_func = base_func.as<tir::PrimFunc>();
199+
if (!opt_prim_func) {
200+
// The call is to something other than a PrimFunc. It may be
201+
// another Relax function, in which case the legalization of its
202+
// body will handle any additional target annotations.
203+
return expr;
204+
}
205+
auto prim_func = opt_prim_func.value();
206+
207+
auto func_target = prim_func->GetAttr<Target>(tvm::attr::kTarget);
208+
if (func_target && func_target.value()->kind == vdevice_target.value()->kind) {
209+
// The function already has compatible annotations for the
210+
// target, so no modifications are required.
211+
return expr;
212+
}
213+
214+
// The FLegalize function generated a PrimFunc, but that PrimFunc
215+
// doesn't have annotations compatible with the vdevice required
216+
// by the Relax StructInfo. Update the call to instead call a
217+
// `PrimFunc` with the appropriate target annotation. In the
218+
// future, this may be treated as a bug in the FLegalize
219+
// implementation, rather than expected output from it.
220+
auto new_prim_func = WithAttr(prim_func, tvm::attr::kTarget, vdevice_target.value());
221+
auto new_gvar_name = [&]() -> std::string {
222+
std::stringstream ss;
223+
ss << gvar.value()->name_hint;
224+
ss << "_";
225+
ss << vdevice_target.value()->kind->name;
226+
return ss.str();
227+
}();
228+
auto new_gvar = builder_->AddFunction(new_prim_func, new_gvar_name);
229+
requires_tir_convert_ssa_ = true;
230+
231+
call.CopyOnWrite()->args.Set(0, new_gvar);
232+
return call;
233+
}
234+
159235
Expr VisitExpr_(const CallNode* call) final {
160236
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
161237
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
@@ -268,8 +344,10 @@ class LegalizeMutator : public ExprMutator {
268344
}
269345
Expr legalized = legalization_func(builder_, visited_call);
270346

347+
legalized = BindTarget(legalized);
348+
271349
// Save the expected target info. into tmap_
272-
SaveTarget(legalized);
350+
// SaveTarget(legalized);
273351

274352
legalized = builder_->Normalize(legalized);
275353

@@ -305,6 +383,7 @@ class LegalizeMutator : public ExprMutator {
305383
Map<String, PackedFunc> cmap_;
306384
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
307385
Map<GlobalVar, Target> tmap_;
386+
bool requires_tir_convert_ssa_{false};
308387
/*!
309388
* \brief A boolean value indicating if to print warnings for CallNode whose op's
310389
* legalization function is not registered.

src/tir/transforms/ir_utils.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,42 @@ class IRConvertSSA final : public StmtExprMutator {
246246
return std::move(decl);
247247
}
248248

249+
Stmt VisitStmt_(const BlockNode* op) final {
250+
Block block = GetRef<Block>(op);
251+
252+
// The BlockNode is the point of definition for the IterVar
253+
// instances. These re-defines must be present before visiting
254+
// the body of the BlockNode.
255+
std::vector<ScopedRedefine> redefines;
256+
Array<IterVar> iter_vars = op->iter_vars.Map([&](IterVar iter_var) {
257+
if (defined_.count(iter_var->var.get())) {
258+
redefines.emplace_back(this, iter_var->var);
259+
iter_var.CopyOnWrite()->var = redefines.back().new_var;
260+
} else {
261+
defined_.insert(iter_var->var.get());
262+
}
263+
return iter_var;
264+
});
265+
Array<BufferRegion> reads =
266+
block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); });
267+
Array<BufferRegion> writes =
268+
block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); });
269+
270+
if (!reads.same_as(block->reads) || !writes.same_as(block->writes) ||
271+
!iter_vars.same_as(op->iter_vars)) {
272+
auto write_ptr = block.CopyOnWrite();
273+
write_ptr->reads = reads;
274+
write_ptr->writes = writes;
275+
write_ptr->iter_vars = iter_vars;
276+
}
277+
278+
Stmt output = Downcast<Block>(StmtExprMutator::VisitStmt_(block.get()));
279+
280+
while (redefines.size()) redefines.pop_back();
281+
282+
return output;
283+
}
284+
249285
template <typename Node>
250286
Node VisitBufferAccess(Node node) {
251287
Buffer new_buf = GetRemappedBuffer(node->buffer);

tests/python/relax/test_transform_legalize_ops.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,5 +356,86 @@ def main(
356356
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)
357357

358358

359+
def test_legalize_with_vdevice():
360+
"""Legalization may generate kernels for multiple targets
361+
362+
This is a regression test. In previous implementations, Relax
363+
expressions whose argument types differed only by their `vdevice`
364+
would be legalized to use the same `PrimFunc`.
365+
366+
"""
367+
368+
@I.ir_module
369+
class Before:
370+
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
371+
372+
@R.function
373+
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
374+
C = R.add(A, B)
375+
return C
376+
377+
@R.function
378+
def func_llvm(
379+
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
380+
):
381+
C = R.add(A, B)
382+
return C
383+
384+
@I.ir_module
385+
class Expected:
386+
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
387+
388+
@R.function
389+
def func_cuda(
390+
A: R.Tensor((32, 32), dtype="float32"),
391+
B: R.Tensor((32, 32), dtype="float32"),
392+
):
393+
cls = Expected
394+
C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
395+
return C
396+
397+
@T.prim_func(private=True)
398+
def add(
399+
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
400+
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
401+
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
402+
):
403+
T.func_attr({"tir.noalias": T.bool(True)})
404+
for iters in T.grid(T.int64(32), T.int64(32)):
405+
with T.block("T_add"):
406+
ax0, ax1 = T.axis.remap("SS", iters)
407+
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]
408+
409+
@R.function
410+
def func_llvm(
411+
A: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
412+
B: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
413+
):
414+
cls = Expected
415+
C = R.call_tir(
416+
cls.add_llvm,
417+
(A, B),
418+
out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
419+
)
420+
return C
421+
422+
@T.prim_func(private=True)
423+
def add_llvm(
424+
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
425+
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
426+
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
427+
):
428+
T.func_attr({"target": T.target("llvm"), "tir.noalias": T.bool(True)})
429+
for iters in T.grid(T.int64(32), T.int64(32)):
430+
with T.block("T_add"):
431+
ax0, ax1 = T.axis.remap("SS", iters)
432+
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]
433+
434+
with tvm.target.Target("cuda"):
435+
After = tvm.relax.transform.LegalizeOps()(Before)
436+
437+
tvm.ir.assert_structural_equal(Expected, After)
438+
439+
359440
if __name__ == "__main__":
360441
tvm.testing.main()

0 commit comments

Comments
 (0)