Skip to content

Commit 7433b2f

Browse files
authored
Add optional mem_scope parameter to tvm.nd.array and tvm.nd.copyto (#11717)
1 parent 7e376e2 commit 7433b2f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

python/tvm/runtime/ndarray.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,18 +236,21 @@ def numpy(self):
236236
return np_arr_ret.reshape(shape)
237237
return np_arr
238238

239-
def copyto(self, target):
239+
def copyto(self, target, mem_scope=None):
240240
"""Copy array to target
241241
242242
Parameters
243243
----------
244244
target : NDArray
245245
The target array to be copied, must have same shape as this array.
246+
247+
mem_scope : Optional[str]
248+
The memory scope of the array.
246249
"""
247250
if isinstance(target, NDArrayBase):
248251
return self._copyto(target)
249252
if isinstance(target, Device):
250-
res = empty(self.shape, self.dtype, target)
253+
res = empty(self.shape, self.dtype, target, mem_scope)
251254
return self._copyto(res)
252255
raise ValueError("Unsupported target type %s" % str(type(target)))
253256

@@ -574,7 +577,7 @@ def webgpu(dev_id=0):
574577
mtl = metal
575578

576579

577-
def array(arr, device=cpu(0)):
580+
def array(arr, device=cpu(0), mem_scope=None):
578581
"""Create an array from source arr.
579582
580583
Parameters
@@ -585,6 +588,9 @@ def array(arr, device=cpu(0)):
585588
device : Device, optional
586589
The device device to create the array
587590
591+
mem_scope : Optional[str]
592+
The memory scope of the array
593+
588594
Returns
589595
-------
590596
ret : NDArray
@@ -595,7 +601,7 @@ def array(arr, device=cpu(0)):
595601

596602
if not isinstance(arr, (np.ndarray, NDArray)):
597603
arr = np.array(arr)
598-
return empty(arr.shape, arr.dtype, device).copyfrom(arr)
604+
return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)
599605

600606

601607
# Register back to FFI

0 commit comments

Comments
 (0)