Skip to content

Commit 593a4bd

Browse files
author
Siyuan Feng
authored
[Relax] NDArray Cache Update with DLTensor Support (#16464)
As NDArray on RPC devices only returns a DLTensor, we add support for DLTensor in NDArray Cache. It's not easy to add test cases as we cannot create a raw DLTensor in Python interface.
1 parent 0e8e421 commit 593a4bd

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/runtime/relax_vm/ndarray_cache_support.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,28 @@ class NDArrayCache {
267267
};
268268

269269
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get);
270-
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body_typed(NDArrayCache::Update);
270+
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body([](TVMArgs args, TVMRetValue* rv) {
271+
CHECK(args.size() == 2 || args.size() == 3);
272+
String name = args[0];
273+
bool is_override = args.size() == 2 ? false : args[2];
274+
275+
NDArray arr;
276+
if (args[1].type_code() == kTVMNDArrayHandle) {
277+
arr = args[1];
278+
} else {
279+
// We support converting DLTensors to NDArrays as RPC references are always DLTensors
280+
DLTensor* tensor = args[1];
281+
std::vector<int64_t> shape;
282+
for (int64_t i = 0; i < tensor->ndim; i++) {
283+
shape.push_back(tensor->shape[i]);
284+
}
285+
NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device);
286+
arr.CopyFrom(tensor);
287+
TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr);
288+
}
289+
290+
NDArrayCache::Update(name, arr, is_override);
291+
});
271292
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove);
272293
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear);
273294
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load);

0 commit comments

Comments
 (0)