Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,21 @@ def numpy(self):
return np_arr_ret.reshape(shape)
return np_arr

def copyto(self, target):
def copyto(self, target, mem_scope=None):
"""Copy array to target

Parameters
----------
target : NDArray
The target array to be copied, must have same shape as this array.

mem_scope : Optional[str]
The memory scope of the array.
"""
if isinstance(target, NDArrayBase):
return self._copyto(target)
if isinstance(target, Device):
res = empty(self.shape, self.dtype, target)
res = empty(self.shape, self.dtype, target, mem_scope)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))

Expand Down Expand Up @@ -574,7 +577,7 @@ def webgpu(dev_id=0):
mtl = metal


def array(arr, device=cpu(0)):
def array(arr, device=cpu(0), mem_scope=None):
"""Create an array from source arr.

Parameters
Expand All @@ -585,6 +588,9 @@ def array(arr, device=cpu(0)):
device : Device, optional
The device device to create the array

mem_scope : Optional[str]
The memory scope of the array

Returns
-------
ret : NDArray
Expand All @@ -595,7 +601,7 @@ def array(arr, device=cpu(0)):

if not isinstance(arr, (np.ndarray, NDArray)):
arr = np.array(arr)
return empty(arr.shape, arr.dtype, device).copyfrom(arr)
return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)


# Register back to FFI
Expand Down