22 * \file intrin_rule_cuda.cc
33 * \brief CUDA intrinsic rules.
44 */
5- #include < tvm/tir/builtin.h>
6- #include < tvm/tir/op_attr_types.h>
7-
8- #include " target/intrin_rule.h"
9-
10- namespace tvm {
11- namespace codegen {
12- namespace intrin {
13- // Add float suffix to the intrinsics, CUDA fast math.
14- using tir::FLowerIntrinsic;
15-
16- struct CUDAMath {
17- std::string operator ()(DataType t, std::string name) const {
18- if (t.is_float ()) {
19- switch (t.bits ()) {
20- case 64 :
21- return name;
22- case 32 :
23- return name + ' f' ;
24- case 16 : {
25- if (name == " fabs" ) {
26- return " __habs" ;
27- } else if (name == " round" ) {
28- return " hrint" ;
29- } else {
30- return " h" + name;
31- }
32- }
33- default :
34- return " " ;
35- }
36- } else if (t.is_bfloat16 ()) {
37- if (name == " fabs" ) {
38- return " __habs" ;
39- } else if (name == " round" ) {
40- return " hrint" ;
41- } else {
42- return " h" + name;
43- }
44- } else if (t.is_int () || t.is_uint ()) {
45- switch (t.bits ()) {
46- case 32 :
47- return " __" + name;
48- case 64 :
49- return " __" + name + " ll" ;
50- default :
51- return " " ;
52- }
53- }
54- return " " ;
55- }
56- };
57-
58- struct CUDAFastMath : public CUDAMath {
59- std::string operator ()(DataType t, std::string name) const {
60- if (t.is_float () && t.bits () == 32 ) {
61- return " __" + name + ' f' ;
62- } else {
63- return CUDAMath::operator ()(t, name);
64- }
65- return " " ;
66- }
67- };
68-
69- struct CUDAFastMathTan : public CUDAMath {
70- std::string operator ()(DataType t, std::string name) const {
71- if (t.is_float ()) {
72- switch (t.bits ()) {
73- case 64 :
74- return name;
75- // `__tanf` seems to produce some values too deviant from numpy tan version.
76- // So, let's use just `tanf` instead.
77- case 32 :
78- return name + ' f' ;
79- case 16 :
80- return ' h' + name;
81- default :
82- return " " ;
83- }
84- }
85- return " " ;
86- }
87- };
88-
89- struct CUDAPopcount {
90- std::string operator ()(DataType t, std::string name) const {
91- if (t.is_uint ()) {
92- switch (t.bits ()) {
93- case 32 :
94- return " __popc" ;
95- case 64 :
96- return " __popcll" ;
97- default :
98- return " " ;
99- }
100- }
101- return " " ;
102- }
103- };
104-
105- struct CUDAWarpIntrinsic {
106- const Op operator ()(DataType t, const Op& orig_op) const {
107- if (orig_op.same_as (builtin::tvm_warp_shuffle ())) {
108- return Op::Get (" tir.cuda.__shfl_sync" );
109- } else if (orig_op.same_as (builtin::tvm_warp_shuffle_up ())) {
110- return Op::Get (" tir.cuda.__shfl_up_sync" );
111- } else {
112- ICHECK (orig_op.same_as (builtin::tvm_warp_shuffle_down ()));
113- return Op::Get (" tir.cuda.__shfl_down_sync" );
114- }
115- }
116- };
117-
118- static PrimExpr DispatchCUDAWarpActiveMask (const PrimExpr& e) {
119- const CallNode* call = e.as <CallNode>();
120- return Call (call->dtype , Op::Get (" tir.cuda.__activemask" ), call->args );
121- }
122-
123- template <typename T>
124- static PrimExpr DispatchCUDAShuffle ( const PrimExpr& e) {
125- const CallNode* call = e. as <CallNode>( );
126- ICHECK (call != nullptr );
127- ICHECK_EQ (call-> args . size (), 5 ); // mask, value, warp_id, width, warp_size
128- Array<PrimExpr> cuda_args{ {call->args [0 ], call->args [1 ], call->args [2 ], call->args [3 ]}};
129- return Call (call->dtype , T ()(call->dtype , Downcast<Op>(call->op )), cuda_args);
130- }
131-
132- TVM_REGISTER_OP (" tir.rsqrt" )
133- .set_attr<FLowerIntrinsic>(" cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>);
134-
135- } // namespace intrin
136- } // namespace codegen
137- } // namespace tvm
138-
5+ #include < tvm/tir/builtin.h>
6+ #include < tvm/tir/op_attr_types.h>
7+
8+ #include " target/intrin_rule.h"
9+
10+ namespace tvm {
11+ namespace codegen {
12+ namespace intrin {
13+ // Add float suffix to the intrinsics, CUDA fast math.
14+ using tir::FLowerIntrinsic;
15+
16+ struct CUDAMath {
17+ std::string operator ()(DataType t, std::string name) const {
18+ if (t.is_float ()) {
19+ switch (t.bits ()) {
20+ case 64 :
21+ return name;
22+ case 32 :
23+ return name + ' f' ;
24+ case 16 : {
25+ if (name == " fabs" ) {
26+ return " __habs" ;
27+ } else if (name == " round" ) {
28+ return " hrint" ;
29+ } else {
30+ return " h" + name;
31+ }
32+ }
33+ default :
34+ return " " ;
35+ }
36+ } else if (t.is_bfloat16 ()) {
37+ if (name == " fabs" ) {
38+ return " __habs" ;
39+ } else if (name == " round" ) {
40+ return " hrint" ;
41+ } else {
42+ return " h" + name;
43+ }
44+ } else if (t.is_int () || t.is_uint ()) {
45+ switch (t.bits ()) {
46+ case 32 :
47+ return " __" + name;
48+ case 64 :
49+ return " __" + name + " ll" ;
50+ default :
51+ return " " ;
52+ }
53+ }
54+ return " " ;
55+ }
56+ };
57+
58+ struct CUDAFastMath : public CUDAMath {
59+ std::string operator ()(DataType t, std::string name) const {
60+ if (t.is_float () && t.bits () == 32 ) {
61+ return " __" + name + ' f' ;
62+ } else {
63+ return CUDAMath::operator ()(t, name);
64+ }
65+ return " " ;
66+ }
67+ };
68+
69+ struct CUDAFastMathTan : public CUDAMath {
70+ std::string operator ()(DataType t, std::string name) const {
71+ if (t.is_float ()) {
72+ switch (t.bits ()) {
73+ case 64 :
74+ return name;
75+ // `__tanf` seems to produce some values too deviant from numpy tan
76+ // version. So, let's use just `tanf` instead.
77+ case 32 :
78+ return name + ' f' ;
79+ case 16 :
80+ return ' h' + name;
81+ default :
82+ return " " ;
83+ }
84+ }
85+ return " " ;
86+ }
87+ };
88+
89+ struct CUDAPopcount {
90+ std::string operator ()(DataType t, std::string name) const {
91+ if (t.is_uint ()) {
92+ switch (t.bits ()) {
93+ case 32 :
94+ return " __popc" ;
95+ case 64 :
96+ return " __popcll" ;
97+ default :
98+ return " " ;
99+ }
100+ }
101+ return " " ;
102+ }
103+ };
104+
105+ struct CUDAWarpIntrinsic {
106+ const Op operator ()(DataType t, const Op & orig_op) const {
107+ if (orig_op.same_as (builtin::tvm_warp_shuffle ())) {
108+ return Op::Get (" tir.cuda.__shfl_sync" );
109+ } else if (orig_op.same_as (builtin::tvm_warp_shuffle_up ())) {
110+ return Op::Get (" tir.cuda.__shfl_up_sync" );
111+ } else {
112+ ICHECK (orig_op.same_as (builtin::tvm_warp_shuffle_down ()));
113+ return Op::Get (" tir.cuda.__shfl_down_sync" );
114+ }
115+ }
116+ };
117+
118+ static PrimExpr DispatchCUDAWarpActiveMask (const PrimExpr & e) {
119+ const CallNode * call = e.as <CallNode>();
120+ return Call (call->dtype , Op::Get (" tir.cuda.__activemask" ), call->args );
121+ }
122+
123+ template <typename T> static PrimExpr DispatchCUDAShuffle ( const PrimExpr &e) {
124+ const CallNode *call = e. as <CallNode>();
125+ ICHECK ( call != nullptr );
126+ ICHECK_EQ (call-> args . size (), 5 ); // mask, value, warp_id, width, warp_size
127+ Array<PrimExpr> cuda_args{
128+ {call->args [0 ], call->args [1 ], call->args [2 ], call->args [3 ]}};
129+ return Call (call->dtype , T ()(call->dtype , Downcast<Op>(call->op )), cuda_args);
130+ }
131+
132+ TVM_REGISTER_OP (" tir.rsqrt" )
133+ .set_attr<FLowerIntrinsic>(" cuda.FLowerIntrinsic" ,
134+ DispatchPureExtern<CUDAMath>);
135+
136+ } // namespace intrin
137+ } // namespace codegen
138+ } // namespace tvm
0 commit comments