@@ -241,6 +241,23 @@ TVM_REGISTER_OP("tir.atanh")
241241 PrimExpr one = make_const (x.dtype (), 1.0 );
242242 return (log (one + x) - log (one - x)) * make_const (x.dtype (), 0.5 );
243243 });
244+
245+ TVM_REGISTER_OP (" tir.erf" ).set_attr<FLegalize>(" llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr {
246+ using tir::make_const;
247+ const tir::CallNode* call = e.as <tir::CallNode>();
248+ ICHECK (call != nullptr ) << " Invalid call node in erf legalization" ;
249+ const PrimExpr& x = call->args [0 ];
250+ PrimExpr abs_x = tvm::abs (x);
251+ PrimExpr t = make_const (x.dtype (), 1.0 ) / (make_const (x.dtype (), 1.0 ) + make_const (x.dtype (), 0.3275911 ) * abs_x);
252+ PrimExpr a1 = make_const (x.dtype (), 0.254829592 );
253+ PrimExpr a2 = make_const (x.dtype (), -0.284496736 );
254+ PrimExpr a3 = make_const (x.dtype (), 1.421413741 );
255+ PrimExpr a4 = make_const (x.dtype (), -1.453152027 );
256+ PrimExpr a5 = make_const (x.dtype (), 1.061405429 );
257+ PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t);
258+ PrimExpr approx = make_const (x.dtype (), 1.0 ) - poly * exp (-abs_x * abs_x);
259+ return tvm::tir::Select (x < 0 , -approx, approx);
260+ });
244261
245262TVM_REGISTER_OP (" tir.clz" ).set_attr<FLegalize>(" llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr {
246263 const tir::CallNode* call = e.as <tir::CallNode>();
0 commit comments