Skip to content

Commit

Permalink
[Target][Legalization]Add Tir Level Legalization Function Registratio…
Browse files Browse the repository at this point in the history
…n And Update Intrinsic Lowering Pass (apache#7936)
  • Loading branch information
zxybazh authored and Trevor Morris committed May 6, 2021
1 parent be89d82 commit cf13cab
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 110 deletions.
24 changes: 15 additions & 9 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,41 +112,47 @@ TVM_REGISTER_OP("tir.ceil")
TVM_REGISTER_OP("tir.round")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);

TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);

} // namespace intrin

namespace legalize {

using namespace tir;

TVM_REGISTER_OP("tir.rsqrt")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
return one / sqrt(call->args[0]);
});

TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);

TVM_REGISTER_OP("tir.sigmoid")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
return one / (one + exp(-call->args[0]));
});

TVM_REGISTER_OP("tir.isfinite")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
return isfinite(call->args[0]);
});

TVM_REGISTER_OP("tir.isinf")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
return isinf(call->args[0]);
});

TVM_REGISTER_OP("tir.q_multiply_shift")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;

const tir::CallNode* call = e.as<tir::CallNode>();
Expand Down Expand Up @@ -222,6 +228,6 @@ TVM_REGISTER_OP("tir.q_multiply_shift")
}
});

} // namespace intrin
} // namespace legalize
} // namespace codegen
} // namespace tvm
109 changes: 55 additions & 54 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
namespace tvm {
namespace codegen {
namespace llvm {
namespace intrin {
using tir::FLowerIntrinsic;

TVM_REGISTER_OP("tir.prefetch")
Expand All @@ -43,20 +44,6 @@ TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);

// TODO(tvm-team): migrate the legalization transformations as a separate
// set of rules in TIR that can be shared across backends.
TVM_REGISTER_OP("tir.exp10")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
PrimExpr ret = exp(x * ln10);
return ret;
});

TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);

Expand Down Expand Up @@ -99,8 +86,37 @@ TVM_REGISTER_OP("tir.nearbyint")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);

TVM_REGISTER_OP("tir.popcount")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);

TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
} // namespace intrin

namespace legalize {
using tir::FLegalize;

TVM_REGISTER_OP("tir.exp10")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
PrimExpr ret = exp(x * ln10);
return ret;
});

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
Expand All @@ -118,28 +134,16 @@ TVM_REGISTER_OP("tir.tanh")
return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});

TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);

TVM_REGISTER_OP("tir.popcount")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);

TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
[](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call =
e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr tan_x = sin(x) / cos(x);
return tan_x;
});

TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
TVM_REGISTER_OP("tir.tan").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr tan_x = sin(x) / cos(x);
return tan_x;
});

TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
Expand All @@ -153,11 +157,8 @@ TVM_REGISTER_OP("tir.cosh")
return ret;
});

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

TVM_REGISTER_OP("tir.sinh")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
Expand All @@ -171,21 +172,21 @@ TVM_REGISTER_OP("tir.sinh")
return ret;
});

TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
Array<PrimExpr> cargs;
cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
cargs.push_back(IntImm(DataType::UInt(32), 2));
cargs.push_back(call->args[0]);
cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
// LLVM requires that the return type must match the first argument type
auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
return cast(call->dtype, clz);
});

TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
Array<PrimExpr> cargs;
cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
cargs.push_back(IntImm(DataType::UInt(32), 2));
cargs.push_back(call->args[0]);
cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
// LLVM requires that the return type must match the first argument type
auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
return cast(call->dtype, clz);
});

} // namespace legalize
} // namespace llvm
} // namespace codegen
} // namespace tvm
Expand Down
53 changes: 28 additions & 25 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
namespace tvm {
namespace codegen {
namespace spirv {
using tir::FLowerIntrinsic;

// num_signature means number of arguments used to query signature
template <unsigned id>
PrimExpr CallGLSLIntrin(PrimExpr e, const Array<PrimExpr>& args) {
Expand Down Expand Up @@ -59,6 +57,8 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) {
return CallGLSLIntrin<id>(e);
}

namespace intrin {
using tir::FLowerIntrinsic;
TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Floor>);

Expand Down Expand Up @@ -98,29 +98,6 @@ TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);

TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
"vulkan.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
PrimExpr msb;
if (arg.dtype().bits() == 64) {
// SPIR-V FindUMsb intrinsic only supports 32 bit input
auto int32 = DataType::Int(32);
PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
} else if (arg.dtype().bits() == 32) {
msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
} else {
LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
}
return PrimExpr(arg.dtype().bits() - 1) - msb;
});

// WebGPU rules.
TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Floor>);
Expand Down Expand Up @@ -151,7 +128,33 @@ TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
} // namespace intrin

namespace legalize {
using tir::FLegalize;
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>(
"vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
PrimExpr msb;
if (arg.dtype().bits() == 64) {
// SPIR-V FindUMsb intrinsic only supports 32 bit input
auto int32 = DataType::Int(32);
PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
} else if (arg.dtype().bits() == 32) {
msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
} else {
LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
}
return PrimExpr(arg.dtype().bits() - 1) - msb;
});
} // namespace legalize
} // namespace spirv
} // namespace codegen
} // namespace tvm
43 changes: 22 additions & 21 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,34 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
using FLowerGeneral = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;

IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
std::vector<std::string> patterns_;
patterns_.push_back(target + ".FLowerIntrinsic");

std::vector<std::string> patterns;
patterns.push_back(target + ".FLowerIntrinsic");
patterns.push_back(target + ".FLegalize");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
}

patterns_.push_back("default.FLowerIntrinsic");

fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
patterns.push_back(target + ".aarch64.FLowerIntrinsic");
patterns.push_back(target + ".aarch64.FLegalize");
}

for (const std::string& pattern : patterns_)
if (Op::HasAttrMap(pattern))
lower_intrin_maps_.push_back(Op::GetAttrMap<FLowerIntrinsic>(pattern));
patterns.push_back("default.FLowerIntrinsic");
patterns.push_back("default.FLegalize");

for (const std::string& pattern : patterns)
if (Op::HasAttrMap(pattern)) {
attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern));
if (fma_ == nullptr) {
fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr);
}
}
}

PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
for (const auto& f_lower_intrin_map : lower_intrin_maps_) {
FLowerIntrinsic f = f_lower_intrin_map.get(GetRef<Op>(ptr_op), nullptr);
for (const auto& f_attr_map : attr_maps_) {
FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op);
PrimExpr r = f(e);
Expand Down Expand Up @@ -269,7 +270,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
PrimExpr rhs = SwapBroadcastCast(b);

if (fma_ != nullptr && op->dtype.is_float()) {
PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
Expand All @@ -280,9 +281,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}

// patterns
std::vector<OpAttrMap<FLowerIntrinsic>> lower_intrin_maps_;
const PackedFunc* fma_{nullptr};
// attribute maps, shared only when FLegalize == FLowerIntrinsic
std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
FLowerGeneral fma_{nullptr};
bool support_bitwise_op_{true};
};

Expand Down
Loading

0 comments on commit cf13cab

Please sign in to comment.