@@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) {
8989 index = self + delta;
9090 index = Select ((self & (width - 1 )) + delta >= width, self, index);
9191 }
92+ // reinterprete var as int32
93+ bool is_int32 = var.dtype ().is_int () && var.dtype ().bits () == 32 ;
94+ PrimExpr source = is_int32 ? var : reinterpret (DataType::Int (32 ), var);
9295 PrimExpr res = Call (DataType::Int (32 ), builtin::call_pure_extern (),
93- {StringImm (" llvm.amdgcn.ds.bpermute" ), index << 2 , var});
96+ {StringImm (" llvm.amdgcn.ds.bpermute" ), index << 2 , source});
97+ if (!is_int32) {
98+ res = reinterpret (var.dtype (), res);
99+ }
94100 return res;
95101}
96102
@@ -114,28 +120,35 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
114120 .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchShuffle);
115121
116122TVM_REGISTER_OP (" tir.floor" )
117- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
123+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
124+ DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1 >);
118125
119126TVM_REGISTER_OP (" tir.ceil" )
120- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
127+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
128+ DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1 >);
121129
122130TVM_REGISTER_OP (" tir.round" )
123- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
131+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
132+ DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1 >);
124133
125134TVM_REGISTER_OP (" tir.nearbyint" )
126- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
135+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
136+ DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1 >);
127137
128138TVM_REGISTER_OP (" tir.trunc" )
129- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
139+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
140+ DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1 >);
130141
131142TVM_REGISTER_OP (" tir.fabs" )
132- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
143+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
144+ DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1 >);
133145
134- TVM_REGISTER_OP (" tir.exp" ).set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic " ,
135- DispatchPureExternOCML );
146+ TVM_REGISTER_OP (" tir.exp" ).set_attr<FLowerIntrinsic>(
147+ " rocm.FLowerIntrinsic " , DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1 > );
136148
137149TVM_REGISTER_OP (" tir.exp2" )
138- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
150+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
151+ DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1 >);
139152
140153TVM_REGISTER_OP (" tir.exp10" )
141154 .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
@@ -146,35 +159,38 @@ TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
146159TVM_REGISTER_OP (" tir.fma" ).set_attr<FLowerIntrinsic>(
147160 " rocm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3 >);
148161
149- TVM_REGISTER_OP (" tir.log" ).set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic " ,
150- DispatchPureExternOCML );
162+ TVM_REGISTER_OP (" tir.log" ).set_attr<FLowerIntrinsic>(
163+ " rocm.FLowerIntrinsic " , DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1 > );
151164
152165TVM_REGISTER_OP (" tir.log2" )
153- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
166+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
167+ DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1 >);
154168
155169TVM_REGISTER_OP (" tir.log10" )
156- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
170+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
171+ DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1 >);
157172
158173TVM_REGISTER_OP (" tir.sqrt" )
159- .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
174+ .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
175+ DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1 >);
160176
161- TVM_REGISTER_OP (" tir.pow" ).set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic " ,
162- DispatchPureExternOCML );
177+ TVM_REGISTER_OP (" tir.pow" ).set_attr<FLowerIntrinsic>(
178+ " rocm.FLowerIntrinsic " , DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2 > );
163179
164180TVM_REGISTER_OP (" tir.tanh" )
165181 .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
166182
167183TVM_REGISTER_OP (" tir.tan" ).set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" ,
168184 DispatchPureExternOCML);
169185
170- TVM_REGISTER_OP (" tir.cos" ).set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic " ,
171- DispatchPureExternOCML );
186+ TVM_REGISTER_OP (" tir.cos" ).set_attr<FLowerIntrinsic>(
187+ " rocm.FLowerIntrinsic " , DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1 > );
172188
173189TVM_REGISTER_OP (" tir.cosh" )
174190 .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
175191
176- TVM_REGISTER_OP (" tir.sin" ).set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic " ,
177- DispatchPureExternOCML );
192+ TVM_REGISTER_OP (" tir.sin" ).set_attr<FLowerIntrinsic>(
193+ " rocm.FLowerIntrinsic " , DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1 > );
178194
179195TVM_REGISTER_OP (" tir.sinh" )
180196 .set_attr<FLowerIntrinsic>(" rocm.FLowerIntrinsic" , DispatchPureExternOCML);
0 commit comments