Skip to content

Commit

Permalink
fix sd-checkpoint switching issue cause by #14170
Browse files Browse the repository at this point in the history
  • Loading branch information
w-e-w committed Dec 2, 2023
1 parent 4a66638 commit be0e6f8
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None

sgm_original_forward = None
ldm_original_forward = None


def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()

Expand Down Expand Up @@ -255,8 +259,13 @@ def flatten(el):

import modules.models.diffusion.ddpm_edit

ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
global sgm_original_forward
global ldm_original_forward
try:
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
except RuntimeError:
pass

if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
Expand All @@ -267,7 +276,6 @@ def flatten(el):
else:
sd_unet.original_forward = None


def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
Expand Down

0 comments on commit be0e6f8

Please sign in to comment.