Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,109 @@ def numpy(self) -> np.ndarray:
"""
return self._ortvalue.numpy()

def __array__(self, dtype=None, copy=None) -> np.ndarray:
"""
Supports ``numpy.asarray(ortvalue)`` and ``numpy.array(ortvalue)`` via the
`numpy __array__ protocol <https://numpy.org/devdocs/user/basics.interoperability.html>`_.

Valid only for OrtValues holding Tensors on CPU.

:param dtype: Optional numpy dtype to cast the result to.
:param copy: Optional bool (numpy >= 2.0). If ``False``, a copy will
only be made if necessary. If ``True``, a copy is always forced.
If ``None`` (default), a copy will be made only if needed.
:return: A numpy array with the same data as the OrtValue.
"""
import numpy as np # noqa: PLC0415

arr = self.numpy()

if copy is not None:
# numpy >= 2.0 added the copy kwarg to np.asarray;
# np.array has always accepted it but with weaker semantics pre-2.0.
arr = np.array(arr, dtype=dtype, copy=copy)
elif dtype is not None:
# np.asarray avoids a copy when the dtype already matches,
# preserving memory sharing with the underlying OrtValue.
arr = np.asarray(arr, dtype=dtype)

return arr

def __dlpack__(self, *, stream=None):
"""
Returns a DLPack capsule representing the tensor (part of the
`DLPack protocol <https://dmlc.github.io/dlpack/latest/>`_).

This enables interoperability with other frameworks via
``from_dlpack(ortvalue)`` (e.g. ``torch.from_dlpack``,
``jax.dlpack.from_dlpack``, ``numpy.from_dlpack``).

The OrtValue must hold a contiguous tensor. No data is copied;
the consumer shares memory with this OrtValue, which must remain
alive while the capsule is in use.

:param stream: Optional stream on which the tensor data is accessible.
Currently unused; included for protocol compliance.
:return: A PyCapsule holding a DLManagedTensor.
"""
return self._ortvalue.__dlpack__(stream=stream)

def __dlpack_device__(self) -> tuple[int, int]:
"""
Returns ``(device_type, device_id)`` indicating where the tensor data
resides (part of the `DLPack protocol
<https://dmlc.github.io/dlpack/latest/>`_).

:return: Tuple of ``(device_type, device_id)`` as ints following DLPack
``DLDeviceType`` enum values.
"""
return self._ortvalue.__dlpack_device__()

@classmethod
def from_dlpack(cls, data, /) -> OrtValue:
"""
Construct an OrtValue from an object that implements the DLPack protocol.

Accepts either:

* An object with ``__dlpack__`` / ``__dlpack_device__`` methods
(e.g. a PyTorch tensor, JAX array, or numpy array).
* A raw DLPack PyCapsule (legacy path).

Boolean tensors are automatically detected when the source object
exposes a ``dtype`` attribute (numpy, PyTorch, etc.) or is an
``OrtValue``. For raw DLPack capsules where the original dtype cannot
be inspected, bool tensors encoded as uint8 by older DLPack versions
are not distinguishable from true uint8 tensors and will be imported
as uint8.

No data is copied; the new OrtValue shares memory with the source.

:param data: A tensor object supporting the DLPack protocol, or a raw
DLPack PyCapsule.
:return: An OrtValue wrapping the tensor data.
"""
# Detect boolean dtype from the source object before consuming it,
# because DLPack encodes bool as uint8 and the capsule alone cannot
# distinguish between the two.
is_bool = False
if isinstance(data, OrtValue):
Comment thread
tianleiwu marked this conversation as resolved.
is_bool = data.data_type() == "tensor(bool)"
elif hasattr(data, "dtype"):
dtype_obj = data.dtype
# Use .name when available (numpy, cupy, tensorflow all expose it).
# Fall back to str() for frameworks that don't (e.g. PyTorch).
dtype_name = getattr(dtype_obj, "name", str(dtype_obj))
is_bool = dtype_name in ("bool", "bool_", "torch.bool")

# If the input supports the __dlpack__ protocol, call it to get the capsule.
Comment thread
tianleiwu marked this conversation as resolved.
if hasattr(data, "__dlpack__"):
capsule = data.__dlpack__()
else:
capsule = data

return cls(C.OrtValue.from_dlpack(capsule, is_bool))

def update_inplace(self, np_arr) -> None:
"""
Update the OrtValue in place with a new Numpy array. The numpy contents
Expand Down
91 changes: 91 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,97 @@ def test_ort_value_dlpack_zero_size(self):
ortvalue2 = C.OrtValue.from_dlpack(dlp2, False)
self.assertEqual(list(shape), list(ortvalue2.shape()))

def test_ort_value_array_protocol(self):
"""Test that OrtValue supports numpy's __array__ protocol."""
numpy_arr = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr)

# np.asarray should work via __array__ and share memory (zero-copy)
result = np.asarray(ortvalue)
np.testing.assert_equal(numpy_arr, result)
self.assertEqual(result.dtype, np.float32)
self.assertEqual(ortvalue.data_ptr(), result.ctypes.data)

# np.array should also work
result2 = np.array(ortvalue)
np.testing.assert_equal(numpy_arr, result2)

# same dtype should still share memory (no unnecessary copy)
result_same = np.asarray(ortvalue, dtype=np.float32)
np.testing.assert_equal(numpy_arr, result_same)
self.assertEqual(ortvalue.data_ptr(), result_same.ctypes.data)

# dtype conversion via __array__
result_f64 = np.asarray(ortvalue, dtype=np.float64)
np.testing.assert_equal(numpy_arr.astype(np.float64), result_f64)
self.assertEqual(result_f64.dtype, np.float64)

# Integer tensor
int_arr = np.array([1, 2, 3], dtype=np.int64)
ortvalue_int = onnxrt.OrtValue.ortvalue_from_numpy(int_arr)
result_int = np.asarray(ortvalue_int)
np.testing.assert_equal(int_arr, result_int)
self.assertEqual(result_int.dtype, np.int64)

# Boolean tensor
bool_arr = np.array([True, False, True], dtype=np.bool_)
ortvalue_bool = onnxrt.OrtValue.ortvalue_from_numpy(bool_arr)
result_bool = np.asarray(ortvalue_bool)
np.testing.assert_equal(bool_arr, result_bool)
self.assertEqual(result_bool.dtype, np.bool_)

@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
def test_ort_value_dlpack_protocol(self):
"""Test that OrtValue exposes __dlpack__ and __dlpack_device__ protocols."""
numpy_arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr)

# __dlpack_device__ should return (device_type, device_id) for CPU
device = ortvalue.__dlpack_device__()
self.assertEqual((1, 0), device)

# __dlpack__ should return a capsule that can be consumed by from_dlpack
dlp = ortvalue.__dlpack__()
ortvalue2 = onnxrt.OrtValue.from_dlpack(dlp)
np.testing.assert_equal(numpy_arr, ortvalue2.numpy())

@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
def test_ort_value_from_dlpack_protocol_object(self):
"""Test OrtValue.from_dlpack with objects implementing __dlpack__ protocol."""
numpy_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)

# numpy arrays support __dlpack__ protocol since numpy 1.22
if hasattr(numpy_arr, "__dlpack__"):
ortvalue = onnxrt.OrtValue.from_dlpack(numpy_arr)
np.testing.assert_equal(numpy_arr, ortvalue.numpy())
self.assertEqual(list(numpy_arr.shape), list(ortvalue.shape()))

# Round-trip: numpy -> OrtValue -> OrtValue (via __dlpack__)
ortvalue_src = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr)
ortvalue_dst = onnxrt.OrtValue.from_dlpack(ortvalue_src)
np.testing.assert_equal(numpy_arr, ortvalue_dst.numpy())
# Verify shared memory (no copy)
self.assertEqual(ortvalue_src.data_ptr(), ortvalue_dst.data_ptr())

@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
def test_ort_value_from_dlpack_bool(self):
"""Test that from_dlpack auto-detects boolean tensors."""
bool_arr = np.array([True, False, True, False], dtype=np.bool_)
ortvalue_src = onnxrt.OrtValue.ortvalue_from_numpy(bool_arr)

# Round-trip through DLPack should preserve bool dtype
ortvalue_dst = onnxrt.OrtValue.from_dlpack(ortvalue_src)
result = ortvalue_dst.numpy()
np.testing.assert_equal(bool_arr, result)

# Ensure uint8 is NOT falsely detected as bool
uint8_arr = np.array([1, 2, 255], dtype=np.uint8)
ortvalue_uint8 = onnxrt.OrtValue.ortvalue_from_numpy(uint8_arr)
ortvalue_uint8_dst = onnxrt.OrtValue.from_dlpack(ortvalue_uint8)
result_uint8 = ortvalue_uint8_dst.numpy()
np.testing.assert_equal(uint8_arr, result_uint8)
self.assertEqual(result_uint8.dtype, np.uint8)

def test_sparse_tensor_coo_format(self):
cpu_device = onnxrt.OrtDevice.make("cpu", 0)
shape = [9, 9]
Expand Down
Loading