Skip to content

Commit 5d4c01e

Browse files
authored
[Thrust] Use no sync exec policy and caching allocator (#16386)
1 parent e2e33dd commit 5d4c01e

File tree

2 files changed

+126
-122
lines changed

2 files changed

+126
-122
lines changed

src/runtime/contrib/thrust/thrust.cu

Lines changed: 123 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -21,48 +21,55 @@
2121
* \file Use external Thrust library call
2222
*/
2323

24+
#include <dlpack/dlpack.h>
25+
#include <thrust/detail/caching_allocator.h>
2426
#include <thrust/device_ptr.h>
2527
#include <thrust/device_vector.h>
26-
#include <thrust/sort.h>
2728
#include <thrust/gather.h>
2829
#include <thrust/scan.h>
2930
#include <thrust/sequence.h>
30-
31+
#include <thrust/sort.h>
3132
#include <tvm/runtime/registry.h>
32-
#include <dlpack/dlpack.h>
33+
3334
#include <algorithm>
34-
#include <vector>
3535
#include <functional>
36+
#include <vector>
37+
38+
#include "../../cuda/cuda_common.h"
3639

3740
namespace tvm {
3841
namespace contrib {
3942

4043
using namespace runtime;
4144

45+
auto get_thrust_exec_policy() {
46+
return thrust::cuda::par_nosync(thrust::detail::single_device_tls_caching_allocator())
47+
.on(GetCUDAStream());
48+
}
49+
4250
// Performs sorting along axis -1 and returns both sorted values and indices.
43-
template<typename DataType, typename IndicesType>
44-
void thrust_sort(DLTensor* input,
45-
DLTensor* out_values,
46-
DLTensor* out_indices,
47-
bool is_ascend,
51+
template <typename DataType, typename IndicesType>
52+
void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend,
4853
int n_values) {
49-
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
50-
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
51-
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
54+
thrust::device_ptr<DataType> data_ptr(static_cast<DataType*>(input->data));
55+
thrust::device_ptr<DataType> values_ptr(static_cast<DataType*>(out_values->data));
56+
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType*>(out_indices->data));
57+
58+
auto policy = get_thrust_exec_policy();
5259

5360
size_t size = 1;
5461
for (int i = 0; i < input->ndim; ++i) {
5562
size *= input->shape[i];
5663
}
57-
thrust::copy(data_ptr, data_ptr + size, values_ptr);
64+
thrust::copy(policy, data_ptr, data_ptr + size, values_ptr);
5865

5966
if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
6067
// A fast path for single segment case
6168
thrust::sequence(indices_ptr, indices_ptr + n_values);
6269
if (is_ascend) {
63-
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
70+
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr);
6471
} else {
65-
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
72+
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr,
6673
thrust::greater<DataType>());
6774
}
6875
} else {
@@ -74,9 +81,9 @@ void thrust_sort(DLTensor* input,
7481

7582
// First, sort values and store the sorted order in argsort_order.
7683
if (is_ascend) {
77-
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin());
84+
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order.begin());
7885
} else {
79-
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(),
86+
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order.begin(),
8087
thrust::greater<DataType>());
8188
}
8289

@@ -85,36 +92,33 @@ void thrust_sort(DLTensor* input,
8592
auto counting_iter = thrust::counting_iterator<int64_t>(0);
8693
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
8794
return i % n_values;
88-
}; // NOLINT(*)
89-
auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
90-
linear_index_to_sort_axis_index);
95+
}; // NOLINT(*)
96+
auto init_indices_iter =
97+
thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index);
9198

9299
// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
93-
thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr);
100+
thrust::gather(policy, argsort_order.begin(), argsort_order.end(), init_indices_iter,
101+
indices_ptr);
94102

95103
thrust::device_vector<int> segment_ids(size);
96104
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
97105
return i / n_values;
98-
}; // NOLINT(*)
106+
}; // NOLINT(*)
99107
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
100-
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
108+
thrust::transform(policy, argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
101109
linear_index_to_segment_id);
102110

103111
// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
104112
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
105113
// Since sorting has been done in a stable way, relative orderings of values and indices
106114
// in the segment do not change and hence they remain sorted.
107115
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
108-
thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip);
116+
thrust::stable_sort_by_key(policy, segment_ids.begin(), segment_ids.end(), key_val_zip);
109117
}
110118
}
111119

112-
void thrust_sort_common(DLTensor* input,
113-
DLTensor* values_out,
114-
DLTensor* indices_out,
115-
bool is_ascend,
116-
int sort_len,
117-
std::string data_dtype,
120+
void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out,
121+
bool is_ascend, int sort_len, std::string data_dtype,
118122
std::string out_dtype) {
119123
if (data_dtype == "float32") {
120124
if (out_dtype == "int32") {
@@ -152,7 +156,7 @@ void thrust_sort_common(DLTensor* input,
152156
} else {
153157
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
154158
}
155-
} else if (data_dtype == "int64") {
159+
} else if (data_dtype == "int64") {
156160
if (out_dtype == "int32") {
157161
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
158162
} else if (out_dtype == "int64") {
@@ -169,8 +173,7 @@ void thrust_sort_common(DLTensor* input,
169173
}
170174
}
171175

172-
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
173-
.set_body([](TVMArgs args, TVMRetValue* ret) {
176+
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetValue* ret) {
174177
ICHECK_GE(args.num_args, 4);
175178
DLTensor* input = args[0];
176179
DLTensor* values_out = args[1];
@@ -181,97 +184,94 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
181184
auto out_dtype = DLDataType2String(indices_out->dtype);
182185

183186
int n_values = input->shape[input->ndim - 1];
184-
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
185-
data_dtype, out_dtype);
187+
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype);
186188
});
187189

188-
template<typename KeyType, typename ValueType>
189-
void thrust_stable_sort_by_key(DLTensor* keys_in,
190-
DLTensor* values_in,
191-
DLTensor* keys_out,
192-
DLTensor* values_out,
193-
bool for_scatter) {
190+
template <typename KeyType, typename ValueType>
191+
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
192+
DLTensor* values_out, bool for_scatter) {
194193
const auto size = keys_in->shape[0];
195-
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType *>(keys_in->data));
196-
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType *>(values_in->data));
197-
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType *>(keys_out->data));
198-
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType *>(values_out->data));
194+
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType*>(keys_in->data));
195+
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType*>(values_in->data));
196+
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType*>(keys_out->data));
197+
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType*>(values_out->data));
198+
199+
auto policy = get_thrust_exec_policy();
199200

200201
if (for_scatter) {
201-
thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) {
202-
if (k < 0) return k + static_cast<KeyType>(size);
203-
return k;
204-
});
202+
thrust::transform(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr,
203+
[size] __device__(KeyType k) {
204+
if (k < 0) return k + static_cast<KeyType>(size);
205+
return k;
206+
});
205207
} else {
206-
thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
208+
thrust::copy(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
207209
}
208-
thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr);
210+
thrust::copy(policy, values_in_ptr, values_in_ptr + size, values_out_ptr);
209211

210-
thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr);
212+
thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr);
211213
}
212214

213215
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
214-
.set_body([](TVMArgs args, TVMRetValue* ret) {
215-
ICHECK_GE(args.num_args, 5);
216-
DLTensor* keys_in = args[0];
217-
DLTensor* values_in = args[1];
218-
DLTensor* keys_out = args[2];
219-
DLTensor* values_out = args[3];
220-
bool for_scatter = args[4];
221-
222-
auto key_dtype = DLDataType2String(keys_in->dtype);
223-
auto value_dtype = DLDataType2String(values_in->dtype);
224-
225-
if (key_dtype == "int32") {
226-
if (value_dtype == "int32") {
227-
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
228-
for_scatter);
229-
} else if (value_dtype == "int64") {
230-
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
216+
.set_body([](TVMArgs args, TVMRetValue* ret) {
217+
ICHECK_GE(args.num_args, 5);
218+
DLTensor* keys_in = args[0];
219+
DLTensor* values_in = args[1];
220+
DLTensor* keys_out = args[2];
221+
DLTensor* values_out = args[3];
222+
bool for_scatter = args[4];
223+
224+
auto key_dtype = DLDataType2String(keys_in->dtype);
225+
auto value_dtype = DLDataType2String(values_in->dtype);
226+
227+
if (key_dtype == "int32") {
228+
if (value_dtype == "int32") {
229+
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
231230
for_scatter);
232-
} else if (value_dtype == "float32") {
233-
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
234-
for_scatter);
235-
} else {
236-
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
237-
}
238-
} else if (key_dtype == "int64") {
239-
if (value_dtype == "int32") {
240-
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
241-
for_scatter);
242-
} else if (value_dtype == "int64") {
243-
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
231+
} else if (value_dtype == "int64") {
232+
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
244233
for_scatter);
245-
} else if (value_dtype == "float32") {
246-
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
234+
} else if (value_dtype == "float32") {
235+
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
247236
for_scatter);
248-
} else {
249-
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
250-
}
251-
} else if (key_dtype == "float32") {
252-
if (value_dtype == "int32") {
253-
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
254-
for_scatter);
255-
} else if (value_dtype == "int64") {
256-
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
257-
for_scatter);
258-
} else if (value_dtype == "float32") {
259-
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
260-
for_scatter);
261-
} else {
262-
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
263-
}
264-
} else {
265-
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
266-
}
267-
});
237+
} else {
238+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
239+
}
240+
} else if (key_dtype == "int64") {
241+
if (value_dtype == "int32") {
242+
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
243+
for_scatter);
244+
} else if (value_dtype == "int64") {
245+
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
246+
for_scatter);
247+
} else if (value_dtype == "float32") {
248+
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
249+
for_scatter);
250+
} else {
251+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
252+
}
253+
} else if (key_dtype == "float32") {
254+
if (value_dtype == "int32") {
255+
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
256+
for_scatter);
257+
} else if (value_dtype == "int64") {
258+
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
259+
for_scatter);
260+
} else if (value_dtype == "float32") {
261+
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
262+
for_scatter);
263+
} else {
264+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
265+
}
266+
} else {
267+
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
268+
}
269+
});
268270

269-
template<typename InType, typename OutType>
270-
void thrust_scan(DLTensor* data,
271-
DLTensor* output,
272-
bool exclusive) {
273-
thrust::device_ptr<InType> data_ptr(static_cast<InType *>(data->data));
274-
thrust::device_ptr<OutType> output_ptr(static_cast<OutType *>(output->data));
271+
template <typename InType, typename OutType>
272+
void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) {
273+
thrust::device_ptr<InType> data_ptr(static_cast<InType*>(data->data));
274+
thrust::device_ptr<OutType> output_ptr(static_cast<OutType*>(output->data));
275275
const auto scan_size = data->shape[data->ndim - 1];
276276

277277
if (scan_size == 0) return;
@@ -281,19 +281,20 @@ void thrust_scan(DLTensor* data,
281281

282282
const bool need_cast = std::is_same<InType, OutType>::value == false;
283283

284-
auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) {
285-
return static_cast<OutType>(v);
286-
}); // NOLINT(*)
284+
auto data_cast_ptr = thrust::make_transform_iterator(
285+
data_ptr, [] __host__ __device__(InType v) { return static_cast<OutType>(v); }); // NOLINT(*)
286+
287+
auto policy = get_thrust_exec_policy();
287288

288289
if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
289290
if (exclusive && need_cast) {
290-
thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
291+
thrust::exclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
291292
} else if (exclusive && !need_cast) {
292-
thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
293+
thrust::exclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
293294
} else if (!exclusive && need_cast) {
294-
thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
295+
thrust::inclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
295296
} else {
296-
thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
297+
thrust::inclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
297298
}
298299
} else {
299300
// Use thrust segmented scan to compute scan on the inner most axis
@@ -305,18 +306,18 @@ void thrust_scan(DLTensor* data,
305306
auto counting_iter = thrust::counting_iterator<size_t>(0);
306307
// Without __host__ annotation, cub crashes
307308
auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) {
308-
return i / scan_size;
309-
}; // NOLINT(*)
309+
return i / scan_size;
310+
}; // NOLINT(*)
310311
auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);
311312

312313
if (exclusive && need_cast) {
313-
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
314+
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
314315
} else if (exclusive && !need_cast) {
315-
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
316+
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
316317
} else if (!exclusive && need_cast) {
317-
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
318+
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
318319
} else {
319-
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
320+
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
320321
}
321322
}
322323
}

src/runtime/cuda/cuda_common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class CUDAThreadEntry {
6363
// get the threadlocal workspace
6464
static CUDAThreadEntry* ThreadLocal();
6565
};
66+
67+
inline cudaStream_t GetCUDAStream() { return CUDAThreadEntry::ThreadLocal()->stream; }
68+
6669
} // namespace runtime
6770
} // namespace tvm
6871
#endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_

0 commit comments

Comments
 (0)