Skip to content

Commit 17027ac

Browse files
author
Gokulnath Srinivasan
committed
Add LLVM Legalization for tir.erf
1 parent 910aeaf commit 17027ac

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,18 @@ TVM_REGISTER_OP("tir.atanh")
241241
PrimExpr one = make_const(x.dtype(), 1.0);
242242
return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
243243
});
244+
245+
TVM_REGISTER_OP("tir.erf").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246+
using tir::make_const;
247+
const tir::CallNode* call = e.as<tir::CallNode>();
248+
ICHECK(call != nullptr) << "Invalid call node in erf legalization";
249+
const PrimExpr& x = call->args[0];
250+
PrimExpr sqrt_pi = sqrt(make_const(x.dtype(), M_PI));
251+
PrimExpr coeff = make_const(x.dtype(), 2.0) / sqrt_pi;
252+
PrimExpr x_cubed = x * x * x;
253+
PrimExpr inner = x + make_const(x.dtype(), 11.0 / 123.0) * x_cubed;
254+
return tanh(coeff * inner);
255+
});
244256

245257
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246258
const tir::CallNode* call = e.as<tir::CallNode>();

tests/python/relax/test_frontend_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2804,7 +2804,7 @@ def test_symbolic_shape_deduction():
28042804

28052805
@R.function
28062806
def expected(
2807-
data: R.Tensor(("batch", "seq"), dtype="float32")
2807+
data: R.Tensor(("batch", "seq"), dtype="float32"),
28082808
) -> R.Tensor(dtype="float32", ndim=1):
28092809
batch = T.int64()
28102810
seq = T.int64()

0 commit comments

Comments
 (0)