Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ if(USE_CUDA)
message(STATUS "Build with Thrust support")
tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
add_library(tvm_thrust_objs OBJECT ${CONTRIB_THRUST_SRC})
target_link_libraries(tvm_thrust_objs PRIVATE tvm_ffi_header)
target_compile_options(tvm_thrust_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>)
target_compile_definitions(tvm_thrust_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
if (NOT USE_THRUST MATCHES ${IS_TRUE_PATTERN})
Expand Down
189 changes: 95 additions & 94 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/function.h>

#include <algorithm>
Expand Down Expand Up @@ -233,24 +234,24 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices
}

TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ICHECK_GE(args.num_args, 4);
auto input = args[0].cast<DLTensor*>();
auto values_out = args[1].cast<DLTensor*>();
auto indices_out = args[2].cast<DLTensor*>();
bool is_ascend = args[3].cast<bool>();
DLTensor* workspace = nullptr;
if (args.num_args == 5) {
workspace = args[4];
}
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ICHECK_GE(args.size(), 4);
auto input = args[0].cast<DLTensor*>();
auto values_out = args[1].cast<DLTensor*>();
auto indices_out = args[2].cast<DLTensor*>();
bool is_ascend = args[3].cast<bool>();
DLTensor* workspace = nullptr;
if (args.size() == 5) {
workspace = args[4].cast<DLTensor*>();
}

auto data_dtype = DLDataTypeToString(input->dtype);
auto out_dtype = DLDataTypeToString(indices_out->dtype);
auto data_dtype = ffi::DLDataTypeToString(input->dtype);
auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype,
workspace);
});
int n_values = input->shape[input->ndim - 1];
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype,
workspace);
});

template <typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
Expand Down Expand Up @@ -281,19 +282,19 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor*

TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ICHECK_GE(args.num_args, 5);
ICHECK_GE(args.size(), 5);
auto keys_in = args[0].cast<DLTensor*>();
auto values_in = args[1].cast<DLTensor*>();
auto keys_out = args[2].cast<DLTensor*>();
auto values_out = args[3].cast<DLTensor*>();
bool for_scatter = args[4].cast<bool>();
DLTensor* workspace = nullptr;
if (args.num_args == 6) {
workspace = args[5];
if (args.size() == 6) {
workspace = args[5].cast<DLTensor*>();
}

auto key_dtype = DLDataTypeToString(keys_in->dtype);
auto value_dtype = DLDataTypeToString(values_in->dtype);
auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype);
auto value_dtype = ffi::DLDataTypeToString(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
Expand Down Expand Up @@ -395,82 +396,82 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor
}

TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4);
auto data = args[0].cast<DLTensor*>();
auto output = args[1].cast<DLTensor*>();
bool exclusive = false;
DLTensor* workspace = nullptr;

if (args.num_args >= 3) {
exclusive = args[2];
}
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4);
auto data = args[0].cast<DLTensor*>();
auto output = args[1].cast<DLTensor*>();
bool exclusive = false;
DLTensor* workspace = nullptr;

if (args.num_args == 4) {
workspace = args[3];
}
if (args.size() >= 3) {
exclusive = args[2].cast<bool>();
}

auto in_dtype = DLDataTypeToString(data->dtype);
auto out_dtype = DLDataTypeToString(output->dtype);
if (args.size() == 4) {
workspace = args[3].cast<DLTensor*>();
}

if (in_dtype == "bool") {
if (out_dtype == "int32") {
thrust_scan<bool, int>(data, output, exclusive, workspace);
} else if (out_dtype == "int64") {
thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<bool, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<bool, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int32") {
if (out_dtype == "int32") {
thrust_scan<int, int>(data, output, exclusive, workspace);
} else if (out_dtype == "int64") {
thrust_scan<int, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<int, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<int, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int64") {
if (out_dtype == "int64") {
thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<int64_t, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<int64_t, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int64, float32, and float64";
}
} else if (in_dtype == "float32") {
if (out_dtype == "float32") {
thrust_scan<float, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<float, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are float32, and float64";
}
} else if (in_dtype == "float64") {
if (out_dtype == "float64") {
thrust_scan<double, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtype is float64";
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << in_dtype
<< ". Supported input dtypes are bool, int32, int64, float32, and float64";
}
});
auto in_dtype = ffi::DLDataTypeToString(data->dtype);
auto out_dtype = ffi::DLDataTypeToString(output->dtype);

if (in_dtype == "bool") {
if (out_dtype == "int32") {
thrust_scan<bool, int>(data, output, exclusive, workspace);
} else if (out_dtype == "int64") {
thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<bool, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<bool, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int32") {
if (out_dtype == "int32") {
thrust_scan<int, int>(data, output, exclusive, workspace);
} else if (out_dtype == "int64") {
thrust_scan<int, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<int, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<int, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int64") {
if (out_dtype == "int64") {
thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<int64_t, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<int64_t, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int64, float32, and float64";
}
} else if (in_dtype == "float32") {
if (out_dtype == "float32") {
thrust_scan<float, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<float, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are float32, and float64";
}
} else if (in_dtype == "float64") {
if (out_dtype == "float64") {
thrust_scan<double, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtype is float64";
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << in_dtype
<< ". Supported input dtypes are bool, int32, int64, float32, and float64";
}
});

} // namespace contrib
} // namespace tvm
Loading