From 629e2b5f5fbebe4e79e8b7a4cff2de6017e79225 Mon Sep 17 00:00:00 2001 From: melMass Date: Sat, 15 Jul 2023 00:39:15 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=A8=20add=20support=20for=20im?= =?UTF-8?q?age.size(0)=20=3D=3D=200?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Attempt at ignoring a branch if the image size is 0. --- nodes/image_interpolation.py | 6 ++++++ utils.py | 22 ++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/nodes/image_interpolation.py b/nodes/image_interpolation.py index 1f94a5a..cfaf337 100644 --- a/nodes/image_interpolation.py +++ b/nodes/image_interpolation.py @@ -82,6 +82,10 @@ def do_interpolation( film_model: interpolator.Interpolator, ): n = images.size(0) + # check if images is an empty tensor and return it... + if n == 0: + return (images,) + # check if tensorflow GPU is available available_gpus = tf.config.list_physical_devices("GPU") if not len(available_gpus): @@ -184,6 +188,8 @@ def export_prores( fps: float, prefix: str, ): + if images.size(0) == 0: + return ("",) output_dir = Path(folder_paths.get_output_directory()) id = f"{prefix}_{uuid.uuid4()}.mov" diff --git a/utils.py b/utils.py index ce94de4..7ab1bcc 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ import sys from typing import Union, List +from .log import log def add_path(path, prepend=False): @@ -45,16 +46,21 @@ def add_path(path, prepend=False): add_path((comfy_dir / "custom_nodes")) -def tensor2pil(image: torch.Tensor) -> Union[Image.Image, List[Image.Image]]: +def tensor2pil(image: torch.Tensor) -> List[Image.Image]: batch_count = 1 if len(image.shape) > 3: batch_count = image.size(0) - if batch_count == 1: - return Image.fromarray( + if batch_count > 1: + out = [] + out.extend([tensor2pil(image[i]) for i in range(batch_count)]) + return out + + return [ + Image.fromarray( np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) ) - return [tensor2pil(image[i]) for i in range(batch_count)] + ] def pil2tensor(image: Image.Image | List[Image.Image]) -> torch.Tensor: @@ -76,5 +82,9 @@ def tensor2np(tensor: torch.Tensor) -> Union[np.ndarray, List[np.ndarray]]: if len(tensor.shape) > 3: batch_count = tensor.size(0) if batch_count > 1: - return [tensor2np(tensor[i]) for i in range(batch_count)] - return np.clip(255.0 * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + out = [] + out.extend([tensor2np(tensor[i]) for i in range(batch_count)]) + return out + + return [np.clip(255.0 * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)] +