Skip to content

Commit

Permalink
feat: ✨ Add support for extra_model_paths.yaml
Browse files Browse the repository at this point in the history
closes #66
  • Loading branch information
melMass committed Oct 10, 2023
1 parent af94203 commit d7b8ac8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 57 deletions.
13 changes: 8 additions & 5 deletions nodes/deep_bump.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import tempfile
from pathlib import Path

import folder_paths
import numpy as np
import onnxruntime as ort
import torch
from PIL import Image

from ..errors import ModelNotFound
from ..log import mklog
from ..utils import (models_dir, tensor2pil, tiles_infer, tiles_merge,
tiles_split)
from ..utils import get_model_path, tensor2pil, tiles_infer, tiles_merge, tiles_split

# Disable MS telemetry
ort.disable_telemetry_events()
Expand Down Expand Up @@ -54,9 +55,11 @@ def color_to_normals(color_img, overlap, progress_callback, save_temp=False):

# Load model
log.debug("DeepBump Color → Normals : loading model")
ort_session = ort.InferenceSession(
(models_dir / "deepbump" / "deepbump256.onnx").as_posix()
)
model = get_model_path("deepbump", "deepbump256.onnx")
if not model or not model.exists():
raise ModelNotFound(f"deepbump ({model})")

ort_session = ort.InferenceSession(model)

# Predict normal map for each tile
log.debug("DeepBump Color → Normals : generating")
Expand Down
28 changes: 14 additions & 14 deletions nodes/faceenhance.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from gfpgan import GFPGANer
import cv2
import numpy as np
import os
from pathlib import Path
import folder_paths
from ..utils import pil2tensor, np2tensor, tensor2np
from typing import Tuple

import comfy
import comfy.utils
import cv2
import folder_paths
import numpy as np
import torch
from basicsr.utils import imwrite


from comfy import model_management
from gfpgan import GFPGANer
from PIL import Image
import torch

from ..log import NullWriter, log
from comfy import model_management
import comfy
import comfy.utils
from typing import Tuple
from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np


class LoadFaceEnhanceModel:
Expand All @@ -26,11 +25,12 @@ def __init__(self) -> None:

@classmethod
def get_models_root(cls):
fr = Path(folder_paths.models_dir) / "face_restore"
fr = get_model_path("face_restore")
# fr = Path(folder_paths.models_dir) / "face_restore"
if fr.exists():
return (fr, None)

um = Path(folder_paths.models_dir) / "upscale_models"
um = get_model_path("upscale_models")
return (fr, um) if um.exists() else (None, None)

@classmethod
Expand Down
44 changes: 17 additions & 27 deletions nodes/faceswap.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# Optional face enhance nodes
# region imports
import onnxruntime
import sys
from pathlib import Path
from PIL import Image
from typing import List, Set, Union, Optional
from typing import List, Optional, Set, Union

import comfy.model_management as model_management
import cv2
import folder_paths
import glob
import insightface
import numpy as np
import os
import onnxruntime
import torch
from insightface.model_zoo.inswapper import INSwapper
from ..utils import pil2tensor, tensor2pil, download_antelopev2
from ..log import mklog, NullWriter
import sys
import comfy.model_management as model_management
from PIL import Image

from ..errors import ModelNotFound
from ..log import NullWriter, mklog
from ..utils import download_antelopev2, get_model_path, pil2tensor, tensor2pil

# endregion

Expand All @@ -27,15 +27,6 @@ class LoadFaceAnalysisModel:

models = []

@staticmethod
def get_models() -> List[str]:
models_path = os.path.join(folder_paths.models_dir, "insightface/*")
models = glob.glob(models_path)
models = [
Path(x).name for x in models if x.endswith(".onnx") or x.endswith(".pth")
]
return models

@classmethod
def INPUT_TYPES(cls):
return {
Expand All @@ -57,7 +48,7 @@ def load_model(self, faceswap_model: str):

face_analyser = insightface.app.FaceAnalysis(
name=faceswap_model,
root=os.path.join(folder_paths.models_dir, "insightface"),
root=get_model_path("insightface"),
)
return (face_analyser,)

Expand All @@ -67,10 +58,8 @@ class LoadFaceSwapModel:

@staticmethod
def get_models() -> List[Path]:
models_path = os.path.join(folder_paths.models_dir, "insightface/*")
models = glob.glob(models_path)
models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")]
return models
models_path = get_model_path("insightface").iterdir()
return [x for x in models_path if x.suffix in [".onnx", ".pth"]]

@classmethod
def INPUT_TYPES(cls):
Expand All @@ -88,9 +77,10 @@ def INPUT_TYPES(cls):
CATEGORY = "mtb/facetools"

def load_model(self, faceswap_model: str):
model_path = os.path.join(
folder_paths.models_dir, "insightface", faceswap_model
)
model_path = get_model_path("insightface", faceswap_model)
if not model_path or not model_path.exists():
raise ModelNotFound(f"{faceswap_model} ({model_path})")

log.info(f"Loading model {model_path}")
return (
INSwapper(
Expand Down
15 changes: 10 additions & 5 deletions nodes/image_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import torch
from frame_interpolation.eval import interpolator, util

from utils import get_model_path

from ..errors import ModelNotFound
from ..log import log


Expand All @@ -20,10 +23,9 @@ class LoadFilmModel:

@staticmethod
def get_models() -> List[Path]:
models_path = os.path.join(folder_paths.models_dir, "FILM/*")
models = glob.glob(models_path)
models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")]
return models
models_paths = get_model_path("FILM").iterdir()

return [x for x in models_paths if x.suffix in [".onnx", ".pth"]]

@classmethod
def INPUT_TYPES(cls):
Expand All @@ -41,7 +43,10 @@ def INPUT_TYPES(cls):
CATEGORY = "mtb/frame iterpolation"

def load_model(self, film_model: str):
model_path = Path(folder_paths.models_dir) / "FILM" / film_model
model_path = get_model_path("FILM", film_model)
if not model_path or not model_path.exists():
raise ModelNotFound(f"FILM ({model_path})")

if not (model_path / "saved_model.pb").exists():
model_path = model_path / "saved_model"

Expand Down
35 changes: 29 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import contextlib
import functools
import math
import os
import shlex
import shutil
import signal
import socket
import subprocess
import sys
import threading
import uuid
from contextlib import suppress
from pathlib import Path
from queue import Empty, Queue
from typing import List, Optional, Union

import folder_paths
Expand Down Expand Up @@ -506,12 +503,11 @@ def download_antelopev2():
antelopev2_url = "https://drive.google.com/uc?id=18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8"

try:
import folder_paths
import gdown

log.debug("Loading antelopev2 model")

dest = Path(folder_paths.models_dir) / "insightface"
dest = get_model_path("insightface")
archive = dest / "antelopev2.zip"
final_path = dest / "models" / "antelopev2"
if not final_path.exists():
Expand All @@ -538,6 +534,33 @@ def download_antelopev2():
raise e


def get_model_path(fam, model=None):
log.debug(f"Requesting {fam} with model {model}")
res = None
if model:
res = folder_paths.get_full_path(fam, model)
else:
# this one can raise errors...
with contextlib.suppress(KeyError):
res = folder_paths.get_folder_paths(fam)

if res:
if isinstance(res, list):
if len(res) > 1:
log.warning(
f"Found multiple match, we will pick the first {res[0]}\n{res}"
)
res = res[0]
res = Path(res)
log.debug(f"Resolved model path from folder_paths: {res}")
else:
res = models_dir / fam
if model:
res /= model

return res


# endregion


Expand Down

0 comments on commit d7b8ac8

Please sign in to comment.