From cb5683cbf8ff9bca08faaf5c861584c3cdaeb990 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 29 May 2025 16:29:33 -0400 Subject: [PATCH] [CUDA] Fix thrust with latest FFI refactor This PR fixes the thrust integration with the latest FFI refactor. --- cmake/modules/CUDA.cmake | 1 + src/runtime/contrib/thrust/thrust.cu | 189 ++++++++++++++------------- 2 files changed, 96 insertions(+), 94 deletions(-) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index f9dd4a890369..84261c6ea0ae 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -109,6 +109,7 @@ if(USE_CUDA) message(STATUS "Build with Thrust support") tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) add_library(tvm_thrust_objs OBJECT ${CONTRIB_THRUST_SRC}) + target_link_libraries(tvm_thrust_objs PRIVATE tvm_ffi_header) target_compile_options(tvm_thrust_objs PRIVATE $<$:--expt-extended-lambda>) target_compile_definitions(tvm_thrust_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) if (NOT USE_THRUST MATCHES ${IS_TRUE_PATTERN}) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 19f82b1855b4..6b6b9df834ab 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -233,24 +234,24 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort") -.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_GE(args.num_args, 4); - auto input = args[0].cast(); - auto values_out = args[1].cast(); - auto indices_out = args[2].cast(); - bool is_ascend = args[3].cast(); - DLTensor* workspace = nullptr; - if (args.num_args == 5) { - workspace = args[4]; - } + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK_GE(args.size(), 4); + auto input = args[0].cast(); + auto values_out = args[1].cast(); + auto indices_out = args[2].cast(); + bool is_ascend = args[3].cast(); + DLTensor* workspace = nullptr; + if (args.size() == 5) { + workspace = args[4].cast(); + } - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(indices_out->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype); - int n_values = input->shape[input->ndim - 1]; - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, - workspace); -}); + int n_values = input->shape[input->ndim - 1]; + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, + workspace); + }); template void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, @@ -281,19 +282,19 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_GE(args.num_args, 5); + ICHECK_GE(args.size(), 5); auto keys_in = args[0].cast(); auto values_in = args[1].cast(); auto keys_out = args[2].cast(); auto values_out = args[3].cast(); bool for_scatter = args[4].cast(); DLTensor* workspace = nullptr; - if (args.num_args == 6) { - workspace = args[5]; + if (args.size() == 6) { + workspace = args[5].cast(); } - auto key_dtype = DLDataTypeToString(keys_in->dtype); - auto value_dtype = DLDataTypeToString(values_in->dtype); + auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype); + auto value_dtype = ffi::DLDataTypeToString(values_in->dtype); if (key_dtype == "int32") { if (value_dtype == "int32") { @@ -395,82 +396,82 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor } TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") -.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4); - auto data = args[0].cast(); - auto output = args[1].cast(); - bool exclusive = false; - DLTensor* workspace = nullptr; - - if (args.num_args >= 3) { - exclusive = args[2]; - } + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4); + auto data = args[0].cast(); + auto output = args[1].cast(); + bool exclusive = false; + DLTensor* workspace = nullptr; - if (args.num_args == 4) { - workspace = args[3]; - } + if (args.size() >= 3) { + exclusive = args[2].cast(); + } - auto in_dtype = DLDataTypeToString(data->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + if (args.size() == 4) { + workspace = args[3].cast(); + } - if (in_dtype == "bool") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int32") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int64") { - if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int64, float32, and float64"; - } - } else if (in_dtype == "float32") { - if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are float32, and float64"; - } - } else if (in_dtype == "float64") { - if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtype is float64"; - } - } else { - LOG(FATAL) << "Unsupported input dtype: " << in_dtype - << ". Supported input dtypes are bool, int32, int64, float32, and float64"; - } -}); + auto in_dtype = ffi::DLDataTypeToString(data->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); + + if (in_dtype == "bool") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int64, float32, and float64"; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are float32, and float64"; + } + } else if (in_dtype == "float64") { + if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtype is float64"; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << in_dtype + << ". Supported input dtypes are bool, int32, int64, float32, and float64"; + } + }); } // namespace contrib } // namespace tvm