-
Notifications
You must be signed in to change notification settings - Fork 328
[Math] Dispatch T.rsqrt(x) into cuda intrin instead of 1 / T.sqrt(x)
#781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
287bf3a
5f44aef
b2e85c0
aeff94d
200e366
0b4dc07
6b038e7
d4e9158
0fb9ae6
441c9d0
ea32fb0
261f099
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| /*! | ||
| * \file intrin_rule_cuda.cc | ||
| * \brief CUDA intrinsic rules. | ||
| */ | ||
| #include <tvm/tir/builtin.h> | ||
| #include <tvm/tir/op_attr_types.h> | ||
|
|
||
| #include "target/intrin_rule.h" | ||
|
|
||
| namespace tvm { | ||
| namespace codegen { | ||
| namespace intrin { | ||
| // Add float suffix to the intrinsics, CUDA fast math. | ||
| using tir::FLowerIntrinsic; | ||
|
|
||
| struct CUDAMath { | ||
| std::string operator()(DataType t, std::string name) const { | ||
| if (t.is_float()) { | ||
| switch (t.bits()) { | ||
| case 64: | ||
| return name; | ||
| case 32: | ||
| return name + 'f'; | ||
| case 16: { | ||
| if (name == "fabs") { | ||
| return "__habs"; | ||
| } else if (name == "round") { | ||
| return "hrint"; | ||
| } else { | ||
| return "h" + name; | ||
| } | ||
| } | ||
| default: | ||
| return ""; | ||
| } | ||
| } else if (t.is_bfloat16()) { | ||
| if (name == "fabs") { | ||
| return "__habs"; | ||
| } else if (name == "round") { | ||
| return "hrint"; | ||
| } else { | ||
| return "h" + name; | ||
| } | ||
| } else if (t.is_int() || t.is_uint()) { | ||
| switch (t.bits()) { | ||
| case 32: | ||
| return "__" + name; | ||
| case 64: | ||
| return "__" + name + "ll"; | ||
| default: | ||
| return ""; | ||
| } | ||
| } | ||
| return ""; | ||
| } | ||
| }; | ||
|
|
||
| struct CUDAFastMath : public CUDAMath { | ||
| std::string operator()(DataType t, std::string name) const { | ||
| if (t.is_float() && t.bits() == 32) { | ||
| return "__" + name + 'f'; | ||
| } else { | ||
| return CUDAMath::operator()(t, name); | ||
| } | ||
| return ""; | ||
| } | ||
| }; | ||
|
|
||
| struct CUDAFastMathTan : public CUDAMath { | ||
| std::string operator()(DataType t, std::string name) const { | ||
| if (t.is_float()) { | ||
| switch (t.bits()) { | ||
| case 64: | ||
| return name; | ||
| // `__tanf` seems to produce some values too deviant from numpy tan | ||
| // version. So, let's use just `tanf` instead. | ||
| case 32: | ||
| return name + 'f'; | ||
| case 16: | ||
| return 'h' + name; | ||
| default: | ||
| return ""; | ||
| } | ||
| } | ||
| return ""; | ||
| } | ||
| }; | ||
|
|
||
| struct CUDAPopcount { | ||
| std::string operator()(DataType t, std::string name) const { | ||
| if (t.is_uint()) { | ||
| switch (t.bits()) { | ||
| case 32: | ||
| return "__popc"; | ||
| case 64: | ||
| return "__popcll"; | ||
| default: | ||
| return ""; | ||
| } | ||
| } | ||
| return ""; | ||
| } | ||
| }; | ||
|
|
||
| struct CUDAWarpIntrinsic { | ||
| const Op operator()(DataType t, const Op &orig_op) const { | ||
| if (orig_op.same_as(builtin::tvm_warp_shuffle())) { | ||
| return Op::Get("tir.cuda.__shfl_sync"); | ||
| } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { | ||
| return Op::Get("tir.cuda.__shfl_up_sync"); | ||
| } else { | ||
| ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); | ||
| return Op::Get("tir.cuda.__shfl_down_sync"); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) { | ||
| const CallNode *call = e.as<CallNode>(); | ||
| return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); | ||
| } | ||
|
|
||
| template <typename T> static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) { | ||
| const CallNode *call = e.as<CallNode>(); | ||
| ICHECK(call != nullptr); | ||
| ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size | ||
| Array<PrimExpr> cuda_args{ | ||
| {call->args[0], call->args[1], call->args[2], call->args[3]}}; | ||
| return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args); | ||
| } | ||
|
|
||
| TVM_REGISTER_OP("tir.rsqrt") | ||
| .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", | ||
| DispatchPureExtern<CUDAMath>); | ||
|
|
||
| } // namespace intrin | ||
| } // namespace codegen | ||
| } // namespace tvm |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -55,6 +55,11 @@ TL_PATCH TL_DEVICE half_t __habs(const half_t x) { | |||||||||||||||||||||
| return half_t(__habs(x.to_half())); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // hrsqrt function for half_t | ||||||||||||||||||||||
| TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) { | ||||||||||||||||||||||
| return half_t(hrsqrt(x.to_half())); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Comment on lines
+58
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Fix infinite recursion in hrsqrt(half_t) and use device rsqrtf
-// hrsqrt function for half_t
-TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
- return half_t(hrsqrt(x.to_half()));
-}
+// hrsqrt function for half_t
+TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
+ float xf = static_cast<float>(x);
+ float rf = rsqrtf(xf);
+ return half_t(rf);
+}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents💡 Verification agent 🧩 Analysis chainAdd bfloat16_t overload (rsqrt) to match lowering for BF16
Proposed addition (outside the shown hunk): TL_PATCH TL_DEVICE bfloat16_t hrsqrt(const bfloat16_t x) {
float xf = static_cast<float>(x);
float rf = rsqrtf(xf);
return bfloat16_t(rf);
}If you prefer distinct naming (e.g., 🏁 Script executed: #!/bin/bash
# Search for existing bfloat16_t overloads or mentions in cuda common.h
rg -n "bfloat16_t" -C3 src/tl_templates/cuda/common.h
# Search specifically for TL_DEVICE functions taking bfloat16_t
rg -nP --type=cpp "TL_DEVICE.*bfloat16_t" -C3 src/tl_templates/cuda/common.hLength of output: 2301 Add bfloat16_t hrsqrt overload TL_PATCH TL_DEVICE bfloat16_t hrsqrt(const bfloat16_t x) {
float xf = static_cast<float>(x);
float rf = rsqrtf(xf);
return bfloat16_t(rf);
}🤖 Prompt for AI Agents |
||||||||||||||||||||||
| // Pack two half values. | ||||||||||||||||||||||
| TL_DEVICE unsigned __pack_half2(const half x, const half y) { | ||||||||||||||||||||||
| unsigned v0 = *((unsigned short *)&x); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Harden get_cmake_path: handle None from shutil.which, support env override, raise RuntimeError
Current code calls os.path.exists on a possible None and throws TypeError; also only checks cmake, not cmake3, and uses generic Exception.
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.12.2)
249-249: Create your own exception
(TRY002)
249-249: Avoid specifying long messages outside the exception class
(TRY003)