Skip to content

Commit 3464bbb

Browse files
tqchenyongwww
authored andcommitted
[WEB] WebGPU Codegen (apache#14048)
This PR provide an implementation of WebGPU codegen. Previously we relied on SPIRV codegen for WebGPU, which is deprecated in favor of the WGSL shading language. Pass limited testing on elementwise via chrome. Likely we will do future iterations. Also cleans up some legacy code organization in intrinsics.
1 parent 6863520 commit 3464bbb

File tree

19 files changed

+870
-153
lines changed

19 files changed

+870
-153
lines changed

include/tvm/tir/op.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,15 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
678678
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
679679
Span span = Span());
680680

681+
/*!
682+
* \brief Fast_erf_float expression from Eigen
683+
*
684+
* \param arg The input expression.
685+
* \param bits The number of bits in the type.
686+
* \return The constructed expression.
687+
*/
688+
TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);
689+
681690
// Intrinsic operators
682691
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
683692
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \

include/tvm/topi/elemwise.h

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <tvm/tir/builtin.h>
2828
#include <tvm/tir/expr.h>
29+
#include <tvm/tir/op.h>
2930
#include <tvm/topi/tags.h>
3031

3132
#include <algorithm>
@@ -455,54 +456,6 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
455456
}
456457
}
457458

458-
/*!
459-
* \brief Fast_erf_float expression from Eigen
460-
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
461-
* \param arg The input expression.
462-
* \param bits The number of bits in the type.
463-
*/
464-
inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
465-
auto plus_4 = make_const(DataType::Float(bits), 4.f);
466-
auto minus_4 = make_const(DataType::Float(bits), -4.f);
467-
468-
// The monomial coefficients of the numerator polynomial (odd).
469-
auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f);
470-
auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f);
471-
auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f);
472-
auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f);
473-
auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f);
474-
auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f);
475-
auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f);
476-
477-
// The monomial coefficients of the denominator polynomial (even).
478-
auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f);
479-
auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f);
480-
auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f);
481-
auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f);
482-
auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f);
483-
484-
// clamp x
485-
auto x = tvm::max(tvm::min(arg, plus_4), minus_4);
486-
auto x2 = x * x;
487-
488-
// Evaluate the numerator polynomial p.
489-
auto p = x2 * alpha_13 + alpha_11;
490-
p = x2 * p + alpha_9;
491-
p = x2 * p + alpha_7;
492-
p = x2 * p + alpha_5;
493-
p = x2 * p + alpha_3;
494-
p = x2 * p + alpha_1;
495-
p = x * p;
496-
497-
// Evaluate the denominator polynomial p.
498-
auto q = x2 * beta_8 + beta_6;
499-
q = x2 * q + beta_4;
500-
q = x2 * q + beta_2;
501-
q = x2 * q + beta_0;
502-
503-
return p / q;
504-
}
505-
506459
/*!
507460
* \brief Fast_erf_float expression from Eigen
508461
*/

src/target/intrin_rule.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,22 @@ TVM_REGISTER_OP("tir.nearbyint")
118118
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
119119
DispatchPureExtern<FloatSuffix>);
120120

121+
PrimExpr DispatchFastErf(const PrimExpr& e) {
122+
LOG(WARNING) << "fast_erf will be used instead of erf";
123+
const CallNode* call = e.as<CallNode>();
124+
ICHECK(call != nullptr);
125+
ICHECK_EQ(call->args.size(), 1);
126+
PrimExpr arg = call->args[0];
127+
int bits = arg.dtype().bits();
128+
PrimExpr res;
129+
if (arg.dtype().is_float() && (bits == 16 || bits == 32)) {
130+
res = fast_erf_float_expr(arg, bits);
131+
} else {
132+
LOG(FATAL) << "Unsupported type in Metal fast_erf";
133+
}
134+
return res;
135+
}
136+
121137
} // namespace intrin
122138

123139
namespace legalize {

src/target/intrin_rule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
7777
}
7878
}
7979

80+
// Dispatch ERF to fast erf when it is not available.
81+
PrimExpr DispatchFastErf(const PrimExpr& e);
82+
8083
} // namespace intrin
8184
} // namespace codegen
8285
} // namespace tvm

src/target/source/codegen_c.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
262262
*/
263263
void RegisterHandleType(const VarNode* buf_var, DataType t);
264264
// override
265-
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final;
265+
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override;
266266
/*! \brief reserves common C keywords */
267267
void ReserveKeywordsAsUnique();
268268

@@ -281,10 +281,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
281281
const Op& builtin_call_extern_ = builtin::call_extern();
282282
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
283283
Integer constants_byte_alignment_ = 16;
284-
285-
private:
286284
/*! \brief whether to print in SSA form */
287285
bool print_ssa_form_{false};
286+
287+
private:
288288
/*! \brief set of volatile buf access */
289289
std::unordered_set<const VarNode*> volatile_buf_;
290290
// deep comparison of PrimExpr

src/target/source/codegen_metal.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
5555
// clear previous generated state.
5656
this->InitFuncState(f);
5757
// skip the first underscore, so SSA variable starts from _1
58-
name_supply_->FreshName("_");
58+
name_supply_->FreshName("v_");
5959

6060
// add to alloc buffer type.
6161
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);

src/target/source/codegen_source_base.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
4343
}
4444
}
4545
SSAEntry e;
46-
e.vid = name_supply_->FreshName("_");
46+
// use v_ prefix so it works for most systems
47+
e.vid = name_supply_->FreshName("v_");
4748
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
4849
ssa_assign_map_[src] = e;
4950
this->PrintIndent();

0 commit comments

Comments
 (0)