Skip to content

Commit 3a4a571

Browse files
committed
fix sort
1 parent c7ddb2c commit 3a4a571

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

src/runtime/contrib/sort/sort.cc

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct float16 {
8080
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
8181
// and sort axis is dk. sort_num should have dimension of
8282
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
83-
TVM_FFI_STATIC_INIT_BLOCK({
83+
void RegisterArgsortNMS() {
8484
namespace refl = tvm::ffi::reflection;
8585
refl::GlobalDef().def_packed(
8686
"tvm.contrib.sort.argsort_nms", [](ffi::PackedArgs args, ffi::Any* ret) {
@@ -157,7 +157,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
157157
}
158158
}
159159
});
160-
});
160+
}
161161

162162
template <typename DataType, typename OutType>
163163
void sort_impl(
@@ -222,7 +222,7 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
222222
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
223223
// and sort axis is dk. sort_num should have dimension of
224224
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
225-
TVM_FFI_STATIC_INIT_BLOCK({
225+
void RegisterArgsort() {
226226
namespace refl = tvm::ffi::reflection;
227227
refl::GlobalDef().def_packed("tvm.contrib.sort.argsort", [](ffi::PackedArgs args, ffi::Any* ret) {
228228
auto input = args[0].cast<DLTensor*>();
@@ -311,10 +311,19 @@ TVM_FFI_STATIC_INIT_BLOCK({
311311
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
312312
}
313313
});
314-
});
314+
};
315315

316316

317-
void SortPacked(ffi::PackedArgs args, ffi::Any* ret) {
317+
// Sort implemented C library sort.
318+
// Return sorted tensor.
319+
// By default, the last axis will be used to sort.
320+
// sort_num specify the number of elements to be sorted.
321+
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
322+
// and sort axis is dk. sort_num should have dimension of
323+
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
324+
void RegisterSort() {
325+
namespace refl = tvm::ffi::reflection;
326+
refl::GlobalDef().def_packed("tvm.contrib.sort.sort", [](ffi::PackedArgs args, ffi::Any* ret) {
318327
auto input = args[0].cast<DLTensor*>();
319328
auto output = args[1].cast<DLTensor*>();
320329
int32_t axis = args[2].cast<int32_t>();
@@ -348,18 +357,8 @@ void SortPacked(ffi::PackedArgs args, ffi::Any* ret) {
348357
} else {
349358
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
350359
}
360+
});
351361
}
352-
// Sort implemented C library sort.
353-
// Return sorted tensor.
354-
// By default, the last axis will be used to sort.
355-
// sort_num specify the number of elements to be sorted.
356-
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
357-
// and sort axis is dk. sort_num should have dimension of
358-
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
359-
TVM_FFI_STATIC_INIT_BLOCK({
360-
namespace refl = tvm::ffi::reflection;
361-
refl::GlobalDef().def_packed("tvm.contrib.sort.sort", SortPacked);
362-
});
363362

364363
template <typename DataType, typename IndicesType>
365364
void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis,
@@ -454,7 +453,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i
454453
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
455454
// and sort axis is dk. sort_num should have dimension of
456455
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
457-
TVM_FFI_STATIC_INIT_BLOCK({
456+
void RegisterTopk() {
458457
namespace refl = tvm::ffi::reflection;
459458
refl::GlobalDef().def_packed("tvm.contrib.sort.topk", [](ffi::PackedArgs args, ffi::Any* ret) {
460459
auto input = args[0].cast<DLTensor*>();
@@ -576,6 +575,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
576575
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
577576
}
578577
});
578+
}
579+
580+
TVM_FFI_STATIC_INIT_BLOCK({
581+
RegisterArgsortNMS();
582+
RegisterArgsort();
583+
RegisterSort();
584+
RegisterTopk();
579585
});
580586

581587
} // namespace contrib

tests/python/relax/test_vm_callback_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def relax_func(
100100
)
101101
vm = tvm.relax.VirtualMachine(ex, dev)
102102

103+
# custom callback that raises an error in python
103104
def custom_callback():
104105
local_var = 42
105106
raise RuntimeError("Error thrown from callback")

0 commit comments

Comments
 (0)