Skip to content

Commit

Permalink
gfpgan: just download the damn model
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Sep 23, 2022
1 parent d6fd71f commit d4205e6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
19 changes: 13 additions & 6 deletions modules/gfpgan_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import traceback
from glob import glob

from modules import shared, devices
from modules.shared import cmd_opts
Expand All @@ -11,14 +12,20 @@
def gfpgan_model_path():
from modules.shared import cmd_opts

filemask = 'GFPGAN*.pth'

if cmd_opts.gfpgan_model is not None:
return cmd_opts.gfpgan_model

places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
found = [x for x in files if os.path.exists(x)]

if len(found) == 0:
raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
filename = None
for place in places:
filename = next(iter(glob(os.path.join(place, filemask))), None)
if filename is not None:
break

return found[0]
return filename


loaded_gfpgan_model = None
Expand All @@ -34,7 +41,7 @@ def gfpgan():
if gfpgan_constructor is None:
return None

model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model

Expand Down
3 changes: 1 addition & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import argparse
import json
import os
from glob import glob
import gradio as gr
import tqdm

Expand All @@ -22,7 +21,7 @@
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=next(iter(glob('GFPGAN*.pth')), ''))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
Expand Down

0 comments on commit d4205e6

Please sign in to comment.