Skip to content

Commit

Permalink
add swinir v2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
C43H66N12O12S2 authored and AUTOMATIC1111 committed Oct 10, 2022
1 parent ece27fe commit ed76997
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions modules/swinir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from modules import modelloader
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData

precision_scope = (
Expand Down Expand Up @@ -57,22 +58,42 @@ def load_model(self, path, scale=4):
filename = path
if filename is None or not os.path.exists(filename):
return None
model = net(
if filename.endswith(".v2.pth"):
model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
resi_connection="1conv",
)
params = None
else:
model = net(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"

pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True)
if params is not None:
model.load_state_dict(pretrained_model[params], strict=True)
else:
model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
Expand Down

0 comments on commit ed76997

Please sign in to comment.