-
Notifications
You must be signed in to change notification settings - Fork 330
[FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit fastmath op to invoke #875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit fastmath op to invoke #875
Changes from 4 commits
d2f68b9
2463338
029ddf3
7708cc1
65927de
a44e234
0cf3d49
ce4a1a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -21,6 +21,79 @@ namespace tvm { | |||||||||||||||||||||||||||||||||||
| namespace codegen { | ||||||||||||||||||||||||||||||||||||
| using namespace tvm::tl::codegen; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| struct CUDAMath { | ||||||||||||||||||||||||||||||||||||
| std::string operator()(DataType t, std::string name) const { | ||||||||||||||||||||||||||||||||||||
| if (t.is_float()) { | ||||||||||||||||||||||||||||||||||||
| switch (t.bits()) { | ||||||||||||||||||||||||||||||||||||
| case 64: | ||||||||||||||||||||||||||||||||||||
| return name; | ||||||||||||||||||||||||||||||||||||
| case 32: | ||||||||||||||||||||||||||||||||||||
| return name + 'f'; | ||||||||||||||||||||||||||||||||||||
| case 16: { | ||||||||||||||||||||||||||||||||||||
| if (name == "fabs") { | ||||||||||||||||||||||||||||||||||||
| return "__habs"; | ||||||||||||||||||||||||||||||||||||
| } else if (name == "round") { | ||||||||||||||||||||||||||||||||||||
| return "hrint"; | ||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||
| return "h" + name; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| default: | ||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| } else if (t.is_bfloat16()) { | ||||||||||||||||||||||||||||||||||||
| if (name == "fabs") { | ||||||||||||||||||||||||||||||||||||
| return "__habs"; | ||||||||||||||||||||||||||||||||||||
| } else if (name == "round") { | ||||||||||||||||||||||||||||||||||||
| return "hrint"; | ||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||
| return "h" + name; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| } else if (t.is_int() || t.is_uint()) { | ||||||||||||||||||||||||||||||||||||
| switch (t.bits()) { | ||||||||||||||||||||||||||||||||||||
| case 32: | ||||||||||||||||||||||||||||||||||||
| return "__" + name; | ||||||||||||||||||||||||||||||||||||
| case 64: | ||||||||||||||||||||||||||||||||||||
| return "__" + name + "ll"; | ||||||||||||||||||||||||||||||||||||
| default: | ||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| struct CUDAFastMath : public CUDAMath { | ||||||||||||||||||||||||||||||||||||
| std::string operator()(DataType t, std::string name) const { | ||||||||||||||||||||||||||||||||||||
| if (t.is_float() && t.bits() == 32) { | ||||||||||||||||||||||||||||||||||||
| return "__" + name + 'f'; | ||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||
| return CUDAMath::operator()(t, name); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| struct CUDAFastMathTan : public CUDAMath { | ||||||||||||||||||||||||||||||||||||
| std::string operator()(DataType t, std::string name) const { | ||||||||||||||||||||||||||||||||||||
| if (t.is_float()) { | ||||||||||||||||||||||||||||||||||||
| switch (t.bits()) { | ||||||||||||||||||||||||||||||||||||
| case 64: | ||||||||||||||||||||||||||||||||||||
| return name; | ||||||||||||||||||||||||||||||||||||
| // `__tanf` seems to produce some values too deviant from numpy tan | ||||||||||||||||||||||||||||||||||||
| // version. So, let's use just `tanf` instead. | ||||||||||||||||||||||||||||||||||||
| case 32: | ||||||||||||||||||||||||||||||||||||
| return name + 'f'; | ||||||||||||||||||||||||||||||||||||
| case 16: | ||||||||||||||||||||||||||||||||||||
| return 'h' + name; | ||||||||||||||||||||||||||||||||||||
| default: | ||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| static std::string GetFP8Type(DataType type) { | ||||||||||||||||||||||||||||||||||||
| std::stringstream stream; | ||||||||||||||||||||||||||||||||||||
| int32_t lanes = type.lanes(); | ||||||||||||||||||||||||||||||||||||
|
|
@@ -1628,6 +1701,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { | |||||||||||||||||||||||||||||||||||
| op->args, true, os); | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::tl_shuffle_elect())) { | ||||||||||||||||||||||||||||||||||||
| os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__exp())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "exp"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__exp10())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "exp10"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__log())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "log"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__log2())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "log2"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__log10())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "log10"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__tan())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "tan"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__cos())) { | ||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "cos"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(tl::__sin())) { | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1725
to
+1732
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the tan-specific fast-math mapper We landed Please switch this site to - } else if (op->op.same_as(tl::__tan())) {
- CUDAFastMath math_func;
+ } else if (op->op.same_as(tl::__tan())) {
+ CUDAFastMathTan math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
| CUDAFastMath math_func; | ||||||||||||||||||||||||||||||||||||
| std::string func_name = math_func(op->dtype, "sin"); | ||||||||||||||||||||||||||||||||||||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1704
to
+1735
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for handling the different fast math intrinsics is very repetitive. Each |
||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||
| CodeGenC::VisitExpr_(op, os); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
return "";on line 73 is unreachable because all paths in theif/elsestatement above it return a value. This dead code should be removed.