Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,29 @@ class NDArray : public ObjectRef {
* \param stream The output data stream
*/
inline void Save(dmlc::Stream* stream) const;

/*!
* \brief Create a NDArray that shares the data memory with the current one.
*
* \param shape The shape of the new array.
*
* \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one.
*
* \param relative_byte_offset The offset of the output NDArray,
* relative to the current byte offset.
*
* By default, the offset of the view is the same as the offset
* of the current array.
*
* \note The new array must not allow access of addresses which
* would be out of bounds in the current array. If the new
* array is larger than the current array, or if the
* `relative_byte_offset` would place the end of the new array
* outside the bounds of the current array, this function will
* raise an exception.
*/
TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype);
TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype, size_t relative_byte_offset = 0);

/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
Expand Down
25 changes: 23 additions & 2 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Runtime NDArray API"""
import ctypes
import warnings
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -287,7 +288,7 @@ def copyto(self, target, mem_scope=None):
return self._copyto(res)
raise ValueError(f"Unsupported target type {type(target)}")

def _create_view(self, shape):
def _create_view(self, shape, dtype: Optional[str] = None, relative_byte_offset: int = 0):
"""Create a view into an existing array.

The view shares the same allocation and datatype as the
Expand All @@ -307,12 +308,32 @@ def _create_view(self, shape):
shape: Union[tvm.runtime.ShapeTuple, Sequence[typing.SupportsInt]]

The shape of the view.

dtype: Optional[str]

The datatype of the view. If None (default), the view
will be the same data type as the current array.

relative_byte_offset: int

The location of the view, relative to the location of the current
array.

Note: While the `DLTensor.byte_offset` field of the returned view
is usually the same as `relative_byte_offset`, this is not
guaranteed. The `DLTensor.byte_offset` field is relative to the
start of the backing allocation, while the `relative_byte_offset`
is relative to the start of `self`.

"""

if not isinstance(shape, tvm.runtime.ShapeTuple):
shape = tvm.runtime.ShapeTuple([int(dim) for dim in shape])

return _ffi_api.TVMArrayCreateView(self, shape)
if dtype is None:
dtype = self.dtype

return _ffi_api.TVMArrayCreateView(self, shape, dtype, relative_byte_offset)


def device(dev_type, dev_id=0):
Expand Down
70 changes: 39 additions & 31 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,42 +179,53 @@ struct NDArray::Internal {
}
};

NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) {
NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype, size_t relative_byte_offset) {
ICHECK(data_ != nullptr);

const DLTensor& orig = get_mutable()->dl_tensor;
ICHECK(IsContiguous()) << "Can only create view for compact tensor, but found strides " <<
[&orig]() {
std::stringstream ss;
ss << "[";
for (int i = 0; i < orig.ndim; i++) {
if (i) ss << ", ";
ss << orig.strides[i];
}
ss << "]";
return ss.str();
}() << ", for shape "
<< [&]() {
std::stringstream ss;
ss << "[";
for (int i = 0; i < orig.ndim; i++) {
if (i) ss << ", ";
ss << orig.shape[i];
}
ss << "]";
return ss.str();
}();

NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device);
ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset;
CHECK(IsContiguous()) << [&orig]() {
std::stringstream ss;
ss << "Can only create view for compact tensor, but found strides ";

ss << "[";
for (int i = 0; i < orig.ndim; i++) {
if (i) ss << ", ";
ss << orig.strides[i];
}
ss << "]";

ss << ", for shape ";
ss << "[";
for (int i = 0; i < orig.ndim; i++) {
if (i) ss << ", ";
ss << orig.shape[i];
}
ss << "]";
return ss.str();
}();

const auto& curr_dl_tensor = get_mutable()->dl_tensor;

NDArray ret = Internal::Create(shape, dtype, curr_dl_tensor.device);

size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor);
ICHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one";
CHECK_LE(relative_byte_offset + view_size, curr_size)
<< "ValueError: "
<< "View with shape " << shape << " and datatype " << dtype << " would have a size of "
<< view_size << " bytes. "
<< "This would occupy bytes " << relative_byte_offset << " <= i_byte < "
<< (relative_byte_offset + view_size) << " within the backing array. "
<< "However, the NDArray being viewed only contains " << curr_size << " bytes (shape = "
<< ShapeTuple(curr_dl_tensor.shape, curr_dl_tensor.shape + curr_dl_tensor.ndim)
<< ", dtype= " << curr_dl_tensor.dtype << ").";

// increase ref count
get_mutable()->IncRef();
ret.get_mutable()->manager_ctx = get_mutable();
ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data;
ret.get_mutable()->dl_tensor.byte_offset =
get_mutable()->dl_tensor.byte_offset + relative_byte_offset;
return ret;
}

Expand Down Expand Up @@ -372,10 +383,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_

TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body_typed(NDArray::Empty);

TVM_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_typed([](NDArray arr, ShapeTuple shape) {
NDArray view = arr.CreateView(shape, arr->dtype);
return view;
});
TVM_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_method(&NDArray::CreateView);

int TVMArrayFree(TVMArrayHandle handle) {
API_BEGIN();
Expand Down
Loading