Skip to content

Commit 4c564a6

Browse files
author
Gokulnath Srinivasan
committed
Add LLVM Legalization for tir.erf
1 parent 910aeaf commit 4c564a6

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,23 @@ TVM_REGISTER_OP("tir.atanh")
241241
PrimExpr one = make_const(x.dtype(), 1.0);
242242
return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
243243
});
244+
245+
TVM_REGISTER_OP("tir.erf").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246+
using tir::make_const;
247+
const tir::CallNode* call = e.as<tir::CallNode>();
248+
ICHECK(call != nullptr) << "Invalid call node in erf legalization";
249+
const PrimExpr& x = call->args[0];
250+
PrimExpr abs_x = tvm::abs(x);
251+
PrimExpr t = make_const(x.dtype(), 1.0) / (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x);
252+
PrimExpr a1 = make_const(x.dtype(), 0.254829592);
253+
PrimExpr a2 = make_const(x.dtype(), -0.284496736);
254+
PrimExpr a3 = make_const(x.dtype(), 1.421413741);
255+
PrimExpr a4 = make_const(x.dtype(), -1.453152027);
256+
PrimExpr a5 = make_const(x.dtype(), 1.061405429);
257+
PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t);
258+
PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x);
259+
return tvm::tir::Select(x < 0, -approx, approx);
260+
});
244261

245262
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246263
const tir::CallNode* call = e.as<tir::CallNode>();

0 commit comments

Comments
 (0)