diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 6b5f7526cc506..785eb9c485d25 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -15,6 +15,9 @@ from onnxruntime.capi import _pybind_state as C if typing.TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + import onnxruntime @@ -59,22 +62,22 @@ def export_adapter(self, file_path: os.PathLike): """ self._adapter.export_adapter(file_path) - def get_format_version(self): + def get_format_version(self) -> int: return self._adapter.format_version - def set_adapter_version(self, adapter_version: int): + def set_adapter_version(self, adapter_version: int) -> None: self._adapter.adapter_version = adapter_version - def get_adapter_version(self): + def get_adapter_version(self) -> int: return self._adapter.adapter_version - def set_model_version(self, model_version: int): + def set_model_version(self, model_version: int) -> None: self._adapter.model_version = model_version - def get_model_version(self): + def get_model_version(self) -> int: return self._adapter.model_version - def set_parameters(self, params: dict[str, OrtValue]): + def set_parameters(self, params: dict[str, OrtValue]) -> None: self._adapter.parameters = {k: v._ortvalue for k, v in params.items()} def get_parameters(self) -> dict[str, OrtValue]: @@ -174,27 +177,27 @@ def __init__(self): self._sess = None self._enable_fallback = True - def get_session_options(self): + def get_session_options(self) -> onnxruntime.SessionOptions: "Return the session options. See :class:`onnxruntime.SessionOptions`." return self._sess_options - def get_inputs(self): + def get_inputs(self) -> Sequence[onnxruntime.NodeArg]: "Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`." return self._inputs_meta - def get_outputs(self): + def get_outputs(self) -> Sequence[onnxruntime.NodeArg]: "Return the outputs metadata as a list of :class:`onnxruntime.NodeArg`." return self._outputs_meta - def get_overridable_initializers(self): + def get_overridable_initializers(self) -> Sequence[onnxruntime.NodeArg]: "Return the inputs (including initializers) metadata as a list of :class:`onnxruntime.NodeArg`." return self._overridable_initializers - def get_modelmeta(self): + def get_modelmeta(self) -> onnxruntime.ModelMetadata: "Return the metadata. See :class:`onnxruntime.ModelMetadata`." return self._model_meta - def get_providers(self): + def get_providers(self) -> Sequence[str]: "Return list of registered execution providers." return self._providers @@ -202,7 +205,7 @@ def get_provider_options(self): "Return registered execution providers' configurations." return self._provider_options - def set_providers(self, providers=None, provider_options=None): + def set_providers(self, providers=None, provider_options=None) -> None: """ Register the input list of execution providers. The underlying session is re-created. @@ -224,13 +227,13 @@ def set_providers(self, providers=None, provider_options=None): # recreate the underlying C.InferenceSession self._reset_session(providers, provider_options) - def disable_fallback(self): + def disable_fallback(self) -> None: """ Disable session.run() fallback mechanism. """ self._enable_fallback = False - def enable_fallback(self): + def enable_fallback(self) -> None: """ Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure, reset the Execution Providers enabled for this session. @@ -249,7 +252,7 @@ def _validate_input(self, feed_input_names): f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})." ) - def run(self, output_names, input_feed, run_options=None): + def run(self, output_names, input_feed, run_options=None) -> Sequence[np.ndarray | SparseTensor | list | dict]: """ Compute the predictions. @@ -308,7 +311,7 @@ def callback(results: np.ndarray, user_data: MyData, err: str) -> None: output_names = [output.name for output in self._outputs_meta] return self._sess.run_async(output_names, input_feed, callback, user_data, run_options) - def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None): + def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None) -> Sequence[OrtValue]: """ Compute the predictions. @@ -367,7 +370,7 @@ def get_profiling_start_time_ns(self): """ return self._sess.get_profiling_start_time_ns - def io_binding(self): + def io_binding(self) -> IOBinding: "Return an onnxruntime.IOBinding object`." return IOBinding(self) @@ -550,7 +553,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._provider_options = self._sess.get_provider_options() self._profiling_start_time_ns = self._sess.get_profiling_start_time_ns - def _reset_session(self, providers, provider_options): + def _reset_session(self, providers, provider_options) -> None: "release underlying session object." # meta data references session internal structures # so they must be set to None to decrement _sess reference count. @@ -721,7 +724,7 @@ class OrtValue: This class provides APIs to construct and deal with OrtValues. """ - def __init__(self, ortvalue, numpy_obj=None): + def __init__(self, ortvalue: C.OrtValue, numpy_obj: np.ndarray | None = None): if isinstance(ortvalue, C.OrtValue): self._ortvalue = ortvalue # Hold a ref count to the numpy object if the OrtValue is backed directly @@ -733,11 +736,11 @@ def __init__(self, ortvalue, numpy_obj=None): "`Provided ortvalue` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`" ) - def _get_c_value(self): + def _get_c_value(self) -> C.OrtValue: return self._ortvalue - @staticmethod - def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0): + @classmethod + def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue: """ Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu @@ -749,7 +752,7 @@ def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0): # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue # is backed directly by the data buffer of the numpy object and so the numpy object # must be around until this OrtValue instance is around - return OrtValue( + return cls( C.OrtValue.ortvalue_from_numpy( numpy_obj, C.OrtDevice( @@ -761,8 +764,8 @@ def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0): numpy_obj if device_type.lower() == "cpu" else None, ) - @staticmethod - def ortvalue_from_numpy_with_onnx_type(data, onnx_element_type: int): + @classmethod + def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_type: int) -> OrtValue: """ This method creates an instance of OrtValue on top of the numpy array. No data copy is made and the lifespan of the resulting OrtValue should never @@ -771,12 +774,14 @@ def ortvalue_from_numpy_with_onnx_type(data, onnx_element_type: int): when we want to use an ONNX data type that is not supported by numpy. :param data: numpy.ndarray. - :param onnx_elemenet_type: a valid onnx TensorProto::DataType enum value + :param onnx_element_type: a valid onnx TensorProto::DataType enum value """ - return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data) + return cls(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data) - @staticmethod - def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", device_id: int = 0): + @classmethod + def ortvalue_from_shape_and_type( + cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0 + ) -> OrtValue: """ Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type @@ -788,7 +793,7 @@ def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html). # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy. if isinstance(element_type, int): - return OrtValue( + return cls( C.OrtValue.ortvalue_from_shape_and_onnx_type( shape, element_type, @@ -800,7 +805,7 @@ def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", ) ) - return OrtValue( + return cls( C.OrtValue.ortvalue_from_shape_and_type( shape, element_type, @@ -812,77 +817,77 @@ def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", ) ) - @staticmethod - def ort_value_from_sparse_tensor(sparse_tensor): + @classmethod + def ort_value_from_sparse_tensor(cls, sparse_tensor: SparseTensor) -> OrtValue: """ The function will construct an OrtValue instance from a valid SparseTensor The new instance of OrtValue will assume the ownership of sparse_tensor """ - return OrtValue(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor())) + return cls(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor())) - def as_sparse_tensor(self): + def as_sparse_tensor(self) -> SparseTensor: """ The function will return SparseTensor contained in this OrtValue """ return SparseTensor(self._ortvalue.as_sparse_tensor()) - def data_ptr(self): + def data_ptr(self) -> int: """ Returns the address of the first element in the OrtValue's data buffer """ return self._ortvalue.data_ptr() - def device_name(self): + def device_name(self) -> str: """ Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda, cann """ return self._ortvalue.device_name().lower() - def shape(self): + def shape(self) -> Sequence[int]: """ Returns the shape of the data in the OrtValue """ return self._ortvalue.shape() - def data_type(self): + def data_type(self) -> str: """ - Returns the data type of the data in the OrtValue + Returns the data type of the data in the OrtValue. E.g. 'tensor(int64)' """ return self._ortvalue.data_type() - def element_type(self): + def element_type(self) -> int: """ Returns the proto type of the data in the OrtValue if the OrtValue is a tensor. """ return self._ortvalue.element_type() - def has_value(self): + def has_value(self) -> bool: """ Returns True if the OrtValue corresponding to an optional type contains data, else returns False """ return self._ortvalue.has_value() - def is_tensor(self): + def is_tensor(self) -> bool: """ Returns True if the OrtValue contains a Tensor, else returns False """ return self._ortvalue.is_tensor() - def is_sparse_tensor(self): + def is_sparse_tensor(self) -> bool: """ Returns True if the OrtValue contains a SparseTensor, else returns False """ return self._ortvalue.is_sparse_tensor() - def is_tensor_sequence(self): + def is_tensor_sequence(self) -> bool: """ Returns True if the OrtValue contains a Tensor Sequence, else returns False """ return self._ortvalue.is_tensor_sequence() - def numpy(self): + def numpy(self) -> np.ndarray: """ Returns a Numpy object from the OrtValue. Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors. @@ -890,7 +895,7 @@ def numpy(self): """ return self._ortvalue.numpy() - def update_inplace(self, np_arr): + def update_inplace(self, np_arr) -> None: """ Update the OrtValue in place with a new Numpy array. The numpy contents are copied over to the device memory backing the OrtValue. It can be used @@ -948,7 +953,7 @@ class SparseTensor: depending on the format """ - def __init__(self, sparse_tensor): + def __init__(self, sparse_tensor: C.SparseTensor): """ Internal constructor """ @@ -960,11 +965,17 @@ def __init__(self, sparse_tensor): "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`" ) - def _get_c_tensor(self): + def _get_c_tensor(self) -> C.SparseTensor: return self._tensor - @staticmethod - def sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device): + @classmethod + def sparse_coo_from_numpy( + cls, + dense_shape: npt.NDArray[np.int64], + values: np.ndarray, + coo_indices: npt.NDArray[np.int64], + ort_device: OrtDevice, + ) -> SparseTensor: """ Factory method to construct a SparseTensor in COO format from given arguments @@ -985,12 +996,17 @@ def sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device): For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those on other devices and their memory can not be mapped. """ - return SparseTensor( - C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device()) - ) + return cls(C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device())) - @staticmethod - def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort_device): + @classmethod + def sparse_csr_from_numpy( + cls, + dense_shape: npt.NDArray[np.int64], + values: np.ndarray, + inner_indices: npt.NDArray[np.int64], + outer_indices: npt.NDArray[np.int64], + ort_device: OrtDevice, + ) -> SparseTensor: """ Factory method to construct a SparseTensor in CSR format from given arguments @@ -1011,7 +1027,7 @@ def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those on other devices and their memory can not be mapped. """ - return SparseTensor( + return cls( C.SparseTensor.sparse_csr_from_numpy( dense_shape, values, @@ -1021,7 +1037,7 @@ def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort ) ) - def values(self): + def values(self) -> np.ndarray: """ The method returns a numpy array that is backed by the native memory if the data type is numeric. Otherwise, the returned numpy array that contains @@ -1093,19 +1109,19 @@ def format(self): """ return self._tensor.format - def dense_shape(self): + def dense_shape(self) -> npt.NDArray[np.int64]: """ Returns a numpy array(int64) containing a dense shape of a sparse tensor """ return self._tensor.dense_shape() - def data_type(self): + def data_type(self) -> str: """ Returns a string data type of the data in the OrtValue """ return self._tensor.data_type() - def device_name(self): + def device_name(self) -> str: """ Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda """