diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 7d93fed3566bf..def2240358c10 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -1199,6 +1199,109 @@ def numpy(self) -> np.ndarray: """ return self._ortvalue.numpy() + def __array__(self, dtype=None, copy=None) -> np.ndarray: + """ + Supports ``numpy.asarray(ortvalue)`` and ``numpy.array(ortvalue)`` via the + `numpy __array__ protocol `_. + + Valid only for OrtValues holding Tensors on CPU. + + :param dtype: Optional numpy dtype to cast the result to. + :param copy: Optional bool (numpy >= 2.0). If ``False``, a copy will + only be made if necessary. If ``True``, a copy is always forced. + If ``None`` (default), a copy will be made only if needed. + :return: A numpy array with the same data as the OrtValue. + """ + import numpy as np # noqa: PLC0415 + + arr = self.numpy() + + if copy is not None: + # numpy >= 2.0 added the copy kwarg to np.asarray; + # np.array has always accepted it but with weaker semantics pre-2.0. + arr = np.array(arr, dtype=dtype, copy=copy) + elif dtype is not None: + # np.asarray avoids a copy when the dtype already matches, + # preserving memory sharing with the underlying OrtValue. + arr = np.asarray(arr, dtype=dtype) + + return arr + + def __dlpack__(self, *, stream=None): + """ + Returns a DLPack capsule representing the tensor (part of the + `DLPack protocol `_). + + This enables interoperability with other frameworks via + ``from_dlpack(ortvalue)`` (e.g. ``torch.from_dlpack``, + ``jax.dlpack.from_dlpack``, ``numpy.from_dlpack``). + + The OrtValue must hold a contiguous tensor. No data is copied; + the consumer shares memory with this OrtValue, which must remain + alive while the capsule is in use. + + :param stream: Optional stream on which the tensor data is accessible. + Currently unused; included for protocol compliance. + :return: A PyCapsule holding a DLManagedTensor. + """ + return self._ortvalue.__dlpack__(stream=stream) + + def __dlpack_device__(self) -> tuple[int, int]: + """ + Returns ``(device_type, device_id)`` indicating where the tensor data + resides (part of the `DLPack protocol + `_). + + :return: Tuple of ``(device_type, device_id)`` as ints following DLPack + ``DLDeviceType`` enum values. + """ + return self._ortvalue.__dlpack_device__() + + @classmethod + def from_dlpack(cls, data, /) -> OrtValue: + """ + Construct an OrtValue from an object that implements the DLPack protocol. + + Accepts either: + + * An object with ``__dlpack__`` / ``__dlpack_device__`` methods + (e.g. a PyTorch tensor, JAX array, or numpy array). + * A raw DLPack PyCapsule (legacy path). + + Boolean tensors are automatically detected when the source object + exposes a ``dtype`` attribute (numpy, PyTorch, etc.) or is an + ``OrtValue``. For raw DLPack capsules where the original dtype cannot + be inspected, bool tensors encoded as uint8 by older DLPack versions + are not distinguishable from true uint8 tensors and will be imported + as uint8. + + No data is copied; the new OrtValue shares memory with the source. + + :param data: A tensor object supporting the DLPack protocol, or a raw + DLPack PyCapsule. + :return: An OrtValue wrapping the tensor data. + """ + # Detect boolean dtype from the source object before consuming it, + # because DLPack encodes bool as uint8 and the capsule alone cannot + # distinguish between the two. + is_bool = False + if isinstance(data, OrtValue): + is_bool = data.data_type() == "tensor(bool)" + elif hasattr(data, "dtype"): + dtype_obj = data.dtype + # Use .name when available (numpy, cupy, tensorflow all expose it). + # Fall back to str() for frameworks that don't (e.g. PyTorch). + dtype_name = getattr(dtype_obj, "name", str(dtype_obj)) + is_bool = dtype_name in ("bool", "bool_", "torch.bool") + + # If the input supports the __dlpack__ protocol, call it to get the capsule. + if hasattr(data, "__dlpack__"): + capsule = data.__dlpack__() + else: + capsule = data + + return cls(C.OrtValue.from_dlpack(capsule, is_bool)) + def update_inplace(self, np_arr) -> None: """ Update the OrtValue in place with a new Numpy array. The numpy contents diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 73d2ab1938ee2..ae4fc2616168f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1475,6 +1475,97 @@ def test_ort_value_dlpack_zero_size(self): ortvalue2 = C.OrtValue.from_dlpack(dlp2, False) self.assertEqual(list(shape), list(ortvalue2.shape())) + def test_ort_value_array_protocol(self): + """Test that OrtValue supports numpy's __array__ protocol.""" + numpy_arr = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr) + + # np.asarray should work via __array__ and share memory (zero-copy) + result = np.asarray(ortvalue) + np.testing.assert_equal(numpy_arr, result) + self.assertEqual(result.dtype, np.float32) + self.assertEqual(ortvalue.data_ptr(), result.ctypes.data) + + # np.array should also work + result2 = np.array(ortvalue) + np.testing.assert_equal(numpy_arr, result2) + + # same dtype should still share memory (no unnecessary copy) + result_same = np.asarray(ortvalue, dtype=np.float32) + np.testing.assert_equal(numpy_arr, result_same) + self.assertEqual(ortvalue.data_ptr(), result_same.ctypes.data) + + # dtype conversion via __array__ + result_f64 = np.asarray(ortvalue, dtype=np.float64) + np.testing.assert_equal(numpy_arr.astype(np.float64), result_f64) + self.assertEqual(result_f64.dtype, np.float64) + + # Integer tensor + int_arr = np.array([1, 2, 3], dtype=np.int64) + ortvalue_int = onnxrt.OrtValue.ortvalue_from_numpy(int_arr) + result_int = np.asarray(ortvalue_int) + np.testing.assert_equal(int_arr, result_int) + self.assertEqual(result_int.dtype, np.int64) + + # Boolean tensor + bool_arr = np.array([True, False, True], dtype=np.bool_) + ortvalue_bool = onnxrt.OrtValue.ortvalue_from_numpy(bool_arr) + result_bool = np.asarray(ortvalue_bool) + np.testing.assert_equal(bool_arr, result_bool) + self.assertEqual(result_bool.dtype, np.bool_) + + @unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build") + def test_ort_value_dlpack_protocol(self): + """Test that OrtValue exposes __dlpack__ and __dlpack_device__ protocols.""" + numpy_arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr) + + # __dlpack_device__ should return (device_type, device_id) for CPU + device = ortvalue.__dlpack_device__() + self.assertEqual((1, 0), device) + + # __dlpack__ should return a capsule that can be consumed by from_dlpack + dlp = ortvalue.__dlpack__() + ortvalue2 = onnxrt.OrtValue.from_dlpack(dlp) + np.testing.assert_equal(numpy_arr, ortvalue2.numpy()) + + @unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build") + def test_ort_value_from_dlpack_protocol_object(self): + """Test OrtValue.from_dlpack with objects implementing __dlpack__ protocol.""" + numpy_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + + # numpy arrays support __dlpack__ protocol since numpy 1.22 + if hasattr(numpy_arr, "__dlpack__"): + ortvalue = onnxrt.OrtValue.from_dlpack(numpy_arr) + np.testing.assert_equal(numpy_arr, ortvalue.numpy()) + self.assertEqual(list(numpy_arr.shape), list(ortvalue.shape())) + + # Round-trip: numpy -> OrtValue -> OrtValue (via __dlpack__) + ortvalue_src = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr) + ortvalue_dst = onnxrt.OrtValue.from_dlpack(ortvalue_src) + np.testing.assert_equal(numpy_arr, ortvalue_dst.numpy()) + # Verify shared memory (no copy) + self.assertEqual(ortvalue_src.data_ptr(), ortvalue_dst.data_ptr()) + + @unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build") + def test_ort_value_from_dlpack_bool(self): + """Test that from_dlpack auto-detects boolean tensors.""" + bool_arr = np.array([True, False, True, False], dtype=np.bool_) + ortvalue_src = onnxrt.OrtValue.ortvalue_from_numpy(bool_arr) + + # Round-trip through DLPack should preserve bool dtype + ortvalue_dst = onnxrt.OrtValue.from_dlpack(ortvalue_src) + result = ortvalue_dst.numpy() + np.testing.assert_equal(bool_arr, result) + + # Ensure uint8 is NOT falsely detected as bool + uint8_arr = np.array([1, 2, 255], dtype=np.uint8) + ortvalue_uint8 = onnxrt.OrtValue.ortvalue_from_numpy(uint8_arr) + ortvalue_uint8_dst = onnxrt.OrtValue.from_dlpack(ortvalue_uint8) + result_uint8 = ortvalue_uint8_dst.numpy() + np.testing.assert_equal(uint8_arr, result_uint8) + self.assertEqual(result_uint8.dtype, np.uint8) + def test_sparse_tensor_coo_format(self): cpu_device = onnxrt.OrtDevice.make("cpu", 0) shape = [9, 9]