Skip to content

Commit 3ad4eb3

Browse files
authored
[CUDA] Fix thrust with latest FFI refactor (#18024)
1 parent 731f133 commit 3ad4eb3

File tree

2 files changed

+96
-94
lines changed

2 files changed

+96
-94
lines changed

cmake/modules/CUDA.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ if(USE_CUDA)
109109
message(STATUS "Build with Thrust support")
110110
tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
111111
add_library(tvm_thrust_objs OBJECT ${CONTRIB_THRUST_SRC})
112+
target_link_libraries(tvm_thrust_objs PRIVATE tvm_ffi_header)
112113
target_compile_options(tvm_thrust_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>)
113114
target_compile_definitions(tvm_thrust_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
114115
if (NOT USE_THRUST MATCHES ${IS_TRUE_PATTERN})

src/runtime/contrib/thrust/thrust.cu

Lines changed: 95 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <thrust/scan.h>
3232
#include <thrust/sequence.h>
3333
#include <thrust/sort.h>
34+
#include <tvm/ffi/dtype.h>
3435
#include <tvm/ffi/function.h>
3536

3637
#include <algorithm>
@@ -233,24 +234,24 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices
233234
}
234235

235236
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
236-
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
237-
ICHECK_GE(args.num_args, 4);
238-
auto input = args[0].cast<DLTensor*>();
239-
auto values_out = args[1].cast<DLTensor*>();
240-
auto indices_out = args[2].cast<DLTensor*>();
241-
bool is_ascend = args[3].cast<bool>();
242-
DLTensor* workspace = nullptr;
243-
if (args.num_args == 5) {
244-
workspace = args[4];
245-
}
237+
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
238+
ICHECK_GE(args.size(), 4);
239+
auto input = args[0].cast<DLTensor*>();
240+
auto values_out = args[1].cast<DLTensor*>();
241+
auto indices_out = args[2].cast<DLTensor*>();
242+
bool is_ascend = args[3].cast<bool>();
243+
DLTensor* workspace = nullptr;
244+
if (args.size() == 5) {
245+
workspace = args[4].cast<DLTensor*>();
246+
}
246247

247-
auto data_dtype = DLDataTypeToString(input->dtype);
248-
auto out_dtype = DLDataTypeToString(indices_out->dtype);
248+
auto data_dtype = ffi::DLDataTypeToString(input->dtype);
249+
auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype);
249250

250-
int n_values = input->shape[input->ndim - 1];
251-
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype,
252-
workspace);
253-
});
251+
int n_values = input->shape[input->ndim - 1];
252+
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype,
253+
workspace);
254+
});
254255

255256
template <typename KeyType, typename ValueType>
256257
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
@@ -281,19 +282,19 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor*
281282

282283
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
283284
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
284-
ICHECK_GE(args.num_args, 5);
285+
ICHECK_GE(args.size(), 5);
285286
auto keys_in = args[0].cast<DLTensor*>();
286287
auto values_in = args[1].cast<DLTensor*>();
287288
auto keys_out = args[2].cast<DLTensor*>();
288289
auto values_out = args[3].cast<DLTensor*>();
289290
bool for_scatter = args[4].cast<bool>();
290291
DLTensor* workspace = nullptr;
291-
if (args.num_args == 6) {
292-
workspace = args[5];
292+
if (args.size() == 6) {
293+
workspace = args[5].cast<DLTensor*>();
293294
}
294295

295-
auto key_dtype = DLDataTypeToString(keys_in->dtype);
296-
auto value_dtype = DLDataTypeToString(values_in->dtype);
296+
auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype);
297+
auto value_dtype = ffi::DLDataTypeToString(values_in->dtype);
297298

298299
if (key_dtype == "int32") {
299300
if (value_dtype == "int32") {
@@ -395,82 +396,82 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor
395396
}
396397

397398
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
398-
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
399-
ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4);
400-
auto data = args[0].cast<DLTensor*>();
401-
auto output = args[1].cast<DLTensor*>();
402-
bool exclusive = false;
403-
DLTensor* workspace = nullptr;
404-
405-
if (args.num_args >= 3) {
406-
exclusive = args[2];
407-
}
399+
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
400+
ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4);
401+
auto data = args[0].cast<DLTensor*>();
402+
auto output = args[1].cast<DLTensor*>();
403+
bool exclusive = false;
404+
DLTensor* workspace = nullptr;
408405

409-
if (args.num_args == 4) {
410-
workspace = args[3];
411-
}
406+
if (args.size() >= 3) {
407+
exclusive = args[2].cast<bool>();
408+
}
412409

413-
auto in_dtype = DLDataTypeToString(data->dtype);
414-
auto out_dtype = DLDataTypeToString(output->dtype);
410+
if (args.size() == 4) {
411+
workspace = args[3].cast<DLTensor*>();
412+
}
415413

416-
if (in_dtype == "bool") {
417-
if (out_dtype == "int32") {
418-
thrust_scan<bool, int>(data, output, exclusive, workspace);
419-
} else if (out_dtype == "int64") {
420-
thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
421-
} else if (out_dtype == "float32") {
422-
thrust_scan<bool, float>(data, output, exclusive, workspace);
423-
} else if (out_dtype == "float64") {
424-
thrust_scan<bool, double>(data, output, exclusive, workspace);
425-
} else {
426-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
427-
<< ". Supported output dtypes are int32, int64, float32, and float64";
428-
}
429-
} else if (in_dtype == "int32") {
430-
if (out_dtype == "int32") {
431-
thrust_scan<int, int>(data, output, exclusive, workspace);
432-
} else if (out_dtype == "int64") {
433-
thrust_scan<int, int64_t>(data, output, exclusive, workspace);
434-
} else if (out_dtype == "float32") {
435-
thrust_scan<int, float>(data, output, exclusive, workspace);
436-
} else if (out_dtype == "float64") {
437-
thrust_scan<int, double>(data, output, exclusive, workspace);
438-
} else {
439-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
440-
<< ". Supported output dtypes are int32, int64, float32, and float64";
441-
}
442-
} else if (in_dtype == "int64") {
443-
if (out_dtype == "int64") {
444-
thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
445-
} else if (out_dtype == "float32") {
446-
thrust_scan<int64_t, float>(data, output, exclusive, workspace);
447-
} else if (out_dtype == "float64") {
448-
thrust_scan<int64_t, double>(data, output, exclusive, workspace);
449-
} else {
450-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
451-
<< ". Supported output dtypes are int64, float32, and float64";
452-
}
453-
} else if (in_dtype == "float32") {
454-
if (out_dtype == "float32") {
455-
thrust_scan<float, float>(data, output, exclusive, workspace);
456-
} else if (out_dtype == "float64") {
457-
thrust_scan<float, double>(data, output, exclusive, workspace);
458-
} else {
459-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
460-
<< ". Supported output dtypes are float32, and float64";
461-
}
462-
} else if (in_dtype == "float64") {
463-
if (out_dtype == "float64") {
464-
thrust_scan<double, double>(data, output, exclusive, workspace);
465-
} else {
466-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
467-
<< ". Supported output dtype is float64";
468-
}
469-
} else {
470-
LOG(FATAL) << "Unsupported input dtype: " << in_dtype
471-
<< ". Supported input dtypes are bool, int32, int64, float32, and float64";
472-
}
473-
});
414+
auto in_dtype = ffi::DLDataTypeToString(data->dtype);
415+
auto out_dtype = ffi::DLDataTypeToString(output->dtype);
416+
417+
if (in_dtype == "bool") {
418+
if (out_dtype == "int32") {
419+
thrust_scan<bool, int>(data, output, exclusive, workspace);
420+
} else if (out_dtype == "int64") {
421+
thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
422+
} else if (out_dtype == "float32") {
423+
thrust_scan<bool, float>(data, output, exclusive, workspace);
424+
} else if (out_dtype == "float64") {
425+
thrust_scan<bool, double>(data, output, exclusive, workspace);
426+
} else {
427+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
428+
<< ". Supported output dtypes are int32, int64, float32, and float64";
429+
}
430+
} else if (in_dtype == "int32") {
431+
if (out_dtype == "int32") {
432+
thrust_scan<int, int>(data, output, exclusive, workspace);
433+
} else if (out_dtype == "int64") {
434+
thrust_scan<int, int64_t>(data, output, exclusive, workspace);
435+
} else if (out_dtype == "float32") {
436+
thrust_scan<int, float>(data, output, exclusive, workspace);
437+
} else if (out_dtype == "float64") {
438+
thrust_scan<int, double>(data, output, exclusive, workspace);
439+
} else {
440+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
441+
<< ". Supported output dtypes are int32, int64, float32, and float64";
442+
}
443+
} else if (in_dtype == "int64") {
444+
if (out_dtype == "int64") {
445+
thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
446+
} else if (out_dtype == "float32") {
447+
thrust_scan<int64_t, float>(data, output, exclusive, workspace);
448+
} else if (out_dtype == "float64") {
449+
thrust_scan<int64_t, double>(data, output, exclusive, workspace);
450+
} else {
451+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
452+
<< ". Supported output dtypes are int64, float32, and float64";
453+
}
454+
} else if (in_dtype == "float32") {
455+
if (out_dtype == "float32") {
456+
thrust_scan<float, float>(data, output, exclusive, workspace);
457+
} else if (out_dtype == "float64") {
458+
thrust_scan<float, double>(data, output, exclusive, workspace);
459+
} else {
460+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
461+
<< ". Supported output dtypes are float32, and float64";
462+
}
463+
} else if (in_dtype == "float64") {
464+
if (out_dtype == "float64") {
465+
thrust_scan<double, double>(data, output, exclusive, workspace);
466+
} else {
467+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
468+
<< ". Supported output dtype is float64";
469+
}
470+
} else {
471+
LOG(FATAL) << "Unsupported input dtype: " << in_dtype
472+
<< ". Supported input dtypes are bool, int32, int64, float32, and float64";
473+
}
474+
});
474475

475476
} // namespace contrib
476477
} // namespace tvm

0 commit comments

Comments
 (0)