Skip to content

Commit

Permalink
[MXNET-1403] Disable numpy's writability of NDArray once it is zero-c…
Browse files Browse the repository at this point in the history
…opied to MXNet (apache#14948)

* Initial commit

* update

* Update test_ndarray.py

* Retrigger
  • Loading branch information
junrushao authored and haohuw committed Jun 23, 2019
1 parent 1aad52e commit 64b7661
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 7 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4212,7 +4212,12 @@ def dl_managed_tensor_deleter(dl_managed_tensor_handle):


def from_numpy(ndarray, zero_copy=True):
"""Returns an MXNet's NDArray backed by Numpy's ndarray.
"""Returns an MXNet's ndarray backed by numpy's ndarray.
When `zero_copy` is set to be true,
this API consumes numpy's ndarray and produces MXNet's ndarray
without having to copy the content. In this case, we disallow
users to modify the given numpy ndarray, and it is suggested
not to read the numpy ndarray as well for internal correctness.
Parameters
----------
Expand Down Expand Up @@ -4261,6 +4266,7 @@ def _make_dl_managed_tensor(array):

if not ndarray.flags['C_CONTIGUOUS']:
raise ValueError("Only c-contiguous arrays are supported for zero-copy")
ndarray.flags['WRITEABLE'] = False
c_obj = _make_dl_managed_tensor(ndarray)
address = ctypes.addressof(c_obj)
address = ctypes.cast(address, ctypes.c_void_p)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,8 +1687,8 @@ def test_zero_from_numpy():
mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
np_array = arrays[0]
mx_array = mx.nd.from_numpy(np_array)
np_array[2, 1] = 0
mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
assertRaises(ValueError, np_array.__setitem__, (2, 1), 0)

mx_array[2, 1] = 100
mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
np_array = np.array([[1, 2], [3, 4], [5, 6]]).transpose()
Expand Down

0 comments on commit 64b7661

Please sign in to comment.