Skip to content

Commit b334605

Browse files
Fix OOMs happening in some cases.
A cloned model patcher sometimes reported a model was loaded on a device when it wasn't.
1 parent de17a97 commit b334605

File tree

4 files changed

+17
-11
lines changed

4 files changed

+17
-11
lines changed

comfy/model_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
7474
self.latent_format = model_config.latent_format
7575
self.model_config = model_config
7676
self.manual_cast_dtype = model_config.manual_cast_dtype
77+
self.device = device
7778

7879
if not unet_config.get("disable_unet_model_creation", False):
7980
if self.manual_cast_dtype is not None:

comfy/model_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def model_memory(self):
274274
return self.model.model_size()
275275

276276
def model_memory_required(self, device):
277-
if device == self.model.current_device:
277+
if device == self.model.current_loaded_device():
278278
return 0
279279
else:
280280
return self.model_memory()

comfy/model_patcher.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
6464
return model_options
6565

6666
class ModelPatcher:
67-
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
67+
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
6868
self.size = size
6969
self.model = model
70+
if not hasattr(self.model, 'device'):
71+
logging.info("Model doesn't have a device attribute.")
72+
self.model.device = offload_device
73+
elif self.model.device is None:
74+
self.model.device = offload_device
75+
7076
self.patches = {}
7177
self.backup = {}
7278
self.object_patches = {}
@@ -75,11 +81,6 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
7581
self.model_size()
7682
self.load_device = load_device
7783
self.offload_device = offload_device
78-
if current_device is None:
79-
self.current_device = self.offload_device
80-
else:
81-
self.current_device = current_device
82-
8384
self.weight_inplace_update = weight_inplace_update
8485
self.model_lowvram = False
8586
self.lowvram_patch_counter = 0
@@ -92,7 +93,7 @@ def model_size(self):
9293
return self.size
9394

9495
def clone(self):
95-
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
96+
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
9697
n.patches = {}
9798
for k in self.patches:
9899
n.patches[k] = self.patches[k][:]
@@ -302,7 +303,7 @@ def patch_model(self, device_to=None, patch_weights=True):
302303

303304
if device_to is not None:
304305
self.model.to(device_to)
305-
self.current_device = device_to
306+
self.model.device = device_to
306307

307308
return self.model
308309

@@ -355,6 +356,7 @@ def __call__(self, weight):
355356

356357
self.model_lowvram = True
357358
self.lowvram_patch_counter = patch_counter
359+
self.model.device = device_to
358360
return self.model
359361

360362
def calculate_weight(self, patches, weight, key):
@@ -551,10 +553,13 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
551553

552554
if device_to is not None:
553555
self.model.to(device_to)
554-
self.current_device = device_to
556+
self.model.device = device_to
555557

556558
keys = list(self.object_patches_backup.keys())
557559
for k in keys:
558560
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
559561

560562
self.object_patches_backup.clear()
563+
564+
def current_loaded_device(self):
565+
return self.model.device

comfy/sd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
564564
logging.debug("left over keys: {}".format(left_over))
565565

566566
if output_model:
567-
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
567+
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
568568
if inital_load_device != torch.device("cpu"):
569569
logging.info("loaded straight to GPU")
570570
model_management.load_model_gpu(model_patcher)

0 commit comments

Comments
 (0)