2121 * \file Use external Thrust library call
2222 */
2323
24+ #include < dlpack/dlpack.h>
25+ #include < thrust/detail/caching_allocator.h>
2426#include < thrust/device_ptr.h>
2527#include < thrust/device_vector.h>
26- #include < thrust/sort.h>
2728#include < thrust/gather.h>
2829#include < thrust/scan.h>
2930#include < thrust/sequence.h>
30-
31+ # include < thrust/sort.h >
3132#include < tvm/runtime/registry.h>
32- # include < dlpack/dlpack.h >
33+
3334#include < algorithm>
34- #include < vector>
3535#include < functional>
36+ #include < vector>
37+
38+ #include " ../../cuda/cuda_common.h"
3639
3740namespace tvm {
3841namespace contrib {
3942
4043using namespace runtime ;
4144
45+ auto get_thrust_exec_policy () {
46+ return thrust::cuda::par_nosync (thrust::detail::single_device_tls_caching_allocator ())
47+ .on (GetCUDAStream ());
48+ }
49+
4250// Performs sorting along axis -1 and returns both sorted values and indices.
43- template <typename DataType, typename IndicesType>
44- void thrust_sort (DLTensor* input,
45- DLTensor* out_values,
46- DLTensor* out_indices,
47- bool is_ascend,
51+ template <typename DataType, typename IndicesType>
52+ void thrust_sort (DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend,
4853 int n_values) {
49- thrust::device_ptr<DataType> data_ptr (static_cast <DataType *>(input->data ));
50- thrust::device_ptr<DataType> values_ptr (static_cast <DataType *>(out_values->data ));
51- thrust::device_ptr<IndicesType> indices_ptr (static_cast <IndicesType *>(out_indices->data ));
54+ thrust::device_ptr<DataType> data_ptr (static_cast <DataType*>(input->data ));
55+ thrust::device_ptr<DataType> values_ptr (static_cast <DataType*>(out_values->data ));
56+ thrust::device_ptr<IndicesType> indices_ptr (static_cast <IndicesType*>(out_indices->data ));
57+
58+ auto policy = get_thrust_exec_policy ();
5259
5360 size_t size = 1 ;
5461 for (int i = 0 ; i < input->ndim ; ++i) {
5562 size *= input->shape [i];
5663 }
57- thrust::copy (data_ptr, data_ptr + size, values_ptr);
64+ thrust::copy (policy, data_ptr, data_ptr + size, values_ptr);
5865
5966 if (size == static_cast <size_t >(input->shape [input->ndim - 1 ])) {
6067 // A fast path for single segment case
6168 thrust::sequence (indices_ptr, indices_ptr + n_values);
6269 if (is_ascend) {
63- thrust::sort_by_key (values_ptr, values_ptr + n_values, indices_ptr);
70+ thrust::sort_by_key (policy, values_ptr, values_ptr + n_values, indices_ptr);
6471 } else {
65- thrust::sort_by_key (values_ptr, values_ptr + n_values, indices_ptr,
72+ thrust::sort_by_key (policy, values_ptr, values_ptr + n_values, indices_ptr,
6673 thrust::greater<DataType>());
6774 }
6875 } else {
@@ -74,9 +81,9 @@ void thrust_sort(DLTensor* input,
7481
7582 // First, sort values and store the sorted order in argsort_order.
7683 if (is_ascend) {
77- thrust::stable_sort_by_key (values_ptr, values_ptr + size, argsort_order.begin ());
84+ thrust::stable_sort_by_key (policy, values_ptr, values_ptr + size, argsort_order.begin ());
7885 } else {
79- thrust::stable_sort_by_key (values_ptr, values_ptr + size, argsort_order.begin (),
86+ thrust::stable_sort_by_key (policy, values_ptr, values_ptr + size, argsort_order.begin (),
8087 thrust::greater<DataType>());
8188 }
8289
@@ -85,36 +92,33 @@ void thrust_sort(DLTensor* input,
8592 auto counting_iter = thrust::counting_iterator<int64_t >(0 );
8693 auto linear_index_to_sort_axis_index = [n_values] __host__ __device__ (int64_t i) {
8794 return i % n_values;
88- }; // NOLINT(*)
89- auto init_indices_iter = thrust::make_transform_iterator (counting_iter,
90- linear_index_to_sort_axis_index);
95+ }; // NOLINT(*)
96+ auto init_indices_iter =
97+ thrust::make_transform_iterator (counting_iter, linear_index_to_sort_axis_index);
9198
9299 // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
93- thrust::gather (argsort_order.begin (), argsort_order.end (), init_indices_iter, indices_ptr);
100+ thrust::gather (policy, argsort_order.begin (), argsort_order.end (), init_indices_iter,
101+ indices_ptr);
94102
95103 thrust::device_vector<int > segment_ids (size);
96104 auto linear_index_to_segment_id = [n_values] __host__ __device__ (int64_t i) {
97105 return i / n_values;
98- }; // NOLINT(*)
106+ }; // NOLINT(*)
99107 // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
100- thrust::transform (argsort_order.begin (), argsort_order.end (), segment_ids.begin (),
108+ thrust::transform (policy, argsort_order.begin (), argsort_order.end (), segment_ids.begin (),
101109 linear_index_to_segment_id);
102110
103111 // The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
104112 // values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
105113 // Since sorting has been done in a stable way, relative orderings of values and indices
106114 // in the segment do not change and hence they remain sorted.
107115 auto key_val_zip = thrust::make_zip_iterator (thrust::make_tuple (values_ptr, indices_ptr));
108- thrust::stable_sort_by_key (segment_ids.begin (), segment_ids.end (), key_val_zip);
116+ thrust::stable_sort_by_key (policy, segment_ids.begin (), segment_ids.end (), key_val_zip);
109117 }
110118}
111119
112- void thrust_sort_common (DLTensor* input,
113- DLTensor* values_out,
114- DLTensor* indices_out,
115- bool is_ascend,
116- int sort_len,
117- std::string data_dtype,
120+ void thrust_sort_common (DLTensor* input, DLTensor* values_out, DLTensor* indices_out,
121+ bool is_ascend, int sort_len, std::string data_dtype,
118122 std::string out_dtype) {
119123 if (data_dtype == " float32" ) {
120124 if (out_dtype == " int32" ) {
@@ -152,7 +156,7 @@ void thrust_sort_common(DLTensor* input,
152156 } else {
153157 LOG (FATAL) << " Unsupported output dtype: " << out_dtype;
154158 }
155- } else if (data_dtype == " int64" ) {
159+ } else if (data_dtype == " int64" ) {
156160 if (out_dtype == " int32" ) {
157161 thrust_sort<int64_t , int32_t >(input, values_out, indices_out, is_ascend, sort_len);
158162 } else if (out_dtype == " int64" ) {
@@ -169,8 +173,7 @@ void thrust_sort_common(DLTensor* input,
169173 }
170174}
171175
172- TVM_REGISTER_GLOBAL (" tvm.contrib.thrust.sort" )
173- .set_body([](TVMArgs args, TVMRetValue* ret) {
176+ TVM_REGISTER_GLOBAL (" tvm.contrib.thrust.sort" ).set_body([](TVMArgs args, TVMRetValue* ret) {
174177 ICHECK_GE (args.num_args , 4 );
175178 DLTensor* input = args[0 ];
176179 DLTensor* values_out = args[1 ];
@@ -181,97 +184,94 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
181184 auto out_dtype = DLDataType2String (indices_out->dtype );
182185
183186 int n_values = input->shape [input->ndim - 1 ];
184- thrust_sort_common (input, values_out, indices_out, is_ascend, n_values,
185- data_dtype, out_dtype);
187+ thrust_sort_common (input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype);
186188});
187189
188- template <typename KeyType, typename ValueType>
189- void thrust_stable_sort_by_key (DLTensor* keys_in,
190- DLTensor* values_in,
191- DLTensor* keys_out,
192- DLTensor* values_out,
193- bool for_scatter) {
190+ template <typename KeyType, typename ValueType>
191+ void thrust_stable_sort_by_key (DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
192+ DLTensor* values_out, bool for_scatter) {
194193 const auto size = keys_in->shape [0 ];
195- thrust::device_ptr<KeyType> keys_in_ptr (static_cast <KeyType *>(keys_in->data ));
196- thrust::device_ptr<ValueType> values_in_ptr (static_cast <ValueType *>(values_in->data ));
197- thrust::device_ptr<KeyType> keys_out_ptr (static_cast <KeyType *>(keys_out->data ));
198- thrust::device_ptr<ValueType> values_out_ptr (static_cast <ValueType *>(values_out->data ));
194+ thrust::device_ptr<KeyType> keys_in_ptr (static_cast <KeyType*>(keys_in->data ));
195+ thrust::device_ptr<ValueType> values_in_ptr (static_cast <ValueType*>(values_in->data ));
196+ thrust::device_ptr<KeyType> keys_out_ptr (static_cast <KeyType*>(keys_out->data ));
197+ thrust::device_ptr<ValueType> values_out_ptr (static_cast <ValueType*>(values_out->data ));
198+
199+ auto policy = get_thrust_exec_policy ();
199200
200201 if (for_scatter) {
201- thrust::transform (keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__ (KeyType k) {
202- if (k < 0 ) return k + static_cast <KeyType>(size);
203- return k;
204- });
202+ thrust::transform (policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr,
203+ [size] __device__ (KeyType k) {
204+ if (k < 0 ) return k + static_cast <KeyType>(size);
205+ return k;
206+ });
205207 } else {
206- thrust::copy (keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
208+ thrust::copy (policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
207209 }
208- thrust::copy (values_in_ptr, values_in_ptr + size, values_out_ptr);
210+ thrust::copy (policy, values_in_ptr, values_in_ptr + size, values_out_ptr);
209211
210- thrust::stable_sort_by_key (keys_out_ptr, keys_out_ptr + size, values_out_ptr);
212+ thrust::stable_sort_by_key (policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr);
211213}
212214
213215TVM_REGISTER_GLOBAL (" tvm.contrib.thrust.stable_sort_by_key" )
214- .set_body([](TVMArgs args, TVMRetValue* ret) {
215- ICHECK_GE (args.num_args , 5 );
216- DLTensor* keys_in = args[0 ];
217- DLTensor* values_in = args[1 ];
218- DLTensor* keys_out = args[2 ];
219- DLTensor* values_out = args[3 ];
220- bool for_scatter = args[4 ];
221-
222- auto key_dtype = DLDataType2String (keys_in->dtype );
223- auto value_dtype = DLDataType2String (values_in->dtype );
224-
225- if (key_dtype == " int32" ) {
226- if (value_dtype == " int32" ) {
227- thrust_stable_sort_by_key<int , int >(keys_in, values_in, keys_out, values_out,
228- for_scatter);
229- } else if (value_dtype == " int64" ) {
230- thrust_stable_sort_by_key<int , int64_t >(keys_in, values_in, keys_out, values_out,
216+ .set_body([](TVMArgs args, TVMRetValue* ret) {
217+ ICHECK_GE (args.num_args , 5 );
218+ DLTensor* keys_in = args[0 ];
219+ DLTensor* values_in = args[1 ];
220+ DLTensor* keys_out = args[2 ];
221+ DLTensor* values_out = args[3 ];
222+ bool for_scatter = args[4 ];
223+
224+ auto key_dtype = DLDataType2String (keys_in->dtype );
225+ auto value_dtype = DLDataType2String (values_in->dtype );
226+
227+ if (key_dtype == " int32" ) {
228+ if (value_dtype == " int32" ) {
229+ thrust_stable_sort_by_key<int , int >(keys_in, values_in, keys_out, values_out,
231230 for_scatter);
232- } else if (value_dtype == " float32" ) {
233- thrust_stable_sort_by_key<int , float >(keys_in, values_in, keys_out, values_out,
234- for_scatter);
235- } else {
236- LOG (FATAL) << " Unsupported value dtype: " << value_dtype;
237- }
238- } else if (key_dtype == " int64" ) {
239- if (value_dtype == " int32" ) {
240- thrust_stable_sort_by_key<int64_t , int >(keys_in, values_in, keys_out, values_out,
241- for_scatter);
242- } else if (value_dtype == " int64" ) {
243- thrust_stable_sort_by_key<int64_t , int64_t >(keys_in, values_in, keys_out, values_out,
231+ } else if (value_dtype == " int64" ) {
232+ thrust_stable_sort_by_key<int , int64_t >(keys_in, values_in, keys_out, values_out,
244233 for_scatter);
245- } else if (value_dtype == " float32" ) {
246- thrust_stable_sort_by_key<int64_t , float >(keys_in, values_in, keys_out, values_out,
234+ } else if (value_dtype == " float32" ) {
235+ thrust_stable_sort_by_key<int , float >(keys_in, values_in, keys_out, values_out,
247236 for_scatter);
248- } else {
249- LOG (FATAL) << " Unsupported value dtype: " << value_dtype;
250- }
251- } else if (key_dtype == " float32" ) {
252- if (value_dtype == " int32" ) {
253- thrust_stable_sort_by_key<float , int >(keys_in, values_in, keys_out, values_out,
254- for_scatter);
255- } else if (value_dtype == " int64" ) {
256- thrust_stable_sort_by_key<float , int64_t >(keys_in, values_in, keys_out, values_out,
257- for_scatter);
258- } else if (value_dtype == " float32" ) {
259- thrust_stable_sort_by_key<float , float >(keys_in, values_in, keys_out, values_out,
260- for_scatter);
261- } else {
262- LOG (FATAL) << " Unsupported value dtype: " << value_dtype;
263- }
264- } else {
265- LOG (FATAL) << " Unsupported key dtype: " << key_dtype;
266- }
267- });
237+ } else {
238+ LOG (FATAL) << " Unsupported value dtype: " << value_dtype;
239+ }
240+ } else if (key_dtype == " int64" ) {
241+ if (value_dtype == " int32" ) {
242+ thrust_stable_sort_by_key<int64_t , int >(keys_in, values_in, keys_out, values_out,
243+ for_scatter);
244+ } else if (value_dtype == " int64" ) {
245+ thrust_stable_sort_by_key<int64_t , int64_t >(keys_in, values_in, keys_out, values_out,
246+ for_scatter);
247+ } else if (value_dtype == " float32" ) {
248+ thrust_stable_sort_by_key<int64_t , float >(keys_in, values_in, keys_out, values_out,
249+ for_scatter);
250+ } else {
251+ LOG (FATAL) << " Unsupported value dtype: " << value_dtype;
252+ }
253+ } else if (key_dtype == " float32" ) {
254+ if (value_dtype == " int32" ) {
255+ thrust_stable_sort_by_key<float , int >(keys_in, values_in, keys_out, values_out,
256+ for_scatter);
257+ } else if (value_dtype == " int64" ) {
258+ thrust_stable_sort_by_key<float , int64_t >(keys_in, values_in, keys_out, values_out,
259+ for_scatter);
260+ } else if (value_dtype == " float32" ) {
261+ thrust_stable_sort_by_key<float , float >(keys_in, values_in, keys_out, values_out,
262+ for_scatter);
263+ } else {
264+ LOG (FATAL) << " Unsupported value dtype: " << value_dtype;
265+ }
266+ } else {
267+ LOG (FATAL) << " Unsupported key dtype: " << key_dtype;
268+ }
269+ });
268270
269- template <typename InType, typename OutType>
270- void thrust_scan (DLTensor* data,
271- DLTensor* output,
272- bool exclusive) {
273- thrust::device_ptr<InType> data_ptr (static_cast <InType *>(data->data ));
274- thrust::device_ptr<OutType> output_ptr (static_cast <OutType *>(output->data ));
271+ template <typename InType, typename OutType>
272+ void thrust_scan (DLTensor* data, DLTensor* output, bool exclusive) {
273+ thrust::device_ptr<InType> data_ptr (static_cast <InType*>(data->data ));
274+ thrust::device_ptr<OutType> output_ptr (static_cast <OutType*>(output->data ));
275275 const auto scan_size = data->shape [data->ndim - 1 ];
276276
277277 if (scan_size == 0 ) return ;
@@ -281,19 +281,20 @@ void thrust_scan(DLTensor* data,
281281
282282 const bool need_cast = std::is_same<InType, OutType>::value == false ;
283283
284- auto data_cast_ptr = thrust::make_transform_iterator (data_ptr, [] __host__ __device__ (InType v) {
285- return static_cast <OutType>(v);
286- }); // NOLINT(*)
284+ auto data_cast_ptr = thrust::make_transform_iterator (
285+ data_ptr, [] __host__ __device__ (InType v) { return static_cast <OutType>(v); }); // NOLINT(*)
286+
287+ auto policy = get_thrust_exec_policy ();
287288
288289 if (size == static_cast <size_t >(data->shape [data->ndim - 1 ])) {
289290 if (exclusive && need_cast) {
290- thrust::exclusive_scan (data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
291+ thrust::exclusive_scan (policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
291292 } else if (exclusive && !need_cast) {
292- thrust::exclusive_scan (data_ptr, data_ptr + scan_size, output_ptr);
293+ thrust::exclusive_scan (policy, data_ptr, data_ptr + scan_size, output_ptr);
293294 } else if (!exclusive && need_cast) {
294- thrust::inclusive_scan (data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
295+ thrust::inclusive_scan (policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
295296 } else {
296- thrust::inclusive_scan (data_ptr, data_ptr + scan_size, output_ptr);
297+ thrust::inclusive_scan (policy, data_ptr, data_ptr + scan_size, output_ptr);
297298 }
298299 } else {
299300 // Use thrust segmented scan to compute scan on the inner most axis
@@ -305,18 +306,18 @@ void thrust_scan(DLTensor* data,
305306 auto counting_iter = thrust::counting_iterator<size_t >(0 );
306307 // Without __host__ annotation, cub crashes
307308 auto linear_index_to_scan_key = [scan_size] __host__ __device__ (size_t i) {
308- return i / scan_size;
309- }; // NOLINT(*)
309+ return i / scan_size;
310+ }; // NOLINT(*)
310311 auto key_iter = thrust::make_transform_iterator (counting_iter, linear_index_to_scan_key);
311312
312313 if (exclusive && need_cast) {
313- thrust::exclusive_scan_by_key (key_iter, key_iter + size, data_cast_ptr, output_ptr);
314+ thrust::exclusive_scan_by_key (policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
314315 } else if (exclusive && !need_cast) {
315- thrust::exclusive_scan_by_key (key_iter, key_iter + size, data_ptr, output_ptr);
316+ thrust::exclusive_scan_by_key (policy, key_iter, key_iter + size, data_ptr, output_ptr);
316317 } else if (!exclusive && need_cast) {
317- thrust::inclusive_scan_by_key (key_iter, key_iter + size, data_cast_ptr, output_ptr);
318+ thrust::inclusive_scan_by_key (policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
318319 } else {
319- thrust::inclusive_scan_by_key (key_iter, key_iter + size, data_ptr, output_ptr);
320+ thrust::inclusive_scan_by_key (policy, key_iter, key_iter + size, data_ptr, output_ptr);
320321 }
321322 }
322323}
0 commit comments