Skip to content

Commit 08f92d5

Browse files
Partial model shift support.
1 parent 8115d8c commit 08f92d5

File tree

2 files changed

+202
-38
lines changed

2 files changed

+202
-38
lines changed

comfy/model_management.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
"""
2+
This file is part of ComfyUI.
3+
Copyright (C) 2024 Comfy
4+
5+
This program is free software: you can redistribute it and/or modify
6+
it under the terms of the GNU General Public License as published by
7+
the Free Software Foundation, either version 3 of the License, or
8+
(at your option) any later version.
9+
10+
This program is distributed in the hope that it will be useful,
11+
but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
GNU General Public License for more details.
14+
15+
You should have received a copy of the GNU General Public License
16+
along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
"""
18+
119
import psutil
220
import logging
321
from enum import Enum
@@ -273,6 +291,9 @@ def __init__(self, model):
273291
def model_memory(self):
274292
return self.model.model_size()
275293

294+
def model_offloaded_memory(self):
295+
return self.model.model_size() - self.model.loaded_size()
296+
276297
def model_memory_required(self, device):
277298
if device == self.model.current_loaded_device():
278299
return 0
@@ -308,15 +329,37 @@ def should_reload_model(self, force_patch_weights=False):
308329
return True
309330
return False
310331

311-
def model_unload(self, unpatch_weights=True):
332+
def model_unload(self, memory_to_free=None, unpatch_weights=True):
333+
if memory_to_free is not None:
334+
if memory_to_free < self.model.loaded_size():
335+
self.model.partially_unload(self.model.offload_device, memory_to_free)
336+
return False
312337
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
313338
self.model.model_patches_to(self.model.offload_device)
314339
self.weights_loaded = self.weights_loaded and not unpatch_weights
315340
self.real_model = None
341+
return True
342+
343+
def model_use_more_vram(self, extra_memory):
344+
return self.model.partially_load(self.device, extra_memory)
316345

317346
def __eq__(self, other):
318347
return self.model is other.model
319348

349+
def use_more_memory(extra_memory, loaded_models, device):
350+
for m in loaded_models:
351+
if m.device == device:
352+
extra_memory -= m.model_use_more_vram(extra_memory)
353+
if extra_memory <= 0:
354+
break
355+
356+
def offloaded_memory(loaded_models, device):
357+
offloaded_mem = 0
358+
for m in loaded_models:
359+
if m.device == device:
360+
offloaded_mem += m.model_offloaded_memory()
361+
return offloaded_mem
362+
320363
def minimum_inference_memory():
321364
return (1024 * 1024 * 1024) * 1.2
322365

@@ -363,11 +406,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
363406

364407
for x in sorted(can_unload):
365408
i = x[-1]
409+
memory_to_free = None
366410
if not DISABLE_SMART_MEMORY:
367-
if get_free_memory(device) > memory_required:
411+
free_mem = get_free_memory(device)
412+
if free_mem > memory_required:
368413
break
369-
current_loaded_models[i].model_unload()
370-
unloaded_model.append(i)
414+
memory_to_free = memory_required - free_mem
415+
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
416+
if current_loaded_models[i].model_unload(memory_to_free, free_mem):
417+
unloaded_model.append(i)
371418

372419
for i in sorted(unloaded_model, reverse=True):
373420
unloaded_models.append(current_loaded_models.pop(i))
@@ -422,12 +469,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
422469
devs = set(map(lambda a: a.device, models_already_loaded))
423470
for d in devs:
424471
if d != torch.device("cpu"):
425-
free_memory(extra_mem, d, models_already_loaded)
472+
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
426473
free_mem = get_free_memory(d)
427474
if free_mem < minimum_memory_required:
428475
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
429476
models_to_load = free_memory(minimum_memory_required, d)
430477
logging.info("{} models unloaded.".format(len(models_to_load)))
478+
else:
479+
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
431480
if len(models_to_load) == 0:
432481
return
433482

@@ -467,6 +516,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
467516

468517
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
469518
current_loaded_models.insert(0, loaded_model)
519+
520+
521+
devs = set(map(lambda a: a.device, models_already_loaded))
522+
for d in devs:
523+
if d != torch.device("cpu"):
524+
free_mem = get_free_memory(d)
525+
if free_mem > minimum_memory_required:
526+
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
470527
return
471528

472529

comfy/model_patcher.py

Lines changed: 140 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
1+
"""
2+
This file is part of ComfyUI.
3+
Copyright (C) 2024 Comfy
4+
5+
This program is free software: you can redistribute it and/or modify
6+
it under the terms of the GNU General Public License as published by
7+
the Free Software Foundation, either version 3 of the License, or
8+
(at your option) any later version.
9+
10+
This program is distributed in the hope that it will be useful,
11+
but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
GNU General Public License for more details.
14+
15+
You should have received a copy of the GNU General Public License
16+
along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
"""
18+
119
import torch
220
import copy
321
import inspect
422
import logging
523
import uuid
24+
import collections
625

726
import comfy.utils
827
import comfy.model_management
@@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
6382
model_options["disable_cfg1_optimization"] = True
6483
return model_options
6584

85+
def wipe_lowvram_weight(m):
86+
if hasattr(m, "prev_comfy_cast_weights"):
87+
m.comfy_cast_weights = m.prev_comfy_cast_weights
88+
del m.prev_comfy_cast_weights
89+
m.weight_function = None
90+
m.bias_function = None
91+
92+
class LowVramPatch:
93+
def __init__(self, key, model_patcher):
94+
self.key = key
95+
self.model_patcher = model_patcher
96+
def __call__(self, weight):
97+
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
98+
99+
66100
class ModelPatcher:
67101
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
68102
self.size = size
@@ -82,16 +116,29 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
82116
self.load_device = load_device
83117
self.offload_device = offload_device
84118
self.weight_inplace_update = weight_inplace_update
85-
self.model_lowvram = False
86-
self.lowvram_patch_counter = 0
87119
self.patches_uuid = uuid.uuid4()
88120

121+
if not hasattr(self.model, 'model_loaded_weight_memory'):
122+
self.model.model_loaded_weight_memory = 0
123+
124+
if not hasattr(self.model, 'lowvram_patch_counter'):
125+
self.model.lowvram_patch_counter = 0
126+
127+
if not hasattr(self.model, 'model_lowvram'):
128+
self.model.model_lowvram = False
129+
89130
def model_size(self):
90131
if self.size > 0:
91132
return self.size
92133
self.size = comfy.model_management.module_size(self.model)
93134
return self.size
94135

136+
def loaded_size(self):
137+
return self.model.model_loaded_weight_memory
138+
139+
def lowvram_patch_counter(self):
140+
return self.model.lowvram_patch_counter
141+
95142
def clone(self):
96143
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
97144
n.patches = {}
@@ -265,16 +312,16 @@ def model_state_dict(self, filter_prefix=None):
265312
sd.pop(k)
266313
return sd
267314

268-
def patch_weight_to_device(self, key, device_to=None):
315+
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
269316
if key not in self.patches:
270317
return
271318

272319
weight = comfy.utils.get_attr(self.model, key)
273320

274-
inplace_update = self.weight_inplace_update
321+
inplace_update = self.weight_inplace_update or inplace_update
275322

276323
if key not in self.backup:
277-
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
324+
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
278325

279326
if device_to is not None:
280327
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
@@ -304,28 +351,24 @@ def patch_model(self, device_to=None, patch_weights=True):
304351
if device_to is not None:
305352
self.model.to(device_to)
306353
self.model.device = device_to
354+
self.model.model_loaded_weight_memory = self.model_size()
307355

308356
return self.model
309357

310-
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
311-
self.patch_model(device_to, patch_weights=False)
312-
358+
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
313359
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
314-
class LowVramPatch:
315-
def __init__(self, key, model_patcher):
316-
self.key = key
317-
self.model_patcher = model_patcher
318-
def __call__(self, weight):
319-
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
320-
321360
mem_counter = 0
322361
patch_counter = 0
362+
lowvram_counter = 0
323363
for n, m in self.model.named_modules():
324364
lowvram_weight = False
325365
if hasattr(m, "comfy_cast_weights"):
326366
module_mem = comfy.model_management.module_size(m)
327367
if mem_counter + module_mem >= lowvram_model_memory:
328368
lowvram_weight = True
369+
lowvram_counter += 1
370+
if m.comfy_cast_weights:
371+
continue
329372

330373
weight_key = "{}.weight".format(n)
331374
bias_key = "{}.bias".format(n)
@@ -347,16 +390,31 @@ def __call__(self, weight):
347390
m.prev_comfy_cast_weights = m.comfy_cast_weights
348391
m.comfy_cast_weights = True
349392
else:
393+
if hasattr(m, "comfy_cast_weights"):
394+
if m.comfy_cast_weights:
395+
wipe_lowvram_weight(m)
396+
350397
if hasattr(m, "weight"):
351-
self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM
398+
mem_counter += comfy.model_management.module_size(m)
399+
if m.weight is not None and m.weight.device == device_to:
400+
continue
401+
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
352402
self.patch_weight_to_device(bias_key)
353403
m.to(device_to)
354-
mem_counter += comfy.model_management.module_size(m)
355404
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
356405

357-
self.model_lowvram = True
358-
self.lowvram_patch_counter = patch_counter
406+
if lowvram_counter > 0:
407+
self.model.model_lowvram = True
408+
else:
409+
self.model.model_lowvram = False
410+
self.model.lowvram_patch_counter += patch_counter
359411
self.model.device = device_to
412+
self.model.model_loaded_weight_memory = mem_counter
413+
414+
415+
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
416+
self.patch_model(device_to, patch_weights=False)
417+
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
360418
return self.model
361419

362420
def calculate_weight(self, patches, weight, key):
@@ -529,37 +587,86 @@ def calculate_weight(self, patches, weight, key):
529587

530588
def unpatch_model(self, device_to=None, unpatch_weights=True):
531589
if unpatch_weights:
532-
if self.model_lowvram:
590+
if self.model.model_lowvram:
533591
for m in self.model.modules():
534-
if hasattr(m, "prev_comfy_cast_weights"):
535-
m.comfy_cast_weights = m.prev_comfy_cast_weights
536-
del m.prev_comfy_cast_weights
537-
m.weight_function = None
538-
m.bias_function = None
592+
wipe_lowvram_weight(m)
539593

540-
self.model_lowvram = False
541-
self.lowvram_patch_counter = 0
594+
self.model.model_lowvram = False
595+
self.model.lowvram_patch_counter = 0
542596

543597
keys = list(self.backup.keys())
544598

545-
if self.weight_inplace_update:
546-
for k in keys:
547-
comfy.utils.copy_to_param(self.model, k, self.backup[k])
548-
else:
549-
for k in keys:
550-
comfy.utils.set_attr_param(self.model, k, self.backup[k])
599+
for k in keys:
600+
bk = self.backup[k]
601+
if bk.inplace_update:
602+
comfy.utils.copy_to_param(self.model, k, bk.weight)
603+
else:
604+
comfy.utils.set_attr_param(self.model, k, bk.weight)
551605

552606
self.backup.clear()
553607

554608
if device_to is not None:
555609
self.model.to(device_to)
556610
self.model.device = device_to
611+
self.model.model_loaded_weight_memory = 0
557612

558613
keys = list(self.object_patches_backup.keys())
559614
for k in keys:
560615
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
561616

562617
self.object_patches_backup.clear()
563618

619+
def partially_unload(self, device_to, memory_to_free=0):
620+
memory_freed = 0
621+
patch_counter = 0
622+
623+
for n, m in list(self.model.named_modules())[::-1]:
624+
if memory_to_free < memory_freed:
625+
break
626+
627+
shift_lowvram = False
628+
if hasattr(m, "comfy_cast_weights"):
629+
module_mem = comfy.model_management.module_size(m)
630+
weight_key = "{}.weight".format(n)
631+
bias_key = "{}.bias".format(n)
632+
633+
634+
if m.weight is not None and m.weight.device != device_to:
635+
for key in [weight_key, bias_key]:
636+
bk = self.backup.get(key, None)
637+
if bk is not None:
638+
if bk.inplace_update:
639+
comfy.utils.copy_to_param(self.model, key, bk.weight)
640+
else:
641+
comfy.utils.set_attr_param(self.model, key, bk.weight)
642+
self.backup.pop(key)
643+
644+
m.to(device_to)
645+
if weight_key in self.patches:
646+
m.weight_function = LowVramPatch(weight_key, self)
647+
patch_counter += 1
648+
if bias_key in self.patches:
649+
m.bias_function = LowVramPatch(bias_key, self)
650+
patch_counter += 1
651+
652+
m.prev_comfy_cast_weights = m.comfy_cast_weights
653+
m.comfy_cast_weights = True
654+
memory_freed += module_mem
655+
logging.debug("freed {}".format(n))
656+
657+
self.model.model_lowvram = True
658+
self.model.lowvram_patch_counter += patch_counter
659+
self.model.model_loaded_weight_memory -= memory_freed
660+
return memory_freed
661+
662+
def partially_load(self, device_to, extra_memory=0):
663+
if self.model.model_lowvram == False:
664+
return 0
665+
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
666+
pass #TODO: Full load
667+
current_used = self.model.model_loaded_weight_memory
668+
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
669+
return self.model.model_loaded_weight_memory - current_used
670+
564671
def current_loaded_device(self):
565672
return self.model.device

0 commit comments

Comments
 (0)