Skip to content

Commit 0fb9ae6

Browse files
committed
lint fix
1 parent d4e9158 commit 0fb9ae6

File tree

2 files changed

+134
-135
lines changed

2 files changed

+134
-135
lines changed

src/target/intrin_rule_cuda.cc

Lines changed: 134 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -2,137 +2,137 @@
22
* \file intrin_rule_cuda.cc
33
* \brief CUDA intrinsic rules.
44
*/
5-
#include <tvm/tir/builtin.h>
6-
#include <tvm/tir/op_attr_types.h>
7-
8-
#include "target/intrin_rule.h"
9-
10-
namespace tvm {
11-
namespace codegen {
12-
namespace intrin {
13-
// Add float suffix to the intrinsics, CUDA fast math.
14-
using tir::FLowerIntrinsic;
15-
16-
struct CUDAMath {
17-
std::string operator()(DataType t, std::string name) const {
18-
if (t.is_float()) {
19-
switch (t.bits()) {
20-
case 64:
21-
return name;
22-
case 32:
23-
return name + 'f';
24-
case 16: {
25-
if (name == "fabs") {
26-
return "__habs";
27-
} else if (name == "round") {
28-
return "hrint";
29-
} else {
30-
return "h" + name;
31-
}
32-
}
33-
default:
34-
return "";
35-
}
36-
} else if (t.is_bfloat16()) {
37-
if (name == "fabs") {
38-
return "__habs";
39-
} else if (name == "round") {
40-
return "hrint";
41-
} else {
42-
return "h" + name;
43-
}
44-
} else if (t.is_int() || t.is_uint()) {
45-
switch (t.bits()) {
46-
case 32:
47-
return "__" + name;
48-
case 64:
49-
return "__" + name + "ll";
50-
default:
51-
return "";
52-
}
53-
}
54-
return "";
55-
}
56-
};
57-
58-
struct CUDAFastMath : public CUDAMath {
59-
std::string operator()(DataType t, std::string name) const {
60-
if (t.is_float() && t.bits() == 32) {
61-
return "__" + name + 'f';
62-
} else {
63-
return CUDAMath::operator()(t, name);
64-
}
65-
return "";
66-
}
67-
};
68-
69-
struct CUDAFastMathTan : public CUDAMath {
70-
std::string operator()(DataType t, std::string name) const {
71-
if (t.is_float()) {
72-
switch (t.bits()) {
73-
case 64:
74-
return name;
75-
// `__tanf` seems to produce some values too deviant from numpy tan version.
76-
// So, let's use just `tanf` instead.
77-
case 32:
78-
return name + 'f';
79-
case 16:
80-
return 'h' + name;
81-
default:
82-
return "";
83-
}
84-
}
85-
return "";
86-
}
87-
};
88-
89-
struct CUDAPopcount {
90-
std::string operator()(DataType t, std::string name) const {
91-
if (t.is_uint()) {
92-
switch (t.bits()) {
93-
case 32:
94-
return "__popc";
95-
case 64:
96-
return "__popcll";
97-
default:
98-
return "";
99-
}
100-
}
101-
return "";
102-
}
103-
};
104-
105-
struct CUDAWarpIntrinsic {
106-
const Op operator()(DataType t, const Op& orig_op) const {
107-
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
108-
return Op::Get("tir.cuda.__shfl_sync");
109-
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
110-
return Op::Get("tir.cuda.__shfl_up_sync");
111-
} else {
112-
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
113-
return Op::Get("tir.cuda.__shfl_down_sync");
114-
}
115-
}
116-
};
117-
118-
static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) {
119-
const CallNode* call = e.as<CallNode>();
120-
return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
121-
}
122-
123-
template <typename T>
124-
static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) {
125-
const CallNode* call = e.as<CallNode>();
126-
ICHECK(call != nullptr);
127-
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
128-
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
129-
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
130-
}
131-
132-
TVM_REGISTER_OP("tir.rsqrt")
133-
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
134-
135-
} // namespace intrin
136-
} // namespace codegen
137-
} // namespace tvm
138-
5+
#include <tvm/tir/builtin.h>
6+
#include <tvm/tir/op_attr_types.h>
7+
8+
#include "target/intrin_rule.h"
9+
10+
namespace tvm {
11+
namespace codegen {
12+
namespace intrin {
13+
// Add float suffix to the intrinsics, CUDA fast math.
14+
using tir::FLowerIntrinsic;
15+
16+
struct CUDAMath {
17+
std::string operator()(DataType t, std::string name) const {
18+
if (t.is_float()) {
19+
switch (t.bits()) {
20+
case 64:
21+
return name;
22+
case 32:
23+
return name + 'f';
24+
case 16: {
25+
if (name == "fabs") {
26+
return "__habs";
27+
} else if (name == "round") {
28+
return "hrint";
29+
} else {
30+
return "h" + name;
31+
}
32+
}
33+
default:
34+
return "";
35+
}
36+
} else if (t.is_bfloat16()) {
37+
if (name == "fabs") {
38+
return "__habs";
39+
} else if (name == "round") {
40+
return "hrint";
41+
} else {
42+
return "h" + name;
43+
}
44+
} else if (t.is_int() || t.is_uint()) {
45+
switch (t.bits()) {
46+
case 32:
47+
return "__" + name;
48+
case 64:
49+
return "__" + name + "ll";
50+
default:
51+
return "";
52+
}
53+
}
54+
return "";
55+
}
56+
};
57+
58+
struct CUDAFastMath : public CUDAMath {
59+
std::string operator()(DataType t, std::string name) const {
60+
if (t.is_float() && t.bits() == 32) {
61+
return "__" + name + 'f';
62+
} else {
63+
return CUDAMath::operator()(t, name);
64+
}
65+
return "";
66+
}
67+
};
68+
69+
struct CUDAFastMathTan : public CUDAMath {
70+
std::string operator()(DataType t, std::string name) const {
71+
if (t.is_float()) {
72+
switch (t.bits()) {
73+
case 64:
74+
return name;
75+
// `__tanf` seems to produce some values too deviant from numpy tan
76+
// version. So, let's use just `tanf` instead.
77+
case 32:
78+
return name + 'f';
79+
case 16:
80+
return 'h' + name;
81+
default:
82+
return "";
83+
}
84+
}
85+
return "";
86+
}
87+
};
88+
89+
struct CUDAPopcount {
90+
std::string operator()(DataType t, std::string name) const {
91+
if (t.is_uint()) {
92+
switch (t.bits()) {
93+
case 32:
94+
return "__popc";
95+
case 64:
96+
return "__popcll";
97+
default:
98+
return "";
99+
}
100+
}
101+
return "";
102+
}
103+
};
104+
105+
struct CUDAWarpIntrinsic {
106+
const Op operator()(DataType t, const Op &orig_op) const {
107+
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
108+
return Op::Get("tir.cuda.__shfl_sync");
109+
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
110+
return Op::Get("tir.cuda.__shfl_up_sync");
111+
} else {
112+
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
113+
return Op::Get("tir.cuda.__shfl_down_sync");
114+
}
115+
}
116+
};
117+
118+
static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) {
119+
const CallNode *call = e.as<CallNode>();
120+
return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
121+
}
122+
123+
template <typename T> static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) {
124+
const CallNode *call = e.as<CallNode>();
125+
ICHECK(call != nullptr);
126+
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
127+
Array<PrimExpr> cuda_args{
128+
{call->args[0], call->args[1], call->args[2], call->args[3]}};
129+
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
130+
}
131+
132+
TVM_REGISTER_OP("tir.rsqrt")
133+
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
134+
DispatchPureExtern<CUDAMath>);
135+
136+
} // namespace intrin
137+
} // namespace codegen
138+
} // namespace tvm

src/tl_templates/cuda/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
6060
return half_t(hrsqrt(x.to_half()));
6161
}
6262

63-
6463
// Pack two half values.
6564
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
6665
unsigned v0 = *((unsigned short *)&x);

0 commit comments

Comments
 (0)