|
31 | 31 | #include <thrust/scan.h> |
32 | 32 | #include <thrust/sequence.h> |
33 | 33 | #include <thrust/sort.h> |
| 34 | +#include <tvm/ffi/dtype.h> |
34 | 35 | #include <tvm/ffi/function.h> |
35 | 36 |
|
36 | 37 | #include <algorithm> |
@@ -233,24 +234,24 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices |
233 | 234 | } |
234 | 235 |
|
235 | 236 | TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort") |
236 | | -.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { |
237 | | - ICHECK_GE(args.num_args, 4); |
238 | | - auto input = args[0].cast<DLTensor*>(); |
239 | | - auto values_out = args[1].cast<DLTensor*>(); |
240 | | - auto indices_out = args[2].cast<DLTensor*>(); |
241 | | - bool is_ascend = args[3].cast<bool>(); |
242 | | - DLTensor* workspace = nullptr; |
243 | | - if (args.num_args == 5) { |
244 | | - workspace = args[4]; |
245 | | - } |
| 237 | + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { |
| 238 | + ICHECK_GE(args.size(), 4); |
| 239 | + auto input = args[0].cast<DLTensor*>(); |
| 240 | + auto values_out = args[1].cast<DLTensor*>(); |
| 241 | + auto indices_out = args[2].cast<DLTensor*>(); |
| 242 | + bool is_ascend = args[3].cast<bool>(); |
| 243 | + DLTensor* workspace = nullptr; |
| 244 | + if (args.size() == 5) { |
| 245 | + workspace = args[4].cast<DLTensor*>(); |
| 246 | + } |
246 | 247 |
|
247 | | - auto data_dtype = DLDataTypeToString(input->dtype); |
248 | | - auto out_dtype = DLDataTypeToString(indices_out->dtype); |
| 248 | + auto data_dtype = ffi::DLDataTypeToString(input->dtype); |
| 249 | + auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype); |
249 | 250 |
|
250 | | - int n_values = input->shape[input->ndim - 1]; |
251 | | - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, |
252 | | - workspace); |
253 | | -}); |
| 251 | + int n_values = input->shape[input->ndim - 1]; |
| 252 | + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, |
| 253 | + workspace); |
| 254 | + }); |
254 | 255 |
|
255 | 256 | template <typename KeyType, typename ValueType> |
256 | 257 | void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, |
@@ -281,19 +282,19 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* |
281 | 282 |
|
282 | 283 | TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") |
283 | 284 | .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { |
284 | | - ICHECK_GE(args.num_args, 5); |
| 285 | + ICHECK_GE(args.size(), 5); |
285 | 286 | auto keys_in = args[0].cast<DLTensor*>(); |
286 | 287 | auto values_in = args[1].cast<DLTensor*>(); |
287 | 288 | auto keys_out = args[2].cast<DLTensor*>(); |
288 | 289 | auto values_out = args[3].cast<DLTensor*>(); |
289 | 290 | bool for_scatter = args[4].cast<bool>(); |
290 | 291 | DLTensor* workspace = nullptr; |
291 | | - if (args.num_args == 6) { |
292 | | - workspace = args[5]; |
| 292 | + if (args.size() == 6) { |
| 293 | + workspace = args[5].cast<DLTensor*>(); |
293 | 294 | } |
294 | 295 |
|
295 | | - auto key_dtype = DLDataTypeToString(keys_in->dtype); |
296 | | - auto value_dtype = DLDataTypeToString(values_in->dtype); |
| 296 | + auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype); |
| 297 | + auto value_dtype = ffi::DLDataTypeToString(values_in->dtype); |
297 | 298 |
|
298 | 299 | if (key_dtype == "int32") { |
299 | 300 | if (value_dtype == "int32") { |
@@ -395,82 +396,82 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor |
395 | 396 | } |
396 | 397 |
|
397 | 398 | TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") |
398 | | -.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { |
399 | | - ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4); |
400 | | - auto data = args[0].cast<DLTensor*>(); |
401 | | - auto output = args[1].cast<DLTensor*>(); |
402 | | - bool exclusive = false; |
403 | | - DLTensor* workspace = nullptr; |
404 | | - |
405 | | - if (args.num_args >= 3) { |
406 | | - exclusive = args[2]; |
407 | | - } |
| 399 | + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { |
| 400 | + ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4); |
| 401 | + auto data = args[0].cast<DLTensor*>(); |
| 402 | + auto output = args[1].cast<DLTensor*>(); |
| 403 | + bool exclusive = false; |
| 404 | + DLTensor* workspace = nullptr; |
408 | 405 |
|
409 | | - if (args.num_args == 4) { |
410 | | - workspace = args[3]; |
411 | | - } |
| 406 | + if (args.size() >= 3) { |
| 407 | + exclusive = args[2].cast<bool>(); |
| 408 | + } |
412 | 409 |
|
413 | | - auto in_dtype = DLDataTypeToString(data->dtype); |
414 | | - auto out_dtype = DLDataTypeToString(output->dtype); |
| 410 | + if (args.size() == 4) { |
| 411 | + workspace = args[3].cast<DLTensor*>(); |
| 412 | + } |
415 | 413 |
|
416 | | - if (in_dtype == "bool") { |
417 | | - if (out_dtype == "int32") { |
418 | | - thrust_scan<bool, int>(data, output, exclusive, workspace); |
419 | | - } else if (out_dtype == "int64") { |
420 | | - thrust_scan<bool, int64_t>(data, output, exclusive, workspace); |
421 | | - } else if (out_dtype == "float32") { |
422 | | - thrust_scan<bool, float>(data, output, exclusive, workspace); |
423 | | - } else if (out_dtype == "float64") { |
424 | | - thrust_scan<bool, double>(data, output, exclusive, workspace); |
425 | | - } else { |
426 | | - LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
427 | | - << ". Supported output dtypes are int32, int64, float32, and float64"; |
428 | | - } |
429 | | - } else if (in_dtype == "int32") { |
430 | | - if (out_dtype == "int32") { |
431 | | - thrust_scan<int, int>(data, output, exclusive, workspace); |
432 | | - } else if (out_dtype == "int64") { |
433 | | - thrust_scan<int, int64_t>(data, output, exclusive, workspace); |
434 | | - } else if (out_dtype == "float32") { |
435 | | - thrust_scan<int, float>(data, output, exclusive, workspace); |
436 | | - } else if (out_dtype == "float64") { |
437 | | - thrust_scan<int, double>(data, output, exclusive, workspace); |
438 | | - } else { |
439 | | - LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
440 | | - << ". Supported output dtypes are int32, int64, float32, and float64"; |
441 | | - } |
442 | | - } else if (in_dtype == "int64") { |
443 | | - if (out_dtype == "int64") { |
444 | | - thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace); |
445 | | - } else if (out_dtype == "float32") { |
446 | | - thrust_scan<int64_t, float>(data, output, exclusive, workspace); |
447 | | - } else if (out_dtype == "float64") { |
448 | | - thrust_scan<int64_t, double>(data, output, exclusive, workspace); |
449 | | - } else { |
450 | | - LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
451 | | - << ". Supported output dtypes are int64, float32, and float64"; |
452 | | - } |
453 | | - } else if (in_dtype == "float32") { |
454 | | - if (out_dtype == "float32") { |
455 | | - thrust_scan<float, float>(data, output, exclusive, workspace); |
456 | | - } else if (out_dtype == "float64") { |
457 | | - thrust_scan<float, double>(data, output, exclusive, workspace); |
458 | | - } else { |
459 | | - LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
460 | | - << ". Supported output dtypes are float32, and float64"; |
461 | | - } |
462 | | - } else if (in_dtype == "float64") { |
463 | | - if (out_dtype == "float64") { |
464 | | - thrust_scan<double, double>(data, output, exclusive, workspace); |
465 | | - } else { |
466 | | - LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
467 | | - << ". Supported output dtype is float64"; |
468 | | - } |
469 | | - } else { |
470 | | - LOG(FATAL) << "Unsupported input dtype: " << in_dtype |
471 | | - << ". Supported input dtypes are bool, int32, int64, float32, and float64"; |
472 | | - } |
473 | | -}); |
| 414 | + auto in_dtype = ffi::DLDataTypeToString(data->dtype); |
| 415 | + auto out_dtype = ffi::DLDataTypeToString(output->dtype); |
| 416 | + |
| 417 | + if (in_dtype == "bool") { |
| 418 | + if (out_dtype == "int32") { |
| 419 | + thrust_scan<bool, int>(data, output, exclusive, workspace); |
| 420 | + } else if (out_dtype == "int64") { |
| 421 | + thrust_scan<bool, int64_t>(data, output, exclusive, workspace); |
| 422 | + } else if (out_dtype == "float32") { |
| 423 | + thrust_scan<bool, float>(data, output, exclusive, workspace); |
| 424 | + } else if (out_dtype == "float64") { |
| 425 | + thrust_scan<bool, double>(data, output, exclusive, workspace); |
| 426 | + } else { |
| 427 | + LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
| 428 | + << ". Supported output dtypes are int32, int64, float32, and float64"; |
| 429 | + } |
| 430 | + } else if (in_dtype == "int32") { |
| 431 | + if (out_dtype == "int32") { |
| 432 | + thrust_scan<int, int>(data, output, exclusive, workspace); |
| 433 | + } else if (out_dtype == "int64") { |
| 434 | + thrust_scan<int, int64_t>(data, output, exclusive, workspace); |
| 435 | + } else if (out_dtype == "float32") { |
| 436 | + thrust_scan<int, float>(data, output, exclusive, workspace); |
| 437 | + } else if (out_dtype == "float64") { |
| 438 | + thrust_scan<int, double>(data, output, exclusive, workspace); |
| 439 | + } else { |
| 440 | + LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
| 441 | + << ". Supported output dtypes are int32, int64, float32, and float64"; |
| 442 | + } |
| 443 | + } else if (in_dtype == "int64") { |
| 444 | + if (out_dtype == "int64") { |
| 445 | + thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace); |
| 446 | + } else if (out_dtype == "float32") { |
| 447 | + thrust_scan<int64_t, float>(data, output, exclusive, workspace); |
| 448 | + } else if (out_dtype == "float64") { |
| 449 | + thrust_scan<int64_t, double>(data, output, exclusive, workspace); |
| 450 | + } else { |
| 451 | + LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
| 452 | + << ". Supported output dtypes are int64, float32, and float64"; |
| 453 | + } |
| 454 | + } else if (in_dtype == "float32") { |
| 455 | + if (out_dtype == "float32") { |
| 456 | + thrust_scan<float, float>(data, output, exclusive, workspace); |
| 457 | + } else if (out_dtype == "float64") { |
| 458 | + thrust_scan<float, double>(data, output, exclusive, workspace); |
| 459 | + } else { |
| 460 | + LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
| 461 | + << ". Supported output dtypes are float32, and float64"; |
| 462 | + } |
| 463 | + } else if (in_dtype == "float64") { |
| 464 | + if (out_dtype == "float64") { |
| 465 | + thrust_scan<double, double>(data, output, exclusive, workspace); |
| 466 | + } else { |
| 467 | + LOG(FATAL) << "Unsupported output dtype: " << out_dtype |
| 468 | + << ". Supported output dtype is float64"; |
| 469 | + } |
| 470 | + } else { |
| 471 | + LOG(FATAL) << "Unsupported input dtype: " << in_dtype |
| 472 | + << ". Supported input dtypes are bool, int32, int64, float32, and float64"; |
| 473 | + } |
| 474 | + }); |
474 | 475 |
|
475 | 476 | } // namespace contrib |
476 | 477 | } // namespace tvm |
0 commit comments