Skip to content

Commit be0e6f8

Browse files
committed
fix sd-checkpoint switching issue cause by #14170
1 parent 4a66638 commit be0e6f8

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

modules/sd_hijack.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
optimizers = []
3939
current_optimizer: sd_hijack_optimizations.SdOptimization = None
4040

41+
sgm_original_forward = None
42+
ldm_original_forward = None
43+
44+
4145
def list_optimizers():
4246
new_optimizers = script_callbacks.list_optimizers_callback()
4347

@@ -255,8 +259,13 @@ def flatten(el):
255259

256260
import modules.models.diffusion.ddpm_edit
257261

258-
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
259-
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
262+
global sgm_original_forward
263+
global ldm_original_forward
264+
try:
265+
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
266+
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
267+
except RuntimeError:
268+
pass
260269

261270
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
262271
sd_unet.original_forward = ldm_original_forward
@@ -267,7 +276,6 @@ def flatten(el):
267276
else:
268277
sd_unet.original_forward = None
269278

270-
271279
def undo_hijack(self, m):
272280
conditioner = getattr(m, 'conditioner', None)
273281
if conditioner:

0 commit comments

Comments
 (0)