From d7b8ac8e0c98b0d7a2e21889d35aad9f6b093560 Mon Sep 17 00:00:00 2001 From: melMass Date: Tue, 10 Oct 2023 14:31:28 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20Add=20support=20for=20extra?= =?UTF-8?q?=5Fmodel=5Fpaths.yaml?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit closes #66 --- nodes/deep_bump.py | 13 +++++++---- nodes/faceenhance.py | 28 +++++++++++------------ nodes/faceswap.py | 44 ++++++++++++++---------------------- nodes/image_interpolation.py | 15 ++++++++---- utils.py | 35 +++++++++++++++++++++++----- 5 files changed, 78 insertions(+), 57 deletions(-) diff --git a/nodes/deep_bump.py b/nodes/deep_bump.py index fcc4ce1..82639f1 100644 --- a/nodes/deep_bump.py +++ b/nodes/deep_bump.py @@ -1,14 +1,15 @@ import tempfile from pathlib import Path +import folder_paths import numpy as np import onnxruntime as ort import torch from PIL import Image +from ..errors import ModelNotFound from ..log import mklog -from ..utils import (models_dir, tensor2pil, tiles_infer, tiles_merge, - tiles_split) +from ..utils import get_model_path, tensor2pil, tiles_infer, tiles_merge, tiles_split # Disable MS telemetry ort.disable_telemetry_events() @@ -54,9 +55,11 @@ def color_to_normals(color_img, overlap, progress_callback, save_temp=False): # Load model log.debug("DeepBump Color → Normals : loading model") - ort_session = ort.InferenceSession( - (models_dir / "deepbump" / "deepbump256.onnx").as_posix() - ) + model = get_model_path("deepbump", "deepbump256.onnx") + if not model or not model.exists(): + raise ModelNotFound(f"deepbump ({model})") + + ort_session = ort.InferenceSession(model) # Predict normal map for each tile log.debug("DeepBump Color → Normals : generating") diff --git a/nodes/faceenhance.py b/nodes/faceenhance.py index 2e81861..dcb1e06 100644 --- a/nodes/faceenhance.py +++ b/nodes/faceenhance.py @@ -1,21 +1,20 @@ -from gfpgan import GFPGANer -import cv2 -import numpy as np import os from pathlib import Path -import folder_paths -from ..utils import pil2tensor, np2tensor, tensor2np +from typing import Tuple +import comfy +import comfy.utils +import cv2 +import folder_paths +import numpy as np +import torch from basicsr.utils import imwrite - - +from comfy import model_management +from gfpgan import GFPGANer from PIL import Image -import torch + from ..log import NullWriter, log -from comfy import model_management -import comfy -import comfy.utils -from typing import Tuple +from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np class LoadFaceEnhanceModel: @@ -26,11 +25,12 @@ def __init__(self) -> None: @classmethod def get_models_root(cls): - fr = Path(folder_paths.models_dir) / "face_restore" + fr = get_model_path("face_restore") + # fr = Path(folder_paths.models_dir) / "face_restore" if fr.exists(): return (fr, None) - um = Path(folder_paths.models_dir) / "upscale_models" + um = get_model_path("upscale_models") return (fr, um) if um.exists() else (None, None) @classmethod diff --git a/nodes/faceswap.py b/nodes/faceswap.py index 0047ba6..2711426 100644 --- a/nodes/faceswap.py +++ b/nodes/faceswap.py @@ -1,21 +1,21 @@ +# Optional face enhance nodes # region imports -import onnxruntime +import sys from pathlib import Path -from PIL import Image -from typing import List, Set, Union, Optional +from typing import List, Optional, Set, Union + +import comfy.model_management as model_management import cv2 -import folder_paths -import glob import insightface import numpy as np -import os +import onnxruntime import torch from insightface.model_zoo.inswapper import INSwapper -from ..utils import pil2tensor, tensor2pil, download_antelopev2 -from ..log import mklog, NullWriter -import sys -import comfy.model_management as model_management +from PIL import Image +from ..errors import ModelNotFound +from ..log import NullWriter, mklog +from ..utils import download_antelopev2, get_model_path, pil2tensor, tensor2pil # endregion @@ -27,15 +27,6 @@ class LoadFaceAnalysisModel: models = [] - @staticmethod - def get_models() -> List[str]: - models_path = os.path.join(folder_paths.models_dir, "insightface/*") - models = glob.glob(models_path) - models = [ - Path(x).name for x in models if x.endswith(".onnx") or x.endswith(".pth") - ] - return models - @classmethod def INPUT_TYPES(cls): return { @@ -57,7 +48,7 @@ def load_model(self, faceswap_model: str): face_analyser = insightface.app.FaceAnalysis( name=faceswap_model, - root=os.path.join(folder_paths.models_dir, "insightface"), + root=get_model_path("insightface"), ) return (face_analyser,) @@ -67,10 +58,8 @@ class LoadFaceSwapModel: @staticmethod def get_models() -> List[Path]: - models_path = os.path.join(folder_paths.models_dir, "insightface/*") - models = glob.glob(models_path) - models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] - return models + models_path = get_model_path("insightface").iterdir() + return [x for x in models_path if x.suffix in [".onnx", ".pth"]] @classmethod def INPUT_TYPES(cls): @@ -88,9 +77,10 @@ def INPUT_TYPES(cls): CATEGORY = "mtb/facetools" def load_model(self, faceswap_model: str): - model_path = os.path.join( - folder_paths.models_dir, "insightface", faceswap_model - ) + model_path = get_model_path("insightface", faceswap_model) + if not model_path or not model_path.exists(): + raise ModelNotFound(f"{faceswap_model} ({model_path})") + log.info(f"Loading model {model_path}") return ( INSwapper( diff --git a/nodes/image_interpolation.py b/nodes/image_interpolation.py index 0774952..9dd07ca 100644 --- a/nodes/image_interpolation.py +++ b/nodes/image_interpolation.py @@ -12,6 +12,9 @@ import torch from frame_interpolation.eval import interpolator, util +from utils import get_model_path + +from ..errors import ModelNotFound from ..log import log @@ -20,10 +23,9 @@ class LoadFilmModel: @staticmethod def get_models() -> List[Path]: - models_path = os.path.join(folder_paths.models_dir, "FILM/*") - models = glob.glob(models_path) - models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] - return models + models_paths = get_model_path("FILM").iterdir() + + return [x for x in models_paths if x.suffix in [".onnx", ".pth"]] @classmethod def INPUT_TYPES(cls): @@ -41,7 +43,10 @@ def INPUT_TYPES(cls): CATEGORY = "mtb/frame iterpolation" def load_model(self, film_model: str): - model_path = Path(folder_paths.models_dir) / "FILM" / film_model + model_path = get_model_path("FILM", film_model) + if not model_path or not model_path.exists(): + raise ModelNotFound(f"FILM ({model_path})") + if not (model_path / "saved_model.pb").exists(): model_path = model_path / "saved_model" diff --git a/utils.py b/utils.py index 85c05c3..0d15bf2 100644 --- a/utils.py +++ b/utils.py @@ -1,17 +1,14 @@ +import contextlib import functools import math import os import shlex import shutil -import signal import socket import subprocess import sys -import threading import uuid -from contextlib import suppress from pathlib import Path -from queue import Empty, Queue from typing import List, Optional, Union import folder_paths @@ -506,12 +503,11 @@ def download_antelopev2(): antelopev2_url = "https://drive.google.com/uc?id=18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8" try: - import folder_paths import gdown log.debug("Loading antelopev2 model") - dest = Path(folder_paths.models_dir) / "insightface" + dest = get_model_path("insightface") archive = dest / "antelopev2.zip" final_path = dest / "models" / "antelopev2" if not final_path.exists(): @@ -538,6 +534,33 @@ def download_antelopev2(): raise e +def get_model_path(fam, model=None): + log.debug(f"Requesting {fam} with model {model}") + res = None + if model: + res = folder_paths.get_full_path(fam, model) + else: + # this one can raise errors... + with contextlib.suppress(KeyError): + res = folder_paths.get_folder_paths(fam) + + if res: + if isinstance(res, list): + if len(res) > 1: + log.warning( + f"Found multiple match, we will pick the first {res[0]}\n{res}" + ) + res = res[0] + res = Path(res) + log.debug(f"Resolved model path from folder_paths: {res}") + else: + res = models_dir / fam + if model: + res /= model + + return res + + # endregion