Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,15 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());

/*!
* \brief Fast_erf_float expression from Eigen
*
* \param arg The input expression.
* \param bits The number of bits in the type.
* \return The constructed expression.
*/
TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
Expand Down
49 changes: 1 addition & 48 deletions include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/topi/tags.h>

#include <algorithm>
Expand Down Expand Up @@ -455,54 +456,6 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
}
}

/*!
* \brief Fast_erf_float expression from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
* \param arg The input expression.
* \param bits The number of bits in the type.
*/
inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
auto plus_4 = make_const(DataType::Float(bits), 4.f);
auto minus_4 = make_const(DataType::Float(bits), -4.f);

// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f);

// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f);

// clamp x
auto x = tvm::max(tvm::min(arg, plus_4), minus_4);
auto x2 = x * x;

// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;

// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;

return p / q;
}

/*!
* \brief Fast_erf_float expression from Eigen
*/
Expand Down
16 changes: 16 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ TVM_REGISTER_OP("tir.nearbyint")
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);

PrimExpr DispatchFastErf(const PrimExpr& e) {
LOG(WARNING) << "fast_erf will be used instead of erf";
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
int bits = arg.dtype().bits();
PrimExpr res;
if (arg.dtype().is_float() && (bits == 16 || bits == 32)) {
res = fast_erf_float_expr(arg, bits);
} else {
LOG(FATAL) << "Unsupported type in Metal fast_erf";
}
return res;
}

} // namespace intrin

namespace legalize {
Expand Down
3 changes: 3 additions & 0 deletions src/target/intrin_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
}
}

// Dispatch ERF to fast erf when it is not available.
PrimExpr DispatchFastErf(const PrimExpr& e);

} // namespace intrin
} // namespace codegen
} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
*/
void RegisterHandleType(const VarNode* buf_var, DataType t);
// override
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final;
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();

Expand All @@ -281,10 +281,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
const Op& builtin_call_extern_ = builtin::call_extern();
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
Integer constants_byte_alignment_ = 16;

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};

private:
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("_");
name_supply_->FreshName("v_");

// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
Expand Down
3 changes: 2 additions & 1 deletion src/target/source/codegen_source_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
}
}
SSAEntry e;
e.vid = name_supply_->FreshName("_");
// use v_ prefix so it works for most systems
e.vid = name_supply_->FreshName("v_");
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
ssa_assign_map_[src] = e;
this->PrintIndent();
Expand Down
Loading