diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 515bf38a39..c104da702c 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -22,7 +22,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -57,6 +57,13 @@ cp, has_cp = optional_import("cupy") kvikio, has_kvikio = optional_import("kvikio") +if TYPE_CHECKING: + import cupy + + NdarrayOrCupy = Union[np.ndarray, cupy.ndarray] +else: + NdarrayOrCupy = Any + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -663,10 +670,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape dicom_data.append((data_array, metadata)) - # TODO: the actual type is list[np.ndarray | cp.ndarray] - # should figure out how to define correct types without having cupy not found error - # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 - img_array: list[np.ndarray] = [] + img_array: list[NdarrayOrCupy] = [] compatible_meta: dict = {} for data_array, metadata in ensure_tuple(dicom_data): @@ -1104,10 +1108,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ - # TODO: the actual type is list[np.ndarray | cp.ndarray] - # should figure out how to define correct types without having cupy not found error - # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 - img_array: list[np.ndarray] = [] + img_array: list[NdarrayOrCupy] = [] compatible_meta: dict = {} for i, filename in zip(ensure_tuple(img), self.filenames):