From 88a277968745ac990406b14d300a8ada9c575b11 Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Fri, 8 Sep 2023 03:37:45 -0700 Subject: [PATCH] =?UTF-8?q?fix:=20=E2=9C=8F=EF=B8=8F=20use=20Union=20to=20?= =?UTF-8?q?allow=20support=20for=20<3.10=20(#91)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils.py b/utils.py index 401ba4c..05bd75e 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,7 @@ import torch from pathlib import Path import sys -from typing import List, Optional +from typing import List, Optional, Union import signal from contextlib import suppress from queue import Queue, Empty @@ -291,14 +291,14 @@ def tensor2pil(image: torch.Tensor) -> List[Image.Image]: ] -def pil2tensor(image: Image.Image | List[Image.Image]) -> torch.Tensor: +def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor: if isinstance(image, list): return torch.cat([pil2tensor(img) for img in image], dim=0) return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) -def np2tensor(img_np: np.ndarray | List[np.ndarray]) -> torch.Tensor: +def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor: if isinstance(img_np, list): return torch.cat([np2tensor(img) for img in img_np], dim=0)