@@ -80,7 +80,7 @@ struct float16 {
8080// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
8181// and sort axis is dk. sort_num should have dimension of
8282// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
83- TVM_FFI_STATIC_INIT_BLOCK ( {
83+ void RegisterArgsortNMS () {
8484 namespace refl = tvm::ffi::reflection;
8585 refl::GlobalDef ().def_packed (
8686 " tvm.contrib.sort.argsort_nms" , [](ffi::PackedArgs args, ffi::Any* ret) {
@@ -157,7 +157,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
157157 }
158158 }
159159 });
160- });
160+ }
161161
162162template <typename DataType, typename OutType>
163163void sort_impl (
@@ -222,7 +222,7 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
222222// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
223223// and sort axis is dk. sort_num should have dimension of
224224// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
225- TVM_FFI_STATIC_INIT_BLOCK ( {
225+ void RegisterArgsort () {
226226 namespace refl = tvm::ffi::reflection;
227227 refl::GlobalDef ().def_packed (" tvm.contrib.sort.argsort" , [](ffi::PackedArgs args, ffi::Any* ret) {
228228 auto input = args[0 ].cast <DLTensor*>();
@@ -311,10 +311,19 @@ TVM_FFI_STATIC_INIT_BLOCK({
311311 LOG (FATAL) << " Unsupported input dtype: " << data_dtype;
312312 }
313313 });
314- }) ;
314+ };
315315
316316
317- void SortPacked (ffi::PackedArgs args, ffi::Any* ret) {
317+ // Sort implemented C library sort.
318+ // Return sorted tensor.
319+ // By default, the last axis will be used to sort.
320+ // sort_num specify the number of elements to be sorted.
321+ // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
322+ // and sort axis is dk. sort_num should have dimension of
323+ // (d1, d2, ..., d(k-1), d(k+1), ..., dn).
324+ void RegisterSort () {
325+ namespace refl = tvm::ffi::reflection;
326+ refl::GlobalDef ().def_packed (" tvm.contrib.sort.sort" , [](ffi::PackedArgs args, ffi::Any* ret) {
318327 auto input = args[0 ].cast <DLTensor*>();
319328 auto output = args[1 ].cast <DLTensor*>();
320329 int32_t axis = args[2 ].cast <int32_t >();
@@ -348,18 +357,8 @@ void SortPacked(ffi::PackedArgs args, ffi::Any* ret) {
348357 } else {
349358 LOG (FATAL) << " Unsupported input dtype: " << data_dtype;
350359 }
360+ });
351361}
352- // Sort implemented C library sort.
353- // Return sorted tensor.
354- // By default, the last axis will be used to sort.
355- // sort_num specify the number of elements to be sorted.
356- // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
357- // and sort axis is dk. sort_num should have dimension of
358- // (d1, d2, ..., d(k-1), d(k+1), ..., dn).
359- TVM_FFI_STATIC_INIT_BLOCK ({
360- namespace refl = tvm::ffi::reflection;
361- refl::GlobalDef ().def_packed (" tvm.contrib.sort.sort" , SortPacked);
362- });
363362
364363template <typename DataType, typename IndicesType>
365364void topk (DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis,
@@ -454,7 +453,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i
454453// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
455454// and sort axis is dk. sort_num should have dimension of
456455// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
457- TVM_FFI_STATIC_INIT_BLOCK ( {
456+ void RegisterTopk () {
458457 namespace refl = tvm::ffi::reflection;
459458 refl::GlobalDef ().def_packed (" tvm.contrib.sort.topk" , [](ffi::PackedArgs args, ffi::Any* ret) {
460459 auto input = args[0 ].cast <DLTensor*>();
@@ -576,6 +575,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
576575 LOG (FATAL) << " Unsupported input dtype: " << data_dtype;
577576 }
578577 });
578+ }
579+
580+ TVM_FFI_STATIC_INIT_BLOCK ({
581+ RegisterArgsortNMS ();
582+ RegisterArgsort ();
583+ RegisterSort ();
584+ RegisterTopk ();
579585});
580586
581587} // namespace contrib
0 commit comments