Skip to content

Commit a8816b2

Browse files
author
Gokulnath Srinivasan
committed
Add LLVM Legalization for tir.erf
1 parent 910aeaf commit a8816b2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,18 @@ TVM_REGISTER_OP("tir.atanh")
241241
PrimExpr one = make_const(x.dtype(), 1.0);
242242
return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
243243
});
244+
245+
TVM_REGISTER_OP("tir.erf").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246+
using tir::make_const;
247+
const tir::CallNode* call = e.as<tir::CallNode>();
248+
ICHECK(call != nullptr) << "Invalid call node in erf legalization";
249+
const PrimExpr& x = call->args[0];
250+
PrimExpr sqrt_pi = sqrt(make_const(x.dtype(), M_PI));
251+
PrimExpr coeff = make_const(x.dtype(), 2.0) / sqrt_pi;
252+
PrimExpr x_cubed = x * x * x;
253+
PrimExpr inner = x + make_const(x.dtype(), 11.0 / 123.0) * x_cubed;
254+
return tanh(coeff * inner);
255+
});
244256

245257
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246258
const tir::CallNode* call = e.as<tir::CallNode>();

0 commit comments

Comments
 (0)